全网最详细Yolov3训练Caltech Pedestrain数据集并绘制fppi miss rate图


概述

本帖来自于Shadow丶dream大佬,欢迎关注。 原帖https://blog.csdn.net/fatterrier/article/details/109559083

1.环境要求

1.python3,网上很多代码是python2版本的,我大概修改了一下,所有代码都只适用于python3版本。

2.matlab ,这是因为matlab提供了这个画图的函数,所以画图必须要用到matlab。

3.AlexeyAB 版本的yolo,本文基于的是这个版本。因为图片数量巨大,在ubuntu打开对应文件夹时会很卡,所以将caltech转换为yolo格式时我是在windows下处理的。训练的时候我是在ubuntu训练的(因为环境配置简单很多)。但是代码对系统没有要求,哪个系统都可以。

4.路径如果跟我设置成一样的基本不用修改代码,否则在代码中修改成自己的路径即可。

2.caltech数据集做成VOC格式

​ 这一部分内容参考自https://github.com/shadowwalker00/CaltechPestrain2VOC。

2.1官网下载数据集

​ 在官网https://drive.google.com/drive/folders/1IBlcJP8YsCaT81LwQ2YwQJac8bf1q8xF上将所有文件下载下来,建议搭梯子,会快很多。其中annotations是标注文件,下载下来格式为vbb,后续转换为xml格式;set00-set10包含是图像,格式是seq,后续转换为jpg。

2.2将seq格式文件转换为jpg图片

​ 调用seq2jpg_py3.py文件进行转换,只需在代码41行和42行修改输入输出路径即可。

#coding=utf-8
# Deal with .seq format for video sequence
# Author: Kaij
# The .seq file is combined with images,
# so I split the file into several images with the image prefix
# "\xFF\xD8\xFF\xE0\x00\x10\x4A\x46\x49\x46".

import os.path
import fnmatch
import shutil

def open_save(file,savepath):
    # read .seq file, and save the images into the savepath

    f = open(file,'rb+')
    string = f.read().decode('latin-1')
    splitstring = "\xFF\xD8\xFF\xE0\x00\x10\x4A\x46\x49\x46"
    # split .seq file into segment with the image prefix
    strlist=string.split(splitstring)
    f.close()
    count = 0
    # delete the image folder path if it exists
    if os.path.exists(savepath):
        shutil.rmtree(savepath)
    # create the image folder path
    if not os.path.exists(savepath):
        os.makedirs(savepath)
    # deal with file segment, every segment is an image except the first one
    for img in strlist:
        filename = str(count)+'.jpg'
        filenamewithpath=os.path.join(savepath, filename)
        # abandon the first one, which is filled with .seq header
        if count > 0:
            i=open(filenamewithpath,'wb+')
            i.write(splitstring.encode('latin-1'))
            i.write(img.encode('latin-1'))
            i.close()
        count += 1

if __name__=="__main__":
    rootdir = "D:\data set\caltech"
    saveroot = "D:/data set/caltech/VOCdevkit/JPEG"
    # walk in the rootdir, take down the .seq filename and filepath
    for parent, dirnames, filenames in os.walk(rootdir):
        for filename in filenames:
            # check .seq file with suffix
            if fnmatch.fnmatch(filename,'*.seq'):
                # take down the filename with path of .seq file
                thefilename = os.path.join(parent, filename)
                # create the image folder by combining .seq file path with .seq filename
                thesavepath = saveroot +'\\'+ parent.split('\\')[-1] + '\\' + filename.split('.')[0]+'\\'
                print ("Filename=" + thefilename)
                print ("Savepath=" + thesavepath)
                open_save(thefilename,thesavepath)

输出目录如下图所示:

2.3将vbb格式转换为xml格式

​ 调用vbb2voc.py文件,只需修改代码161和162行的输入和输出目录即可。

#-*- coding:utf-8 -*-
import os, glob
import cv2
from scipy.io import loadmat
from collections import defaultdict
import numpy as np
from lxml import etree, objectify

def vbb_anno2dict(vbb_file, cam_id):
    #通过os.path.basename获得路径的最后部分“文件名.扩展名”
    #通过os.path.splitext获得文件名
    filename = os.path.splitext(os.path.basename(vbb_file))[0]

    #定义字典对象annos
    annos = defaultdict(dict)
    vbb = loadmat(vbb_file)
    # object info in each frame: id, pos, occlusion, lock, posv
    objLists = vbb['A'][0][0][1][0]
    objLbl = [str(v[0]) for v in vbb['A'][0][0][4][0]]     #可查看所有类别        
    # person index
    person_index_list = np.where(np.array(objLbl) == "person")[0]   #只选取类别为‘person’的xml
    for frame_id, obj in enumerate(objLists):
        if len(obj) > 0:
            frame_name = str(cam_id) + "_" + str(filename) + "_" + str(frame_id+1) + ".jpg"
            annos[frame_name] = defaultdict(list)
            annos[frame_name]["id"] = frame_name
            annos[frame_name]["label"] = "person"
            for id, pos, occl in zip(obj['id'][0], obj['pos'][0], obj['occl'][0]):
                id = int(id[0][0]) - 1  # for matlab start from 1 not 0
                if not id in person_index_list:  # only use bbox whose label is person
                    continue
                pos = pos[0].tolist()
                occl = int(occl[0][0])
                annos[frame_name]["occlusion"].append(occl)
                annos[frame_name]["bbox"].append(pos)
            if not annos[frame_name]["bbox"]:
                del annos[frame_name]
    print (annos)
    return annos


