def train(): dataset = VOCDetection(root=args.dataset_root, transform=SSDAugmentation(512, MEANS)) data_loader = data.DataLoader(dataset, args.batch_size, num_workers=0, shuffle=True, collate_fn=detection_collate, pin_memory=False) model = EfficientDet(num_classes=21) model = model.cuda() optimizer = optim.AdamW(model.parameters(), lr=args.lr) criterion = FocalLoss() model.train() iteration = 0 for epoch in range(args.num_epoch): print('Start epoch: {} ...'.format(epoch)) total_loss = [] for idx, sample in enumerate(data_loader): images = sample['img'].cuda() classification, regression, anchors = model(images) classification_loss, regression_loss = criterion( classification, regression, anchors, sample['annot']) classification_loss = classification_loss.mean() regression_loss = regression_loss.mean() loss = classification_loss + regression_loss if bool(loss == 0): continue optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) optimizer.step() total_loss.append(loss.item()) if (iteration % 100 == 0): print( 'Epoch/Iteration: {}/{}, classification: {}, regression: {}, totol_loss: {}' .format(epoch, iteration, classification_loss.item(), regression_loss.item(), np.mean(total_loss))) iteration += 1 torch.save(model.state_dict(), './weights/checkpoint_{}.pth'.format(epoch))
def train(): if args.dataset == 'COCO': if args.dataset_root == VOC_ROOT: if not os.path.exists(COCO_ROOT): parser.error('Must specify dataset_root if specifying dataset') print("WARNING: Using default COCO dataset_root because " + "--dataset_root was not specified.") args.dataset_root = COCO_ROOT cfg = coco dataset = COCODetection(root=args.dataset_root, transform=SSDAugmentation( cfg['min_dim'], MEANS)) elif args.dataset == 'VOC': if args.dataset_root == COCO_ROOT: parser.error('Must specify dataset if specifying dataset_root') cfg = voc dataset = VOCDetection(root=args.dataset_root, transform=SSDAugmentation( cfg['min_dim'], MEANS)) net = EfficientDet(num_class=cfg['num_classes']) if args.cuda: net = net.cuda() # if args.cuda: # net = torch.nn.DataParallel(net) # cudnn.benchmark = True optimizer = optim.AdamW(net.parameters(), lr=args.lr) criterion = MultiBoxLoss(cfg['num_classes'], 0.5, True, 0, True, 3, 0.5, False, args.cuda) data_loader = data.DataLoader(dataset, args.batch_size, num_workers=args.num_workers, shuffle=True, collate_fn=detection_collate, pin_memory=True) net.train() iteration = 0 for epoch in range(args.num_epoch): print('\n Start epoch: {} ...'.format(epoch)) for idx, (images, targets) in enumerate(data_loader): if args.cuda: images = Variable(images.cuda()) targets = [ Variable(ann.cuda(), volatile=True) for ann in targets ] else: images = Variable(images) targets = [Variable(ann, volatile=True) for ann in targets] # forward t0 = time.time() out = net(images) optimizer.zero_grad() loss_l, loss_c = criterion(out, targets) loss = loss_l + loss_c loss.backward() optimizer.step() t1 = time.time() if iteration % 10 == 0: print('timer: %.4f sec.' % (t1 - t0)) print('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss), end=' ') if iteration != 0 and iteration % 5000 == 0: print('Saving state, iteration:', iteration) torch.save(net.state_dict(), 'weights/Effi' + repr(idx) + '.pth') iteration += 1 torch.save(net.state_dict(), args.save_folder + '' + args.dataset + '.pth')
def train(): if args.dataset == 'COCO': if args.dataset_root == VOC_ROOT: if not os.path.exists(COCO_ROOT): parser.error('Must specify dataset_root if specifying dataset') print("WARNING: Using default COCO dataset_root because " + "--dataset_root was not specified.") args.dataset_root = COCO_ROOT cfg = coco dataset = COCODetection(root=args.dataset_root, transform=SSDAugmentation( cfg['min_dim'], MEANS)) elif args.dataset == 'VOC': if args.dataset_root == COCO_ROOT: parser.error('Must specify dataset if specifying dataset_root') cfg = voc dataset = VOCDetection(root=args.dataset_root, transform=SSDAugmentation( cfg['min_dim'], MEANS)) if args.visdom: import visdom viz = visdom.Visdom() # ssd_net = build_ssd('train', cfg['min_dim'], cfg['num_classes']) net = EfficientDet(num_class=cfg['num_classes']) if args.cuda: net = torch.nn.DataParallel(net) cudnn.benchmark = True # if args.resume: # print('Resuming training, loading {}...'.format(args.resume)) # ssd_net.load_weights(args.resume) # else: # vgg_weights = torch.load(args.save_folder + args.basenet) # print('Loading base network...') # ssd_net.vgg.load_state_dict(vgg_weights) if args.cuda: net = net.cuda() optimizer = optim.AdamW(net.parameters(), lr=args.lr) criterion = MultiBoxLoss(cfg['num_classes'], 0.5, True, 0, True, 3, 0.5, False, args.cuda) net.train() # loss counters loc_loss = 0 conf_loss = 0 epoch = 0 print('Loading the dataset...') epoch_size = len(dataset) // args.batch_size print('Training SSD on:', dataset.name) print('Using the specified args:') print(args) step_index = 0 if args.visdom: vis_title = 'SSD.PyTorch on ' + dataset.name vis_legend = ['Loc Loss', 'Conf Loss', 'Total Loss'] iter_plot = create_vis_plot('Iteration', 'Loss', vis_title, vis_legend) epoch_plot = create_vis_plot('Epoch', 'Loss', vis_title, vis_legend) data_loader = data.DataLoader(dataset, args.batch_size, num_workers=args.num_workers, shuffle=True, collate_fn=detection_collate, pin_memory=True) iteration = 0 for epoch in range(args.num_epoch): for idx, (images, targets) in enumerate(data_loader): if iteration in cfg['lr_steps']: step_index += 1 adjust_learning_rate(optimizer, args.gamma, step_index) if args.cuda: images = Variable(images.cuda()) targets = [ Variable(ann.cuda(), volatile=True) for ann in targets ] else: images = Variable(images) targets = [Variable(ann, volatile=True) for ann in targets] # forward t0 = time.time() out = net(images) # backprop optimizer.zero_grad() loss_l, loss_c = criterion(out, targets) loss = loss_l + loss_c loss.backward() optimizer.step() t1 = time.time() loc_loss += loss_l conf_loss += loss_c if iteration % 10 == 0: print('timer: %.4f sec.' % (t1 - t0)) print('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss), end=' ') if iteration != 0 and iteration % 5000 == 0: print('Saving state, iter:', iteration) torch.save(net.state_dict(), 'weights/Effi' + repr(idx) + '.pth') iteration += 1 torch.save(net.state_dict(), args.save_folder + '' + args.dataset + '.pth')