/
yolov2_train.py
81 lines (68 loc) · 4.13 KB
/
yolov2_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import numpy as np
import chainer
import os
from chainer import serializers, optimizers, cuda, training
from chainer.training import extension,extensions,updaters
import argparse
import glob
from matplotlib import pylab as plt
plt.switch_backend('agg')
from chainer.datasets import TransformDataset
from user_dataset_3class import UserDataset3Class
from guinness_net_yolov2 import GUINNESS_YOLOv2
from yolo_predictor import YOLOv2Predictor
from transform_sg import convert_sg, Transform
parser = argparse.ArgumentParser(description='YOLOv2 trainer')
parser.add_argument('--batch_size', '-b', type=int, default=6, help='Mini batch size')
parser.add_argument('--img_size', '-s', type=int, default=213, help='test image size')
parser.add_argument('--gpu', '-g', type=int, default=0, help='GPU device ID (negative value uses CPU)')
parser.add_argument('--n_epoch', '-e', type=int, default=200, help='# of epochs for training')
parser.add_argument('--lr', type=float, default=0.0001, help='Initial learning rate for Optimizer')
parser.add_argument('--pretrained_model', type=str, default=None, help='Initial learning rate for Optimizer')
parser.add_argument('--output_dir', '-p', type=str, default='logs', help='used to store temporary files')
parser.add_argument('--annotation_path', '-a', type=str, default='hoge', help='ANNOTATION FILE PATH')
parser.add_argument('--image_path', '-i', type=str, default='hoge', help='TRAINING IMAGE FILE PATH')
args = parser.parse_args()
label_names=('car','person','bicycle','other')
anno_files = glob.glob(args.annotation_path + '/*.xml')
n_datasets = len(anno_files)
print("# DATASETS = ", n_datasets)
n_classes = len(label_names) #+1
n_boxes = 5
# initialize CNN model
model = GUINNESS_YOLOv2(n_classes=n_classes, n_boxes=n_boxes)
model = YOLOv2Predictor(model, conf_scale=0.01, unstable_seen=int(n_datasets*args.n_epoch*0.10))
chainer.config.train = True
cuda.get_device(args.gpu).use()
model.to_gpu()
# set optimizer
optimizer = optimizers.MomentumSGD(lr=args.lr, momentum=0.95)
optimizer.setup(model)
optimizer.add_hook(chainer.optimizer.WeightDecay(0.0005))
train = UserDataset3Class(anno_dir=args.annotation_path, img_dir=args.image_path, cls_label=label_names)
train = TransformDataset(train, Transform(n_classes,args.img_size,random_crop=True,flip=True, mean=[0, 0, 0],std=[1, 1, 1]))
train_iter = chainer.iterators.MultiprocessIterator(train, args.batch_size)
updater = training.StandardUpdater(train_iter, optimizer, converter=convert_sg, device=args.gpu)
trainer = training.Trainer(updater, (args.n_epoch, 'epoch'), out=args.output_dir)
# load pre-trained model
if args.pretrained_model is not None:
if os.path.isfile(args.pretrained_model) == True:
serializers.load_npz(args.pretrained_model, model.predictor) # load model_iter_XXX
log_interval = 0.5, 'epoch'
trainer.extend(extensions.LogReport(trigger=log_interval))
trainer.extend(extensions.observe_lr(), trigger=log_interval)
trainer.extend(extensions.PrintReport(
['epoch', 'iteration', 'lr', 'main/loss', 'main/x_loss', 'main/y_loss','main/w_loss','main/h_loss', 'main/c_loss','main/p_loss']),
trigger=log_interval)
trainer.extend(extensions.PlotReport(['main/loss'], x_key='epoch', file_name='loss.png'))
trainer.extend(extensions.PlotReport(['main/x_loss'], x_key='epoch', file_name='x_loss.png'))
trainer.extend(extensions.PlotReport(['main/y_loss'], x_key='epoch', file_name='y_loss.png'))
trainer.extend(extensions.PlotReport(['main/w_loss'], x_key='epoch', file_name='w_loss.png'))
trainer.extend(extensions.PlotReport(['main/h_loss'], x_key='epoch', file_name='h_loss.png'))
trainer.extend(extensions.PlotReport(['main/c_loss'], x_key='epoch', file_name='c_loss.png'))
trainer.extend(extensions.PlotReport(['main/p_loss'], x_key='epoch', file_name='p_loss.png'))
trainer.extend(extensions.ProgressBar(update_interval=2))
trainer.extend(extensions.snapshot(filename='snapshot_{.updater.epoch}'), trigger=(100, 'epoch'))
trainer.extend(extensions.snapshot_object(optimizer, 'snapshot_optimizer_{.updater.epoch}'), trigger=(100,'epoch'))
trainer.extend(extensions.snapshot_object(model.predictor, filename='model_epoch_{.updater.epoch}'), trigger=(100,'epoch'))
trainer.run()