def seq2img(annos, seq_file, outdir, cam_id):
    cap = cv2.VideoCapture(seq_file)
    index = 1
    # captured frame list
    v_id = os.path.splitext(os.path.basename(seq_file))[0]
    cap_frames_index = np.sort([int(os.path.splitext(id)[0].split("_")[2]) for id in annos.keys()])
    while True:
        ret, frame = cap.read()
        print (ret)
        if ret:
            if not index in cap_frames_index:
                index += 1
                continue
            if not os.path.exists(outdir):
                os.makedirs(outdir)
            outname = os.path.join(outdir, str(cam_id)+"_"+v_id+"_"+str(index)+".jpg")
            print ("Current frame: ", v_id, str(index))
            cv2.imwrite(outname, frame)
            height, width, _ = frame.shape
        else:
            break
        index += 1
    img_size = (width, height)
    return img_size


def instance2xml_base(anno, bbox_type='xyxy'):
    """bbox_type: xyxy (xmin, ymin, xmax, ymax); xywh (xmin, ymin, width, height)"""
    assert bbox_type in ['xyxy', 'xywh']
    E = objectify.ElementMaker(annotate=False)
    anno_tree = E.annotation(
        E.folder('VOC2014_instance/person'),
        E.filename(anno['id']),
        E.source(
            E.database('Caltech pedestrian'),
            E.annotation('Caltech pedestrian'),
            E.image('Caltech pedestrian'),
            E.url('None')
        ),
        E.size(
            E.width(640),
            E.height(480),
            E.depth(3)
        ),
        E.segmented(0),
    )
    for index, bbox in enumerate(anno['bbox']):
        bbox = [float(x) for x in bbox]
        if bbox_type == 'xyxy':
            xmin, ymin, w, h = bbox
            xmax = xmin+w
            ymax = ymin+h
        else:
            xmin, ymin, xmax, ymax = bbox
        E = objectify.ElementMaker(annotate=False)
        anno_tree.append(
            E.object(
            E.name(anno['label']),
            E.bndbox(
                E.xmin(xmin),
                E.ymin(ymin),
                E.xmax(xmax),
                E.ymax(ymax)
            ),
            E.difficult(0),
            E.occlusion(anno["occlusion"][index])
            )
        )
    return anno_tree


def parse_anno_file(vbb_inputdir,vbb_outputdir):
    # annotation sub-directories in hda annotation input directory
    assert os.path.exists(vbb_inputdir)
    sub_dirs = os.listdir(vbb_inputdir)     #对应set00,set01...
    for sub_dir in sub_dirs:
        print ("Parsing annotations of camera: ", sub_dir)
        cam_id = sub_dir
        #获取某一个子set下面的所有vbb文件
        vbb_files = glob.glob(os.path.join(vbb_inputdir, sub_dir, "*.vbb")) 
        for vbb_file in vbb_files:
            #返回一个vbb文件中所有的帧的标注结果
            annos = vbb_anno2dict(vbb_file, cam_id)

            if annos:
                #组成xml文件的存储文件夹,形如“/Users/chenguanghao/Desktop/Caltech/xmlresult/”
                vbb_outdir = vbb_outputdir

                #如果不存在
                if not os.path.exists(vbb_outdir):
                    os.makedirs(vbb_outdir)

                for filename, anno in sorted(annos.items(), key=lambda x: x[0]):                  
                    if "bbox" in anno:
                        anno_tree = instance2xml_base(anno)
                        outfile = os.path.join(vbb_outdir, os.path.splitext(filename)[0]+".xml")
                        print ("Generating annotation xml file of picture: ", filename)
                        #生成最终的xml文件,对应一张图片
                        etree.ElementTree(anno_tree).write(outfile, pretty_print=True)            
def visualize_bbox(xml_file, img_file):
    import cv2
    tree = etree.parse(xml_file)
    # load image
    image = cv2.imread(img_file)
    origin =  cv2.imread(img_file)
    # 获取一张图片的所有bbox
    for bbox in tree.xpath('//bndbox'):
        coord = []
        for corner in bbox.getchildren():
            coord.append(int(float(corner.text)))
        print (coord)
        cv2.rectangle(image, (coord[0], coord[1]), (coord[2], coord[3]), (0, 0, 255), 2)
    # visualize image
    cv2.imshow("test", image)
    cv2.imshow('origin',origin)
    cv2.waitKey(0)


def main():
    vbb_inputdir = "D:/data set/caltech/annotations"
    vbb_outputdir = "D:/data set/caltech/VOCdevkit/annotations"
    parse_anno_file(vbb_inputdir,vbb_outputdir)

if __name__ == "__main__":
    main()

​ 输出目录在VOCdevkit/annotations下,总共生成122187个xml文件,且全部放在了一起。第一个文件名是 set00_V000_69.xml,这是因为并不是每一张图片都有行人,只有含有行人的图片才被标注。这里细心的小伙伴可能发现在set00_V000_69.jpg图片里找不到人,其实是因为caltech数据集行人太小,分辨率过低的缘故。图中人在天桥上特别小,可以用opencv画矩形框在图上就可以发现。这里给出一个在图片上批量标注矩形框的代码。此处需要安装opencv。代码主要修改第11行的输入图片路径、12行的输入xml路径、28行的输出图片路径和30行的输入路径。运行完可以查看一下标注图片,发现确实有人。

