0, confidence_threshold=args.confidence_threshold, nms_threshold=args.nms_threshold, top_k=args.top_k, keep_top_k=args.keep_top_k) else: detect = Detect_RefineDet( num_classes, int(args.input_size), 0, objectness_threshold, confidence_threshold=args.confidence_threshold, nms_threshold=args.nms_threshold, top_k=args.top_k, keep_top_k=args.keep_top_k) net = build_refinedet('test', int(args.input_size), num_classes, backbone_dict) # test multi models, to filter out the best model. # start_epoch = 10; step = 10 start_epoch = 160 step = 5 ToBeTested = [] ToBeTested = [ prefix + f'/RefineDet{args.input_size}_VOC_epoches_{epoch}.pth' for epoch in range(start_epoch, 240, step) ] ToBeTested.append(prefix + f'/RefineDet{args.input_size}_VOC_final.pth') # ToBeTested.append(prefix + f'/RefineDet{args.input_size}_VOC_epoches_10.pth') ap_stats = {"ap50": [], "epoch": []} for index, model_path in enumerate(ToBeTested):
def train(): if args.visdom: import visdom viz = visdom.Visdom() print('Loading the dataset...') if args.dataset == 'COCO': if args.dataset_root == VOC_ROOT: if not os.path.exists(COCOroot): parser.error('Must specify dataset_root if specifying dataset') print("WARNING: Using default COCO dataset_root because " + "--dataset_root was not specified.") args.dataset_root = COCOroot cfg = coco_refinedet[args.input_size] train_sets = [('train2017')] # train_sets = [('train2017', 'val2017')] dataset = COCODetection(COCOroot, train_sets, SSDAugmentation(cfg['min_dim'], MEANS)) elif args.dataset == 'VOC': '''if args.dataset_root == COCO_ROOT: parser.error('Must specify dataset if specifying dataset_root')''' cfg = voc_refinedet[args.input_size] dataset = VOCDetection(root=VOC_ROOT, transform=SSDAugmentation( cfg['min_dim'], MEANS)) print('Training RefineDet on:', dataset.name) print('Using the specified args:') print(args) refinedet_net = build_refinedet('train', int(args.input_size), cfg['num_classes'], backbone_dict) net = refinedet_net print(net) device = torch.device('cuda:0' if args.cuda else 'cpu') if args.ngpu > 1 and args.cuda: net = torch.nn.DataParallel(refinedet_net, device_ids=list(range(args.ngpu))) cudnn.benchmark = True net = net.to(device) if args.resume: print('Resuming training, loading {}...'.format(args.resume)) state_dict = torch.load(args.resume) # create new OrderedDict that does not contain `module.` from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items(): head = k[:7] if head == 'module.': name = k[7:] # remove `module.` else: name = k new_state_dict[name] = v refinedet_net.load_state_dict(new_state_dict) else: print('Initializing weights...') refinedet_net.init_weights(pretrained=pretrained) optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) arm_criterion = RefineDetMultiBoxLoss(2, 0.5, True, 0, True, negpos_ratio, 0.5, False, args.cuda) odm_criterion = RefineDetMultiBoxLoss(cfg['num_classes'], 0.5, True, 0, True, negpos_ratio, 0.5, False, args.cuda, use_ARM=True) priorbox = PriorBox(cfg) with torch.no_grad(): priors = priorbox.forward() priors = priors.to(device) net.train() # loss counters arm_loc_loss = 0 arm_conf_loss = 0 odm_loc_loss = 0 odm_conf_loss = 0 epoch = 0 + args.resume_epoch epoch_size = math.ceil(len(dataset) / args.batch_size) max_iter = args.max_epoch * epoch_size stepvalues = (args.max_epoch * 2 // 3 * epoch_size, args.max_epoch * 8 // 9 * epoch_size, args.max_epoch * epoch_size) if args.dataset == 'VOC': stepvalues = (args.max_epoch * 2 // 3 * epoch_size, args.max_epoch * 5 // 6 * epoch_size, args.max_epoch * epoch_size) step_index = 0 if args.resume_epoch > 0: start_iter = args.resume_epoch * epoch_size for step in stepvalues: if step < start_iter: step_index += 1 else: start_iter = 0 if args.visdom: vis_title = 'RefineDet.PyTorch on ' + dataset.name vis_legend = ['Loc Loss', 'Conf Loss', 'Total Loss'] iter_plot = create_vis_plot(viz, 'Iteration', 'Loss', vis_title, vis_legend) epoch_plot = create_vis_plot(viz, 'Epoch', 'Loss', vis_title, vis_legend) data_loader = data.DataLoader(dataset, args.batch_size, num_workers=args.num_workers, shuffle=True, collate_fn=detection_collate, pin_memory=True) for iteration in range(start_iter, max_iter): if iteration % epoch_size == 0: if args.visdom and iteration != 0: update_vis_plot(viz, epoch, arm_loc_loss, arm_conf_loss, epoch_plot, None, 'append', epoch_size) # reset epoch loss counters arm_loc_loss = 0 arm_conf_loss = 0 odm_loc_loss = 0 odm_conf_loss = 0 # create batch iterator batch_iterator = iter(data_loader) if (epoch % 10 == 0 and epoch > 0) or (epoch % 5 == 0 and epoch > (args.max_epoch * 2 // 3)): torch.save( net.state_dict(), args.save_folder + 'RefineDet' + args.input_size + '_' + args.dataset + '_epoches_' + repr(epoch) + '.pth') epoch += 1 t0 = time.time() if iteration in stepvalues: step_index += 1 lr = adjust_learning_rate(optimizer, args.gamma, epoch, step_index, iteration, epoch_size) # load train data images, targets = next(batch_iterator) images = images.to(device) targets = [ann.to(device) for ann in targets] # for an in targets: # for instance in an: # for cor in instance[:-1]: # if cor < 0 or cor > 1: # raise StopIteration # forward out = net(images) # backprop optimizer.zero_grad() arm_loss_l, arm_loss_c = arm_criterion(out, priors, targets) odm_loss_l, odm_loss_c = odm_criterion(out, priors, targets) arm_loss = arm_loss_l + arm_loss_c odm_loss = odm_loss_l + odm_loss_c loss = arm_loss + odm_loss loss.backward() optimizer.step() arm_loc_loss += arm_loss_l.item() arm_conf_loss += arm_loss_c.item() odm_loc_loss += odm_loss_l.item() odm_conf_loss += odm_loss_c.item() t1 = time.time() batch_time = t1 - t0 eta = int(batch_time * (max_iter - iteration)) print('Epoch:{}/{} || Epochiter: {}/{} || Iter: {}/{} || ARM_L Loss: {:.4f} ARM_C Loss: {:.4f} ODM_L Loss: {:.4f} ODM_C Loss: {:.4f} loss: {:.4f} || LR: {:.8f} || Batchtime: {:.4f} s || ETA: {}'.\ format(epoch, args.max_epoch, (iteration % epoch_size) + 1, epoch_size, iteration + 1, max_iter, arm_loss_l.item(), arm_loss_c.item(), odm_loss_l.item(), odm_loss_c.item(), loss.item(), lr, batch_time, str(datetime.timedelta(seconds=eta)))) if args.visdom: update_vis_plot(viz, iteration, arm_loss_l.item(), arm_loss_c.item(), iter_plot, epoch_plot, 'append') torch.save( refinedet_net.state_dict(), args.save_folder + '/RefineDet{}_{}_final.pth'.format(args.input_size, args.dataset))