forked from ufownl/global-wheat-detection
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
109 lines (95 loc) · 5.14 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import time
import random
import argparse
import mxnet as mx
import gluoncv as gcv
from dataset import load_dataset, get_batches
from model import init_model, load_model
def train(best_score, start_epoch, max_epochs, learning_rate, batch_size, img_w, img_h, sgd, context):
print("Loading dataset...", flush=True)
dataset = load_dataset("data")
split = int(len(dataset) * 0.9)
training_set = dataset[:split]
print("Training set: ", len(training_set))
validation_set = dataset[split:]
print("Validation set: ", len(validation_set))
if os.path.isfile("model/global-wheat-yolo3-darknet53.params"):
model = load_model("model/global-wheat-yolo3-darknet53.params", ctx=context)
else:
model = init_model(ctx=context)
metrics = [gcv.utils.metrics.VOCMApMetric(iou_thresh=iou) for iou in [0.5, 0.55, 0.6, 0.65, 0.7, 0.75]]
print("Learning rate: ", learning_rate)
if sgd:
print("Optimizer: SGD")
trainer = mx.gluon.Trainer(model.collect_params(), "SGD", {
"learning_rate": learning_rate,
"momentum": 0.5
})
else:
print("Optimizer: Nadam")
trainer = mx.gluon.Trainer(model.collect_params(), "Nadam", {
"learning_rate": learning_rate
})
if os.path.isfile("model/global-wheat-yolo3-darknet53.state"):
trainer.load_states("model/global-wheat-yolo3-darknet53.state")
print("Traning...", flush=True)
for epoch in range(start_epoch, max_epochs):
ts = time.time()
random.shuffle(training_set)
training_total_L = 0.0
training_batches = 0
for x, objectness, center_targets, scale_targets, weights, class_targets, gt_bboxes in get_batches(training_set, batch_size, width=img_w, height=img_h, net=model, ctx=context):
training_batches += 1
with mx.autograd.record():
obj_loss, center_loss, scale_loss, cls_loss = model(x, gt_bboxes, objectness, center_targets, scale_targets, weights, class_targets)
L = obj_loss + center_loss + scale_loss + cls_loss
L.backward()
trainer.step(x.shape[0])
training_batch_L = mx.nd.mean(L).asscalar()
if training_batch_L != training_batch_L:
raise ValueError()
training_total_L += training_batch_L
print("[Epoch %d Batch %d] batch_loss %.10f average_loss %.10f elapsed %.2fs" % (
epoch, training_batches, training_batch_L, training_total_L / training_batches, time.time() - ts
), flush=True)
training_avg_L = training_total_L / training_batches
for metric in metrics:
metric.reset()
for x, label in get_batches(validation_set, batch_size, width=img_w, height=img_h, ctx=context):
classes, scores, bboxes = model(x)
for metric in metrics:
metric.update(
bboxes,
classes.reshape((0, -1)),
scores.reshape((0, -1)),
label[:, :, :4],
label[:, :, 4:5].reshape((0, -1))
)
score = mx.nd.array([metric.get()[1] for metric in metrics], ctx=context).mean()
print("[Epoch %d] training_loss %.10f validation_score %.10f best_score %.10f duration %.2fs" % (
epoch + 1, training_avg_L, score.asscalar(), best_score, time.time() - ts
), flush=True)
if score.asscalar() > best_score:
best_score = score.asscalar()
model.save_parameters("model/global-wheat-yolo3-darknet53_best.params")
model.save_parameters("model/global-wheat-yolo3-darknet53.params")
trainer.save_states("model/global-wheat-yolo3-darknet53.state")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Start a global-wheat-detection trainer.")
parser.add_argument("--best_score", help="set the current best score (default: 0.0)", type=float, default=0.0)
parser.add_argument("--start_epoch", help="set the start epoch (default: 0)", type=int, default=0)
parser.add_argument("--max_epochs", help="set the max epochs (default: 100)", type=int, default=100)
parser.add_argument("--learning_rate", help="set the learning rate (default: 0.001)", type=float, default=0.001)
parser.add_argument("--batch_size", help="set the batch size (default: 32)", type=int, default=32)
parser.add_argument("--img_w", help="set the width of training images (default: 512)", type=int, default=512)
parser.add_argument("--img_h", help="set the height of training images (default: 512)", type=int, default=512)
parser.add_argument("--sgd", help="using sgd optimizer", action="store_true")
parser.add_argument("--device_id", help="select device that the model using (default: 0)", type=int, default=0)
parser.add_argument("--gpu", help="using gpu acceleration", action="store_true")
args = parser.parse_args()
if args.gpu:
context = mx.gpu(args.device_id)
else:
context = mx.cpu(args.device_id)
train(args.best_score, args.start_epoch, args.max_epochs, args.learning_rate, args.batch_size, args.img_w, args.img_h, args.sgd, context)