#coding=tuf-8
import xml.etree.ElementTree as ET
import pickle
import os
from os import listdir, getcwd
from os.path import join
import cv2

def convert_annotation(setxxx,vxxx,xxxjpg):

    img = cv2.imread("D:\data set\caltech\VOCdevkit\JPEG\set%s\V%s\%s.jpg"%(setxxx,vxxx,xxxjpg))
    in_file = open('D:\data set\caltech\VOCdevkit\\annotations\set%s_V%s_%s.xml'%(setxxx,vxxx,xxxjpg))
    tree=ET.parse(in_file)
    root = tree.getroot()
    size = root.find('size')
    w = int(size.find('width').text)
    h = int(size.find('height').text)

    for obj in root.iter('object'):
        difficult = obj.find('difficult').text
        if int(difficult)==1:
            continue
        xmlbox = obj.find('bndbox')
        b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text))
        print(b[0],b[1],b[2],b[3])
        colors = (0,0,255)
        cv2.rectangle(img,(int(b[0]),int(b[2])),(int(b[1]),int(b[3])),colors)
        cv2.imwrite('D:\data set\caltech\VOCdevkit\JPEGS\%s.jpg'%(xxxjpg), img)               
if __name__=="__main__":     
    rootdir = "D:\data set\caltech\VOCdevkit\\annotations"
    for parent, dirnames, filenames in os.walk(rootdir):
        for name in filenames:
            temp = name.split(".",1)[0].split("_",2)
            setxx = temp[0][3:5]
            Vxxx = temp[1][1:4]
            xxxjpg = temp[2]
            convert_annotation(setxx,Vxxx,xxxjpg)          

2.4将jpg图片放在一个统一的文件夹下和xml文件对应

​ 2.3步骤中xml文件已经放在了一个文件夹下,而图片还放在不同目录下,没有和xml标注文件相对应。因此调用mergejpg.py文件,将所有图片放在一起,只需修改代码第6和7行的输入输出路径即可。

#-*- coding:utf-8 -*-
import os
import glob
import shutil
if __name__ == "__main__":
    imgpathin = 'D:\data set\caltech\VOCdevkit\JPEG'
    imgout = 'D:\data set\caltech\VOCdevkit\JPEG'
    for subdir in os.listdir(imgpathin):
        print(subdir)
        file_path = os.path.join(imgpathin, subdir)
        for subdir1 in os.listdir(file_path):
            print(subdir1)
            file_path1 = os.path.join(file_path, subdir1)
            print(file_path1)
            for jpg_file in os.listdir(file_path1):
                src = os.path.join(file_path1, jpg_file)
                new_name=str(subdir+"_"+subdir1+"_"+jpg_file)
                print(new_name)
                dst=os.path.join(imgout,new_name)
                os.rename(src,dst)

​ 我这里输入输出同目录,生成的图片放在了一起。删除不需要的set00-set10文件夹,剩下的名为set0_V000_1.jpg格式图片,总共有249884张。如2.3节所述,图片数量比xml文件要多很多。

2.5重命名图片和xml文件

​ 按照“xxxxxx”6位数字格式给图片和xml文件命名方便后续操作。标注有人的图片命名为xxxxxx.jpg和命名为xxxxxx.xml的xml文件相对应,多余的图片保持原名。调用renameindex.py文件,只需要修改代码的第3和4行的输入和输出目录即可。

#-*- coding:utf-8 -*-
import os
xmlpath = 'D:/data set/caltech/VOCdevkit/annotations'
imgpath = 'D:/data set/caltech/VOCdevkit/JPEG'
index = 0
count = 0
emptyset = set()
xmlFiles = os.listdir(xmlpath)
imgFiles = os .listdir(imgpath)
print (len(xmlFiles),len(imgFiles))

for xml in xmlFiles:    
    xmlname  = os.path.splitext(xml)[0]
    imgname = os.path.join(imgpath,xmlname+'.jpg')    
    print(imgname)
    if os.path.exists(imgname):
        newName = str(index).zfill(6)
        #重命名图像
        os.rename(imgname,os.path.join(imgpath,newName+'.jpg'))
        #重命名xml文件
        os.rename(os.path.join(xmlpath,xml),os.path.join(xmlpath,newName+'.xml'))
        print ('============================================')
        print ('img',imgname,os.path.join(imgpath,newName+'.jpg'))
        print ('__________________________________________')
        print ('xml',os.path.join(xmlpath,xml),os.path.join(xmlpath,newName+'.xml'))
        print ('============================================')
        index = index + 1
    else:
        count += 1
        emptyset.add(xmlname.split('_')[0]+'_'+xmlname.split('_')[1])
sortedSet = sorted(emptyset,key= lambda x:(x.split('_')[0],x.split('_')[1]))
for i in sortedSet:
    print (i) 
    print (count)

​ 调用完毕在annotations下生成名为xxxxxx.xml格式的标注文件,JPEG下生成xxxxxx.jpg图片,将多余的没有重命名图片删除。可以右键属性查看到标注文件和图片分别有122187个,且各自对应。

2.6替换标签

​ 这里网上说是caltech标注中行人还会细划分为几类不同的行人,但我好像没发现,如果有大佬知道欢迎留言,这里为了以防万一还是运行此文件。运行findPeople.py将不同类person转换为同类person。代码只需修改第6和7行的输入和输出路径即可。

