def run_test(): print('Loading model..') net = RetinaNet(args.num_classes) ckpt = torch.load(args.checkpoint) net.load_state_dict(ckpt['net']) net.eval() net.cuda() transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) print('Loading image..') img = Image.open(args.img_path) w, h = img.size print('Predicting..') x = transform(img) x = x.unsqueeze(0) with torch.no_grad(): loc_preds, cls_preds = net(x.cuda()) print('Decoding..') encoder = DataEncoder() boxes, labels, scores = encoder.decode(loc_preds.cpu().data.squeeze(), cls_preds.cpu().data.squeeze(), (w, h)) label_map = load_pickled_label_map() draw = ImageDraw.Draw(img, 'RGBA') fnt = ImageFont.truetype('Pillow/Tests/fonts/DejaVuSans.ttf', 11) for idx in range(len(boxes)): box = boxes[idx] label = labels[idx] draw.rectangle(list(box), outline=(255, 0, 0, 200)) item_tag = '{0}: {1:.2f}'.format(label_map[label.item()], scores[idx]) iw, ih = fnt.getsize(item_tag) ix, iy = list(box[:2]) draw.rectangle((ix, iy, ix + iw, iy + ih), fill=(255, 0, 0, 100)) draw.text(list(box[:2]), item_tag, font=fnt, fill=(255, 255, 255, 255)) img.save(os.path.join('./rst', 'rst.png'), 'PNG')
def test(): print('initializing network...') network = RetinaNet(3, 10, 9) checkpoint = torch.load(args.pth) network.load_state_dict(checkpoint['net']) network = network.cuda().eval() if args.onnx: dummy_input = torch.randn(1, 3, 416, 416, device='cuda') torch.onnx.export(network, dummy_input, "retina-bdd.onnx", verbose=True) return class_names = ["bus","traffic light","traffic sign","person","bike","truck","motor","car","train","rider"] image = args.img img_tail = image.split('.')[-1] if img_tail == 'jpg' or img_tail =='jpeg' or img_tail == 'png': detect_image(image, network, args.thresh, class_names) elif img_tail == 'mp4' or img_tail =='mkv' or img_tail == 'avi' or img_tail =='0': detect_vedio(image, network, args.thresh, class_names) else: print('unknow image type!!!')
def train(): max_epoch = 120 lr = 0.001 step_epoch = 50 lr_decay = 0.1 train_batch_size = 64 val_batch_size = 16 if args.vis: vis = visdom.Visdom(env=u'test1') #dataset print('importing dataset...') substep = args.substep trainset = bdd.bddDataset(416, 416) loader_train = data.DataLoader(trainset, batch_size=train_batch_size // substep, shuffle=1, num_workers=4, drop_last=True) valset = bdd.bddDataset(416, 416, train=0) loader_val = data.DataLoader(valset, batch_size=val_batch_size // substep, shuffle=1, num_workers=4, drop_last=True) #model print('initializing network...') network = RetinaNet(3, 10, 9) if args.resume: print('Resuming from checkpoint..') checkpoint = torch.load('./checkpoint/retina-bdd-backup.pth') network.load_state_dict(checkpoint['net']) best_loss = checkpoint['loss'] start_epoch = checkpoint['epoch'] else: start_epoch = 0 if args.ngpus > 1: net = torch.nn.DataParallel(network).cuda() else: net = network.cuda() #criterion criterion = FocalLoss(10, 4, 9) optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=1e-4) #start training for i in range(start_epoch, max_epoch): print('--------start training epoch %d --------' % i) trainset.seen = 0 valset.seen = 0 loss_train = 0.0 net.train() t0 = time.time() optimizer.zero_grad() for ii, (image, cls_truth, box_truth) in enumerate(loader_train): image = Variable(image).cuda() cls_truth = Variable(cls_truth).cuda() box_truth = Variable(box_truth).cuda() #forward cls_pred, box_pred = net(image) #loss loss = criterion(cls_pred, box_pred, cls_truth, box_truth) #backward loss.backward() #update if (ii + 1) % substep == 0: optimizer.step() optimizer.zero_grad() loss_train += loss.data #print('forward time: %f, loss time: %f, backward time: %f, update time: %f'%((t1-t0),(t2-t1),(t3-t2),(t4-t3))) print('%3d/%3d => loss: %f, cls_loss: %f, box_loss: %f' % (ii, i, criterion.loss, criterion.cls_loss, criterion.box_loss)) if args.vis: vis.line(Y=loss.data.cpu().view(1, 1).numpy(), X=np.array([ii]), win='loss', update='append' if ii > 0 else None) t1 = time.time() print('---one training epoch time: %fs---' % ((t1 - t0))) if i < 3: loss_train = loss.data else: loss_train = loss_train / ii loss_val = 0.0 net.eval() for jj, (image, cls_truth, box_truth) in enumerate(loader_val): image = Variable(image).cuda() cls_truth = Variable(cls_truth).cuda() box_truth = Variable(box_truth).cuda() optimizer.zero_grad() cls_pred, box_pred = net(image) loss = criterion(cls_pred, box_pred, cls_truth, box_truth) loss_val += loss.data print('val: %3d/%3d => loss: %f, cls_loss: %f, box_loss: %f' % (jj, i, criterion.loss, criterion.cls_loss, criterion.box_loss)) loss_val = loss_val / jj if args.vis: vis.line(Y=torch.cat((loss_val.view(1,1), loss_train.view(1,1)),1).cpu().numpy(),X=np.array([i]),\ win='eval-train loss',update='append' if i>0 else None) print('Saving weights...') if args.ngpus > 1: state = { 'net': net.module.state_dict(), 'loss': loss_val, 'epoch': i, } else: state = { 'net': net.state_dict(), 'loss': loss_val, 'epoch': i, } if not os.path.isdir('checkpoint'): os.mkdir('checkpoint') if (i + 1) % 10 == 0: torch.save(state, './checkpoint/retina-bdd-%03d.pth' % i) torch.save(state, './checkpoint/retina-bdd-backup.pth') if (i + 1) % step_epoch == 0: lr = lr * lr_decay print('learning rate: %f' % lr) optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=1e-4) torch.save(network, 'retina-bdd_final.pkl') print('finished training!!!')
def run_train(): assert torch.cuda.is_available(), 'Error: CUDA not found!' start_epoch = 0 # start from epoch 0 or last epoch # Data print('Load ListDataset') transform = transforms.Compose([ transforms.ToTensor(), # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) trainset = ListDataset(img_dir=config.img_dir, list_filename=config.train_list_filename, label_map_filename=config.label_map_filename, train=True, transform=transform, input_size=config.img_res) trainloader = torch.utils.data.DataLoader( trainset, batch_size=config.train_batch_size, shuffle=True, num_workers=8, collate_fn=trainset.collate_fn) testset = ListDataset(img_dir=config.img_dir, list_filename=config.test_list_filename, label_map_filename=config.label_map_filename, train=False, transform=transform, input_size=config.img_res) testloader = torch.utils.data.DataLoader(testset, batch_size=config.test_batch_size, shuffle=False, num_workers=8, collate_fn=testset.collate_fn) # Model net = RetinaNet() if os.path.exists(config.checkpoint_filename): print('Load saved checkpoint: {}'.format(config.checkpoint_filename)) checkpoint = torch.load(config.checkpoint_filename) net.load_state_dict(checkpoint['net']) best_loss = checkpoint['loss'] start_epoch = checkpoint['epoch'] else: print('Load pretrained model: {}'.format(config.pretrained_filename)) if not os.path.exists(config.pretrained_filename): import_pretrained_resnet() net.load_state_dict(torch.load(config.pretrained_filename)) net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count())) net.cuda() criterion = FocalLoss() optimizer = optim.SGD(net.parameters(), lr=1e-3, momentum=0.9, weight_decay=1e-4) # Training def train(epoch): print('\nEpoch: %d' % epoch) net.train() net.module.freeze_bn() train_loss = 0 total_batches = int( math.ceil(trainloader.dataset.num_samples / trainloader.batch_size)) for batch_idx, targets in enumerate(trainloader): inputs = targets[0] loc_targets = targets[1] cls_targets = targets[2] inputs = inputs.cuda() loc_targets = loc_targets.cuda() cls_targets = cls_targets.cuda() optimizer.zero_grad() loc_preds, cls_preds = net(inputs) loss = criterion(loc_preds, loc_targets, cls_preds, cls_targets) loss.backward() optimizer.step() train_loss += loss.data print('[%d| %d/%d] loss: %.3f | avg: %.3f' % (epoch, batch_idx, total_batches, loss.data, train_loss / (batch_idx + 1))) # Test def test(epoch): print('\nTest') net.eval() test_loss = 0 total_batches = int( math.ceil(testloader.dataset.num_samples / testloader.batch_size)) for batch_idx, targets in enumerate(testloader): inputs = targets[0] loc_targets = targets[1] cls_targets = targets[2] inputs = inputs.cuda() loc_targets = loc_targets.cuda() cls_targets = cls_targets.cuda() loc_preds, cls_preds = net(inputs) loss = criterion(loc_preds, loc_targets, cls_preds, cls_targets) test_loss += loss.data print('[%d| %d/%d] loss: %.3f | avg: %.3f' % (epoch, batch_idx, total_batches, loss.data, test_loss / (batch_idx + 1))) # Save checkpoint global best_loss test_loss /= len(testloader) if test_loss < best_loss: print('Save checkpoint: {}'.format(config.checkpoint_filename)) state = { 'net': net.module.state_dict(), 'loss': test_loss, 'epoch': epoch, } if not os.path.exists(os.path.dirname(config.checkpoint_filename)): os.makedirs(os.path.dirname(config.checkpoint_filename)) torch.save(state, config.checkpoint_filename) best_loss = test_loss for epoch in range(start_epoch, start_epoch + 1000): train(epoch) test(epoch)