# -*- coding:utf-8 -*-
import os
import re

if __name__ == "__main__":
    xmlin = 'D:\data set\caltech\VOCdevkit\\annotations'
    xmlout = 'D:\data set\caltech\VOCdevkit\\annotations'
    files = os.listdir(xmlin)
    #编译一个pattern
    pattern = re.compile('people')
    #每张图片进行判断
    for file in files:
        f = open(os.path.join(xmlin,file), 'r')
        content = f.read()
        f.close()
        result = re.search('people', content)
        if (result!=None):
            updateFile = pattern.sub('person', content)
        else:
            updateFile = content
        with open(os.path.join(xmlout,file), 'w') as fout:
            fout.write(updateFile)
        print ('updating file {}'.format(file))

2.7对xml文件和jpg图片采样(可选)

​ 因为生成图片过多,我不想训练这么多的内容,选择每8帧对图片和标注文件采样,减少训练内容。调用delete_file.py采样,仍然保持标注文件和图片对应。代码只需修改第22行输入目录和第14行的采样频率即可。这里分两步,第一步先对图片采样,然后将图片的路径改为xml文件的路径再对xml文件采样。

#!/usr/bin/python
# -*- coding: UTF-8 -*-
import os
import shutil
def get_file_path(root_path, file_list):
    dir_or_files = os.listdir(root_path)
    for dir_file in dir_or_files:
        dir_file_path = os.path.join(root_path, dir_file)
        file_list.append(dir_file_path)

def delete_file(file_list, length):
    count = 0
    for file_name in file_list:
        if (count % 8 == 0):    #数字自行修改
            count = count + 1
            continue
        else:
            os.remove(file_name)
        count = count + 1

if __name__ == "__main__":
    root_path = r"D:\data set\caltech\VOCdevkit\JPEG"  #先对图片采样,之后更换路劲对标注文件采样。
    file_list = []
    get_file_path(root_path, file_list)
    length = len(file_list)
    delete_file(file_list, length)

​ 采样后xml和jpg分别有15274张。

2.8重命名xml文件和jpg图片(可选)

​ 2.7步骤采样后删除很多文件,需要重新排列命名。调用rename_after_cut.py文件重命名,代码只需修改第3、4行路径即可。

#-*- coding:utf-8 -*-
import os
xmlpath = 'D:/data set/caltech/VOCdevkit/annotations'
imgpath = 'D:/data set/caltech/VOCdevkit/JPEG'
index = 0
count = 0
emptyset = set()
xmlFiles = os.listdir(xmlpath)
imgFiles = os .listdir(imgpath)
print (len(xmlFiles),len(imgFiles))

for xml in xmlFiles:    
    xmlname  = os.path.splitext(xml)[0]
    imgname = os.path.join(imgpath,xmlname+'.jpg')    
    print(imgname)
    if os.path.exists(imgname):
        newName = str(index).zfill(6)
        #重命名图像
        os.rename(imgname,os.path.join(imgpath,newName+'.jpg'))
        #重命名xml文件
        os.rename(os.path.join(xmlpath,xml),os.path.join(xmlpath,newName+'.xml'))
        print ('img',imgname,os.path.join(imgpath,newName+'.jpg'))
        print ('xml',os.path.join(xmlpath,xml),os.path.join(xmlpath,newName+'.xml'))
        index = index + 1
    else:
        count += 1
        emptyset.add(xmlname.split('_')[0]+'_'+xmlname.split('_')[1])
sortedSet = sorted(emptyset,key= lambda x:(x.split('_')[0],x.split('_')[1]))
for i in sortedSet:
    print (i) 
    print (count)

2.9生成txt文件指定训练集、验证集、数据集、训练验证集

​ 调用generate_txt.py生成trainval.txt、test.txt、train.txt、val.txt四个文件,这些文件内容只包含图像名字的数字索引,也就是xxxxxx。因为要转换为标准VOC数据集格式,这几个文件最后放入VOC2007/ImageSets/Main文件夹中。代码只需要修改第6和7行的目录,第11行和12行的训练集测试集百分比即可。这里我选取测试集占比28%,因此trainval.txt有10997行,test.txt有4277行,加起来正好15274行。每次选取图片均是随机选取,所以后面的文件格式跟我相同即可,内容是不一样的。

#coding=utf-8
import os
import random
import time

xmlfilepath='D:/data set/caltech/VOCdevkit/annotations'
saveBasePath='D:/data set/caltech/VOCdevkit/txt'
if not os.path.exists(saveBasePath):
    os.makedirs(saveBasePath)
#设置训练集和测试集的百分比,训练验证集72%,测试集28%,训练集和验证集在对半分。
trainval_percent=0.72
train_percent=0.5
total_xml = os.listdir(xmlfilepath)#所有的

num = len(total_xml)      #xml文件的数量
index_list = range(num)   #生成一个index列表
trainval_num = int(num*trainval_percent) 
train_num = int(trainval_num*train_percent)
trainval_index = random.sample(index_list,trainval_num)#从index_list中随机采样trainval_num数量的内容(正好一半)
train_index = random.sample(trainval_index,train_num)#再次取一半

print("train and val size", trainval_num)
print("train size", train_num)

ftrainval = open(os.path.join(saveBasePath,'trainval.txt'), 'w')
ftest = open(os.path.join(saveBasePath,'test.txt'), 'w')
ftrain = open(os.path.join(saveBasePath,'train.txt'), 'w')
fval = open(os.path.join(saveBasePath,'val.txt'), 'w')

# Start time
start = time.time()
for i  in index_list:
    name = os.path.splitext(total_xml[i])[0] + '\n'    
    if i in trainval_index:
        ftrainval.write(name)
        if i in train_index:
            ftrain.write(name)
        else:
            fval.write(name)
    else:
        ftest.write(name)
# End time
end = time.time()
seconds = end - start
print( "Time taken : {0} seconds".format(seconds))
ftrainval.close()
ftrain.close()
fval.close()
ftest .close()

3.使用yolov3训练caltech数据集

​ 这里改自https://github.com/AlexeyAB/darknet#how-to-use-on-the-command-line。

​ 这里前期的准备工作已经完成,我们把annotations文件夹和JPEG文件夹以及txt文件夹里面的内容按照VOC数据集的格式存放。这里我为了编译方便用的Ubuntu系统,路径是ubuntu风格,用windows也一样。Alexey版本的yolo下,在data目录里建立VOC文件夹,VOC文件夹里建立VOCdevkit文件夹,VOCdevkit文件夹里建立VOC2007文件夹,VOC2007文件夹下建立Annotations、ImageSets、JPEGImages三个文件夹,其中Annotations将所有的xml文件放入,JPETImages将所有图片放入,ImageSets文件夹下建立Main文件夹并放入2.9步骤生成的4个txt文件。注意这里目录名字一定要打对,仔细核对。

3.1生成训练所需要的的2007_test.txt、2007_train.txt、2007_val.txt、train.all.txt、train.txt以及lables文件夹

​ 将voc_lable.py文件放在和VOCdevkit同级目录下,调用voc_lable.py在同级目录下生成2007_test.txt、2007_train.txt、2007_val.txt、train.all.txt、train.txt5个文件,里面内容是图片的路径。VOC2007文件夹下生成lables文件夹,里面是一堆txt文件,每个文件表示对应图片里归一化的xywh值。代码修改自VOC数据集的voc_lable.py文件,代码路径不建议修改,只是文件一定放在VOCdevkit同级目录下。

import xml.etree.ElementTree as ET
import pickle
import os
from os import listdir, getcwd
from os.path import join

sets=[('2007', 'train'), ('2007', 'val'), ('2007', 'test')]
classes = ["person"]

def convert(size, box):
    dw = 1./(size[0])
    dh = 1./(size[1])
    x = (box[0] + box[1])/2.0 - 1
    y = (box[2] + box[3])/2.0 - 1
    w = box[1] - box[0]
    h = box[3] - box[2]
    x = x*dw
    w = w*dw
    y = y*dh
    h = h*dh
    return (x,y,w,h)

def convert_annotation(year, image_id):
    in_file = open('VOCdevkit/VOC%s/Annotations/%s.xml'%(year, image_id))
    out_file = open('VOCdevkit/VOC%s/labels/%s.txt'%(year, image_id), 'w')
    tree=ET.parse(in_file)
    root = tree.getroot()
    size = root.find('size')
    w = int(size.find('width').text)
    h = int(size.find('height').text)

    for obj in root.iter('object'):
        difficult = obj.find('difficult').text
        cls = obj.find('name').text
        if cls not in classes or int(difficult)==1:
            continue
        cls_id = classes.index(cls)
        xmlbox = obj.find('bndbox')
        b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text))
        bb = convert((w,h), b)
        out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')

wd = getcwd()

for year, image_set in sets:
    if not os.path.exists('VOCdevkit/VOC%s/labels/'%(year)):
        os.makedirs('VOCdevkit/VOC%s/labels/'%(year))
    image_ids = open('VOCdevkit/VOC%s/ImageSets/Main/%s.txt'%(year, image_set)).read().strip().split()
    list_file = open('%s_%s.txt'%(year, image_set), 'w')
    for image_id in image_ids:
        list_file.write('%s/VOCdevkit/VOC%s/JPEGImages/%s.jpg\n'%(wd, year, image_id))
        convert_annotation(year, image_id)
    list_file.close()

os.system("cat 2007_train.txt 2007_val.txt  > train.txt")
os.system("cat 2007_train.txt 2007_val.txt 2007_test.txt  > train.all.txt")

3.2创建yolov3-caltech.cfg文件

​ 创建与yolov3-voc.cfg类似的yolov3-caltech.cfg,放入cfg文件夹下:

  • 设置training下的batch=64,subdivisions=8.(视电脑配置决定)
  • max_batches设置30000。(自己决定)
  • steps参数为80%和90%max_batches,这里是steps=24000,27000.
  • 网络输入长宽必须能整除32,这里是416x416.
  • 3个yolo层的classes类别数改为1.
  • 3个yolo层前一个卷积层的filter数为18.(公式是filters=(classes+5)×3)

3.3创建caltech.names文件

​ 创建类似voc.namescaltech.names放入data文件夹下,文件里只有person

3.4创建caltech.data文件

​ 创建类似voc.datacaltech.data放入cfg文件夹下,修改为自己对应的路径

classes= 1 # 你的类别的个数
train  =/home/davy/下载/yolo_each_version/alexey_simple/data/VOC/train.txt # 存储用于训练的图片位置
valid  =/home/davy/下载/yolo_each_version/alexey_simple/data/VOC/2007_test.txt# 存储用于测试的图片的位置
names = data/caltech.names # 每行一个类别的名称
backup = backup/

3.5开始训练

  • ./darknet detector train cfg/caltech.data cfg/yolov3-caltech.cfg darknet53.conv.74 -map实时显示map
  • ./darknet detector train cfg/caltech.data cfg/yolov3-caltech.cfg darknet53.conv.74 2>1 | tee visualization/train_yolov3.log可视化中间参数并保存终端内容到visualization/train_yolov3.log文件中

4绘制fppi miss rate图

​ 这里参考了https://www.cnblogs.com/ya-cpp/p/8282383.html的内容。

​ 这里运用的是matlab下的[am,fppi,missRate] = evaluateDetectionMissRate(detectionResults,trainingData)函数。detectionResults是自己算法的预测结果,类型如下图所示:

Boxes代表坐标信息(左上角xy、w、h),维度很高是因为一张图中有多个预测框,Src是对应的置信度。

trainingData是真实的gt结果,类型如下图所示:

Boxes代表真实坐标信息,name是图片名称。

4.1生成detectionResults内容

​ 训练完毕后需要在主目录下创建results文件夹,在源码中的detector.c文件中validate_detector函数下找到int classes = l.classes;在后面添加char *classesnum = option_find_str(options, "classes", "80");classes = atoi(classesnum);这样在文件夹results中生成的txt文件才和你自定义的类别数相同。

​ 在终端输入./darknet detector valid cfg/caltech.data cfg/yolov3-caltech.cfg backup/yolov3-caltech_last.weights -out "" -gpu 0,权重文件根据自己的修改。这会在results文件夹下生成person.txt。这个文件是我们根据3.4步骤中valid路径的2007_test.txt中对应图片的自己预测的结果,每张图会生成几个预测框,里面包含图片名、置信度、坐标这几个内容。person.txt文件部分内容如下(每个人预测结果不同):

000001 0.559587 224.088120 133.759949 231.628952 146.779480
000001 0.145668 232.491028 132.304642 239.329895 148.342789
000009 0.468610 228.921371 131.483200 235.663467 144.126236
000009 0.257587 235.649506 132.068237 241.292664 143.776550
000009 0.011003 247.271637 131.625412 253.146362 143.372421
000010 0.939331 592.104309 169.121338 608.267395 212.208038
000010 0.787617 480.261322 171.480667 491.733246 196.037827
000010 0.014944 475.234589 171.413132 485.475250 195.526871

​ 在主目录下创建caltech_test文件夹,里面创建out.txt,out1.txt,out2.txt,person_gt.txt,person_gt_matlab.txt这几个空文件。文件内容下面会解释。

​ 调用pro_det.py文件,生成三个文件,一个文件out.txt只存储图片名(根据person.txt得出,重复的图片名只计一次),一个文件out1.txt只包含置信度(一个图片内所有预测框的置信度放在一行中,每个预测框的置信度用分号间隔),一个文件out2.txt保存坐标(一个图片内所有预测框的四个坐标放在一行,每个框的内容之间用分号间隔,四个坐标之间用逗号间隔)。三个文件内容分别对应。代码只需修改第4、5、63行路径。

#coding=utf-8
import os
def generate_result(resource_path, des_path):
    des_path1 = "/home/davy/下载/yolo_each_version/alexey_simple/caltech_test/out1.txt"
    des_path2 = "/home/davy/下载/yolo_each_version/alexey_simple/caltech_test/out2.txt"

    rf = open(resource_path)

    content = rf.readline()
    cnt = 0
    tmp_dick = {}
    #tmp_dick是一个字典,cls表示图片名字,bbox表示图片的置信度和四个坐标
    while content:
        res = content.replace("\n", "").split(" ")
        cls = str(res[0:1][0])
        bbox = res[1:6]
        #一个图有多个框
        if cls in tmp_dick:
            tmp_dick[cls].append(bbox)
        #一个图就一个框
        else:
            tmp_dick[cls] = [bbox]
        content = rf.readline()
    rf.close()
    wfname = open(des_path, "r+")#图片名,a+表示追加写入方式,r+覆盖写入
    wfsrc = open(des_path1, "r+")#置信度
    wfbox = open(des_path2, "r+")#坐标
    #字典的键,也就是图片名
    for key_ in tmp_dick:
        wfname.write(str(key_)+',')
        #字典的值,就是每个bbox(包含置信度和坐标)
        for detail in tmp_dick[key_]:
            #取一个bbox中的一个值,分别是置信度,x,y,w,h
            for index in detail:
                if index == detail[0]:
                    wfsrc.write(str(index))
                else:
                    if index is detail[1]:#左上角x
                        tmpp1 = index
                        wfbox.write(str((float(index))))
                    if index is detail[2]:#左上角y
                        tmpp2 = index
                        wfbox.write(str((float(index))))
                    if index is detail[3]:#宽
                        wfbox.write(str((float(index) - float(tmpp1))))
                    if index is detail[4]:#高
                        wfbox.write(str((float(index) - float(tmpp2))))
                    if index is not detail[-1]:#每个坐标间用,隔开
                        wfbox.write(",")
            if len(tmp_dick[key_]) > 1:
                if detail is not tmp_dick[key_][-1]:
                    wfsrc.write(";")#不同框间的内容用;隔开
                    wfbox.write(";")

        wfname.write("\n")#不同图片换行
        wfsrc.write("\n")
        wfbox.write("\n")

    wfname.close()
    wfsrc.close()
    wfbox.close()
#生成三个文件,out保存图片名,换行间隔;out1保存置信度,一个图片的置信度分号间隔,不同图片换行;out2保存坐标信息(左上和右下),一个图片的坐标信息分号间隔,每个坐标直接逗号间隔,图片换行。
generate_result("/home/davy/下载/yolo_each_version/alexey_simple/results/person.txt", "/home/davy/下载/yolo_each_version/alexey_simple/caltech_test/out.txt")

​ out.txt部分内容如下(每个人预测结果不同):

000001,
000009,
000010,
000020,
000025,
000028,

​ out1.txt部分内容如下(每个人预测结果不同):

0.559587;0.145668
0.468610;0.257587;0.011003
0.939331;0.787617;0.014944
0.867930;0.172820;0.022374;0.005487;0.005403
0.827769;0.011797;0.008957

​ out2.txt部分内容如下(每个人预测结果不同):

224.08812,133.759949,7.540831999999995,13.019531;232.491028,132.304642,6.838866999999993,16.03814700000001
228.921371,131.4832,6.742096000000004,12.643035999999995;235.649506,132.068237,5.643158,11.708312999999976;247.271637,131.625412,5.874725000000012,11.747008999999991
592.104309,169.121338,16.16308600000002,43.08669999999998;480.261322,171.480667,11.471924000000001,24.557159999999982;475.234589,171.413132,10.240660999999989,24.11373900000001
233.762619,130.848083,7.07607999999999,11.57074;230.70842,130.462173,6.369751000000008,12.236480999999998;247.076218,130.692719,6.628844999999984,10.81832799999998;385.98465,168.248779,11.765930000000026,20.181701999999973;213.189636,131.733978,7.869202000000001,11.511993999999987
594.865295,123.218018,22.319703000000004,57.115263999999996;568.346802,130.529251,17.81664999999998,43.75082400000002;633.757751,114.223183,6.242249000000015,78.37160499999999  

​ 这里我每次生的的预测结果数量总是比2007_test.txt中图片数量少几张,就是总有几张没有给出预测,不过这并不影响我们后续的步骤。

4.2生成trainingData内容

​ 上一步生成的out.txt文件中是你预测的测试集图片名字。这一步是找到预测图片相对应的真实坐标信息的xml格式的标注文件并转换为txt类型。调用xml2txt.py文件,代码同样必须放在和VOCdevkit同级目录下,此外只需修改第18、35的路径即可。

import xml.etree.ElementTree as ET
import pickle
import os
from os import listdir, getcwd
from os.path import join
classes = ["person"]

def convert(box):
    #左上角和wh
    x = box[0]
    y = box[2]
    w = box[1] - box[0]
    h = box[3] - box[2]
    return (x,y,w,h)

def convert_annotation(year, image_id):
    in_file = open('VOCdevkit/VOC%s/Annotations/%s.xml'%(year, image_id)) #存放xml格式文件的目录
    out_file = open('/home/davy/下载/yolo_each_version/alexey_simple/caltech_test/person_gt.txt', 'a+') #保存到这个文件中,a+以追加方式写入
    tree=ET.parse(in_file)
    root = tree.getroot()

    for obj in root.iter('object'):
        difficult = obj.find('difficult').text
        cls = obj.find('name').text
        if cls not in classes or int(difficult)==1:
            continue
        #cls_id = classes.index(cls)#改成人的话只有1类,cls_id永远为0
        xmlbox = obj.find('bndbox')
        b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text))
        bb = convert(b)
        print(bb)
        out_file.write(str(image_id) + " " + " ".join([str(a) for a in bb]) + '\n') #图片名+空格+坐标信息(坐标之间空格间隔)
    out_file.close()
#读取out文件下对应的图片名,并去掉逗号
image_ids = open('/home/davy/下载/yolo_each_version/alexey_simple/caltech_test/out.txt').read().strip().split()
for i in range(len(image_ids)):
    image_ids[i] = image_ids[i][:-1]
#print(image_ids)
for image_id in image_ids:
    convert_annotation(2007, image_id)

​ 这里生成的person_gt.txt部分内容如下:

000001 228.04708948403487 132.5344370077146 6.439862509174759 12.555317447834682
000009 230.92744661144314 131.24515230156234 7.053654320016761 13.13606348956688
000010 591.1319702602229 168.52973977695154 18.088599752168648 44.33240396530363
000010 478.6104328881311 167.79792175807526 11.85185185185179 25.55877616747182
000020 233.06535999869328 129.9418563797974 6.9233349203163925 13.112218043096448

​ 可以看到一张图片可能包含多个gt框,需要把一张图片的内容放在一行来表示。调用pro_train.py文件处理,生成person_gt_matlab.txt。代码只需修改第52行路径即可。

#coding=utf-8
import os
def generate_result(resource_path, des_path):
    rf = open(resource_path)

    content = rf.readline()
    cnt = 0
    tmp_dick = {}
    #tmp_dick是一个字典,cls表示图片名字,bbox表示图片四个坐标
    while content:
        res = content.replace("\n", "").split(" ")
        cls = str(res[0:1][0])
        bbox = res[1:5]
        print("cls:",cls)
        print("bbox", bbox)
        #一个图有多个框
        if cls in tmp_dick:
            tmp_dick[cls].append(bbox)
        #一个图就一个框
        else:
            tmp_dick[cls] = [bbox]
            cnt += 1
        content = rf.readline()
        #print(content)
    print(tmp_dick)
    rf.close()
    wfname = open(des_path, "r+")#图片名,r+表示覆盖写入方式
    #字典的键,也就是图片名
    for key_ in tmp_dick:
        #print("key:",key_)
        wfname.write(str(key_)+',')
        #字典的值,就是每个bbox(包含坐标)
        for detail in tmp_dick[key_]:
            #取一个bbox中的一个值,分别是x,y,w,h
            for index in detail:
                #print("index:",index)
                if index is detail[0]:#左上角x
                    wfname.write(str((float(index))))
                if index is detail[1]:#左上角y
                    wfname.write(str((float(index))))
                if index is detail[2]:#宽
                    wfname.write(str((float(index))))
                if index is detail[3]:#高
                    wfname.write(str((float(index))))
                if index is not detail[-1]:#每个坐标间用,隔开
                    wfname.write(" ")
            if len(tmp_dick[key_]) > 1:
                if detail is not tmp_dick[key_][-1]:
                    wfname.write(";")#不同框间的内容用;隔开
        wfname.write("\n")#不同图片换行
    wfname.close()
generate_result("/home/davy/下载/yolo_each_version/alexey_simple/caltech_test/person_gt.txt", "/home/davy/下载/yolo_each_version/alexey_simple/caltech_test/person_gt_matlab.txt")

​ 生成的person_gt_matlab.txt和out.txt、out1.txt、out2.txt的内容对应。部分内容如下所示:

000001,228.04708948403487 132.5344370077146 6.439862509174759 12.555317447834682
000009,230.92744661144314 131.24515230156234 7.053654320016761 13.13606348956688
000010,591.1319702602229 168.52973977695154 18.088599752168648 44.33240396530363;478.6104328881311 167.79792175807526 11.85185185185179 25.55877616747182
000020,233.06535999869328 129.9418563797974 6.9233349203163925 13.112218043096448
000025,598.681926683717 129.96268826371127 19.121328502415395 45.68222222222224

4.3绘制fppi miss rate曲线

​ 这里我们需要的数据已经处理完毕,可以开始转到matlab画图了。首先调用detectionResults.m生成matlab所需要的预测结果格式。代码修改第1和2行的路径,第11和12行你生成结果的行数。

fid1=fopen("D:\Desktop\ubuntu-shared-files\FPPI-miss rate\caltech-none\out2.txt", "rt");%坐标信息
fid2=fopen("D:\Desktop\ubuntu-shared-files\FPPI-miss rate\caltech-none\out1.txt", "rt");%置信度

data1 = textscan(fid1, '%s', 'delimiter', '\n');
data2 = textscan(fid2, '%s', 'delimiter', '\n');

data1 = data1{1,1};
data2 = data2{1,1};


get_scr_bbox(4258) = struct('Boxes',[],'Scr',[]);
for i=1:4258
    A = data1{i};
    A = cellstr(A);
    A = str2num(cell2mat(A));

    B = data2{i};
    B = cellstr(B);
    B = str2num(cell2mat(B));



    get_scr_bbox(i).Boxes = A;
    get_scr_bbox(i).Scr = B;
end
get_scr_bbox = struct2table(get_scr_bbox);

fclose(fid1);
fclose(fid2);

​ 接下来调用trainingData.m生成matlab所需要的gt框结果的格式。代码修改第1、8、9行

fid=fopen("D:\Desktop\ubuntu-shared-files\FPPI-miss rate\caltech-none\person_gt_matlab.txt", "rt");

data = textscan(fid, '%s', 'delimiter', '\n');

data = data{1,1};
count = 0;

get_results(4258) = struct('Boxes',[],'name',[]);
for i=1:4258
    A = data{i};
    A = regexp(A, ',', 'split'); 
    get_results(i).name = A(1); %图片名
    B = A(2);
    B = str2num(cell2mat(B)); 
    get_results(i).Boxes = B;   
end
get_results = struct2table(get_results);
while feof(fid) ~= 1
    file =  fgetl(fpn);
end

fclose(fid);

​ 运行FPPI_miss_rate.m,里面含有matlab内置函数生成图像。

[am,fppi,missRate] = evaluateDetectionMissRate(get_scr_bbox,get_results(:,1),0.5);
% Plot log average miss rate - FPPI.
figure
loglog(fppi, missRate, 'g','linewidth',1.5); %还有线条类型可以选择
xlabel('false positives per image')
ylabel('miss rate')
grid on
set(gca, 'yMinorTick','on')
set(gca, 'xMinorTick', 'on')
hleg1 = legend('yolov3')
set(hleg1, 'Location', 'NorthEast')
title(sprintf('log Average Miss Rate = %.5f',am))

​ 最终图像

5.补充

1.在results下生成的预测结果文件person.txt,坐标信息分别是左上角xy坐标和右下角xy坐标且是真实坐标。

2.绘制fppi/miss rate图中detectionResults中的Boxes,坐标信息左上角xy、w、h且是真实值

3.用voc_label.py生成的labels文件夹里的坐标信息是左上角xy、w、h但归一化后的值。

4.生成的person_gt.txt和person_gt_matlab.txt都是左上角xy,wh真实值


文章作者: YanBin
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 YanBin !
评论
评论
  目录