# store pruning models if not os.path.exists(args.prune_folder): os.mkdir(args.prune_folder) # ------------------------------------------- 1st prune: load model from state_dict model = build_refine('train', cfg['min_dim'], cfg['num_classes'], use_refine = True, use_tcb = True).cuda() state_dict = torch.load(args.trained_model) from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items(): head = k[:7] # head = k[:4] if head == 'module.': # head == 'vgg.' name = k[7:] # name = 'base.' + k[4:] else: name = k new_state_dict[name] = v model.load_state_dict(new_state_dict) #model.load_state_dict(torch.load(args.trained_model)) # ------------------------------------------- >= 2nd prune: load model from previous pruning # model = torch.load(args.trained_model).cuda() print('Finished loading model!') testset = VOCDetection(root=args.dataset_root, image_sets=[('2007', 'test')], transform=BaseTransform(cfg['min_dim'], cfg['testset_mean'])) arm_criterion = RefineMultiBoxLoss(2, 0.5, True, 0, True, 3, 0.5, False, 0, args.cuda) odm_criterion = RefineMultiBoxLoss(cfg['num_classes'], 0.5, True, 0, True, 3, 0.5, False, 0.01, args.cuda)# 0.01 -> 0.99 negative confidence threshold prunner = Prunner_refineDet(testset, arm_criterion, odm_criterion, model) prunner.prune(cut_ratio = args.cut_ratio)
if args.cuda: net = torch.nn.DataParallel(net, device_ids=[args.device]) if args.cuda: net.cuda() cudnn.benchmark = True optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) #optimizer = optim.RMSprop(net.parameters(), lr=args.lr,alpha = 0.9, eps=1e-08, # momentum=args.momentum, weight_decay=args.weight_decay) criterion = RefineMultiBoxLoss(num_classes, 0.5, True, 0, True, 3, 0.5, False, 0.9) priorbox = PriorBox(cfg) #with torch.no_grad(): priors = Variable(priorbox.forward(), volatile=True) if args.cuda: priors = priors.cuda() def train(): net.train() # loss counters loc_loss = 0 # epoch conf_loss = 0 epoch = 0 + args.resume_epoch print('Loading Dataset...')
if args.gpu_id: net = torch.nn.DataParallel(net, device_ids=args.gpu_id) if args.cuda: net.cuda() cudnn.benchmark = True optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) #optimizer = optim.RMSprop(net.parameters(), lr=args.lr,alpha = 0.9, eps=1e-08, # momentum=args.momentum, weight_decay=args.weight_decay) arm_criterion = RefineMultiBoxLoss(2, 0.5, True, 0, True, 3, 0.5, False) odm_criterion = RefineMultiBoxLoss(num_classes, 0.5, True, 0, True, 3, 0.5, False, 0.01) priorbox = PriorBox(cfg) detector = Detect(num_classes, 0, cfg, object_score=0.01) priors = Variable(priorbox.forward(), volatile=True) #dataset print('Loading Dataset...') if args.dataset == 'VOC': testset = VOCDetection(VOCroot, [('2007', 'test')], None, AnnotationTransform()) train_dataset = VOCDetection(VOCroot, train_sets, preproc(img_dim, rgb_means, p), AnnotationTransform()) elif args.dataset == 'COCO': testset = COCODetection(COCOroot, [('2014', 'minival')], None)
def main(): global args args = arg_parse() cfg_from_file(args.cfg_file) save_folder = args.save_folder batch_size = cfg.TRAIN.BATCH_SIZE bgr_means = cfg.TRAIN.BGR_MEAN p = 0.6 gamma = cfg.SOLVER.GAMMA momentum = cfg.SOLVER.MOMENTUM weight_decay = cfg.SOLVER.WEIGHT_DECAY size = cfg.MODEL.SIZE thresh = cfg.TEST.CONFIDENCE_THRESH if cfg.DATASETS.DATA_TYPE == 'VOC': trainvalDataset = VOCDetection top_k = 1000 else: trainvalDataset = COCODetection top_k = 1000 dataset_name = cfg.DATASETS.DATA_TYPE dataroot = cfg.DATASETS.DATAROOT trainSet = cfg.DATASETS.TRAIN_TYPE valSet = cfg.DATASETS.VAL_TYPE num_classes = cfg.MODEL.NUM_CLASSES start_epoch = args.resume_epoch epoch_step = cfg.SOLVER.EPOCH_STEPS end_epoch = cfg.SOLVER.END_EPOCH if not os.path.exists(save_folder): os.mkdir(save_folder) torch.set_default_tensor_type('torch.cuda.FloatTensor') net = SSD(cfg) print(net) if cfg.MODEL.SIZE == '300': size_cfg = cfg.SMALL else: size_cfg = cfg.BIG optimizer = optim.SGD( net.parameters(), lr=cfg.SOLVER.BASE_LR, momentum=momentum, weight_decay=weight_decay) if args.resume_net != None: checkpoint = torch.load(args.resume_net) state_dict = checkpoint['model'] from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items(): head = k[:7] if head == 'module.': name = k[7:] # remove `module.` else: name = k new_state_dict[name] = v net.load_state_dict(new_state_dict) optimizer.load_state_dict(checkpoint['optimizer']) print('Loading resume network...') if args.ngpu > 1: net = torch.nn.DataParallel(net) net.cuda() cudnn.benchmark = True criterion = list() if cfg.MODEL.REFINE: detector = Detect(cfg) arm_criterion = RefineMultiBoxLoss(cfg, 2) odm_criterion = RefineMultiBoxLoss(cfg, cfg.MODEL.NUM_CLASSES) criterion.append(arm_criterion) criterion.append(odm_criterion) else: detector = Detect(cfg) ssd_criterion = MultiBoxLoss(cfg) criterion.append(ssd_criterion) TrainTransform = preproc(size_cfg.IMG_WH, bgr_means, p) ValTransform = BaseTransform(size_cfg.IMG_WH, bgr_means, (2, 0, 1)) val_dataset = trainvalDataset(dataroot, valSet, ValTransform, dataset_name) val_loader = data.DataLoader( val_dataset, batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=detection_collate) for epoch in range(start_epoch + 1, end_epoch + 1): train_dataset = trainvalDataset(dataroot, trainSet, TrainTransform, dataset_name) epoch_size = len(train_dataset) train_loader = data.DataLoader( train_dataset, batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=detection_collate) train(train_loader, net, criterion, optimizer, epoch, epoch_step, gamma, end_epoch, cfg) if (epoch % 5 == 0) or (epoch % 2 == 0 and epoch >= 60): save_checkpoint(net, epoch, size, optimizer) if (epoch >= 2 and epoch % 2 == 0): eval_net( val_dataset, val_loader, net, detector, cfg, ValTransform, top_k, thresh=thresh, batch_size=batch_size) save_checkpoint(net, end_epoch, size, optimizer)
def main_worker(gpu, ngpus_per_node, args): global best_map ## deal with args args.gpu = gpu cfg_from_file(args.cfg_file) torch.set_default_tensor_type('torch.cuda.FloatTensor') # distributed cfgs if args.gpu is not None: print("Use GPU: {} for training".format(args.gpu)) if args.distributed: if args.dist_url == "env://" and args.rank == -1: args.rank = int(os.environ["RANK"]) if args.multiprocessing_distributed: # For multiprocessing distributed training, rank needs to be the # global rank among all the processes args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) torch.cuda.set_device(args.gpu) net = SSD(cfg) # print(net) if args.resume_net != None: checkpoint = torch.load(args.resume_net) state_dict = checkpoint['model'] from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items(): head = k[:7] if head == 'module.': name = k[7:] # remove `module.` else: name = k new_state_dict[name] = v net.load_state_dict(new_state_dict) print('Loading resume network...') if args.distributed: # For multiprocessing distributed, DistributedDataParallel constructor # should always set the single device scope, otherwise, # DistributedDataParallel will use all available devices. if args.gpu is not None: # print(args.gpu) torch.cuda.set_device(args.gpu) net.cuda(args.gpu) # When using a single GPU per process and per # DistributedDataParallel, we need to divide the batch size # ourselves based on the total number of GPUs we have args.batch_size = int(args.batch_size / ngpus_per_node) args.workers = int(args.workers / ngpus_per_node) net = torch.nn.parallel.DistributedDataParallel( net, device_ids=[args.gpu]) else: net.cuda() # DistributedDataParallel will divide and allocate batch_size to all # available GPUs if device_ids are not set net = torch.nn.parallel.DistributedDataParallel(net) elif args.gpu is not None: # torch.cuda.set_device(args.gpu) net = net.cuda(args.gpu) # args = arg_parse() batch_size = args.batch_size print("batch_size = ", batch_size) bgr_means = cfg.TRAIN.BGR_MEAN p = 0.6 gamma = cfg.SOLVER.GAMMA momentum = cfg.SOLVER.MOMENTUM weight_decay = cfg.SOLVER.WEIGHT_DECAY size = cfg.MODEL.SIZE # size =300 thresh = cfg.TEST.CONFIDENCE_THRESH if cfg.DATASETS.DATA_TYPE == 'VOC': trainvalDataset = VOCDetection top_k = 1000 else: trainvalDataset = COCODetection top_k = 1000 dataset_name = cfg.DATASETS.DATA_TYPE dataroot = cfg.DATASETS.DATAROOT trainSet = cfg.DATASETS.TRAIN_TYPE valSet = cfg.DATASETS.VAL_TYPE num_classes = cfg.MODEL.NUM_CLASSES start_epoch = args.resume_epoch epoch_step = cfg.SOLVER.EPOCH_STEPS end_epoch = cfg.SOLVER.END_EPOCH args.num_workers = args.workers # optimizer optimizer = optim.SGD(net.parameters(), lr=cfg.SOLVER.BASE_LR, momentum=momentum, weight_decay=weight_decay) if cfg.MODEL.SIZE == '300': size_cfg = cfg.SMALL else: size_cfg = cfg.BIG # if args.resume_net != None: # checkpoint = torch.load(args.resume_net) # optimizer.load_state_dict(checkpoint['optimizer']) cudnn.benchmark = True # deal with criterion criterion = list() if cfg.MODEL.REFINE: detector = Detect(cfg) arm_criterion = RefineMultiBoxLoss(cfg, 2) odm_criterion = RefineMultiBoxLoss(cfg, cfg.MODEL.NUM_CLASSES) arm_criterion.cuda(args.gpu) odm_criterion.cuda(args.gpu) criterion.append(arm_criterion) criterion.append(odm_criterion) else: detector = Detect(cfg) ssd_criterion = MultiBoxLoss(cfg) criterion.append(ssd_criterion) # deal with dataset TrainTransform = preproc(size_cfg.IMG_WH, bgr_means, p) ValTransform = BaseTransform(size_cfg.IMG_WH, bgr_means, (2, 0, 1)) val_dataset = trainvalDataset(dataroot, valSet, ValTransform, dataset_name) val_loader = data.DataLoader(val_dataset, batch_size, shuffle=False, num_workers=args.num_workers * ngpus_per_node, collate_fn=detection_collate) # deal with training dataset train_dataset = trainvalDataset(dataroot, trainSet, TrainTransform, dataset_name) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) else: train_sampler = None train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=(train_sampler is None), num_workers=args.num_workers, collate_fn=detection_collate, pin_memory=True, sampler=train_sampler) ## set net in training phase net.train() for epoch in range(start_epoch + 1, end_epoch + 1): if args.distributed: train_sampler.set_epoch(epoch) # train_loader = data.DataLoader(train_dataset, batch_size, shuffle=True, num_workers=args.num_workers, # collate_fn=detection_collate) # Training train(train_loader, net, criterion, optimizer, epoch, epoch_step, gamma, end_epoch, cfg, args) if (epoch >= 0 and epoch % 10 == 0): #print("here",args.rank % ngpus_per_node) ## validation the model eval_net(val_dataset, val_loader, net, detector, cfg, ValTransform, args, top_k, thresh=thresh, batch_size=cfg.TEST.BATCH_SIZE) if not args.multiprocessing_distributed or ( args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): if (epoch % 10 == 0) or (epoch % 5 == 0 and epoch >= 60): save_name = os.path.join( args.save_folder, cfg.MODEL.TYPE + "_epoch_{}_rank_{}_{}".format( str(epoch), str(args.rank), str(size)) + '.pth') save_checkpoint(net, epoch, size, optimizer, batch_size, save_name)
if not args.resume: from model.networks import net_init net_init(ssd_net, args.backbone, logging, refine=args.refine, deform=args.deform, multihead=args.multihead) if args.augm_type == 'ssd': data_transform = SSDAugmentation else: data_transform = BaseTransform optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # criterion if 'RefineDet' in args.backbone and args.refine: use_refine = True arm_criterion = RefineMultiBoxLoss(2, 0.5, True, 0, True, 3, 0.5, False, device=device, only_loc=True) criterion = RefineMultiBoxLoss(num_classes, 0.5, True, 0, True, 3, 0.5, False, device=device) else: use_refine = False criterion = MultiBoxLoss(num_classes, 0.5, True, 0, True, 3, 0.5, False, device=device) priorbox = PriorBox(cfg) with torch.no_grad(): priors = priorbox.forward().to(device) def train(): net.train() epoch = args.start_iter if args.dataset_name == 'COCO': dataset = COCODetection(COCOroot, year='trainval2014', image_sets=train_sets, transform=data_transform(ssd_dim, means), phase='train') else:
def train(): # network set-up ssd_net = build_refine('train', cfg['min_dim'], cfg['num_classes'], use_refine=True, use_tcb=True) net = ssd_net if args.cuda: net = torch.nn.DataParallel( ssd_net) # state_dict will have .module. prefix cudnn.benchmark = True if args.resume: print('Resuming training, loading {}...'.format(args.resume)) ssd_net.load_weights(args.resume) else: print('Using preloaded base network...') # Preloaded. print('Initializing other weights...') # initialize newly added layers' weights with xavier method ssd_net.extras.apply(weights_init) ssd_net.trans_layers.apply(weights_init) ssd_net.latent_layrs.apply(weights_init) ssd_net.up_layers.apply(weights_init) ssd_net.arm_loc.apply(weights_init) ssd_net.arm_conf.apply(weights_init) ssd_net.odm_loc.apply(weights_init) ssd_net.odm_conf.apply(weights_init) if args.cuda: net = net.cuda() # otimizer and loss set-up optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) arm_criterion = RefineMultiBoxLoss(2, 0.5, True, 0, True, 3, 0.5, False, 0, args.cuda) odm_criterion = RefineMultiBoxLoss( cfg['num_classes'], 0.5, True, 0, True, 3, 0.5, False, 0.01, args.cuda) # 0.01 -> 0.99 negative confidence threshold # different from normal ssd, where the PriorBox is stored inside SSD object priorbox = PriorBox(cfg) priors = Variable(priorbox.forward(), volatile=True) # detector used in test_net for testing detector = RefineDetect(cfg['num_classes'], 0, cfg, object_score=0.01) net.train() # loss counters loc_loss = 0 conf_loss = 0 epoch = 0 print('Loading the dataset...') epoch_size = len(dataset) // args.batch_size print('Training refineDet on:', dataset.name) print('Using the specified args:') print(args) if args.visdom: import visdom viz = visdom.Visdom() # initialize visdom loss plot vis_title = 'SSD.PyTorch on ' + dataset.name vis_legend = ['Loc Loss', 'Conf Loss', 'Total Loss'] iter_plot = create_vis_plot('Iteration', 'Loss', vis_title, vis_legend) epoch_plot = create_vis_plot('Epoch', 'Loss', vis_title, vis_legend) # adjust learning rate based on epoch stepvalues_VOC = (150 * epoch_size, 200 * epoch_size, 250 * epoch_size) stepvalues_COCO = (90 * epoch_size, 120 * epoch_size, 140 * epoch_size) stepvalues = (stepvalues_VOC, stepvalues_COCO)[args.dataset == 'COCO'] step_index = 0 # training data loader data_loader = data.DataLoader(dataset, args.batch_size, num_workers=args.num_workers, shuffle=True, collate_fn=detection_collate, pin_memory=True) # create batch iterator batch_iterator = iter(data_loader) # batch_iterator = None mean_odm_loss_c = 0 mean_odm_loss_l = 0 mean_arm_loss_c = 0 mean_arm_loss_l = 0 # max_iter = cfg['max_epoch'] * epoch_size for iteration in range(args.start_iter, cfg['max_epoch'] * epoch_size + 10): try: images, targets = next(batch_iterator) except StopIteration: batch_iterator = iter( data_loader) # the dataloader cannot re-initilize images, targets = next(batch_iterator) if args.visdom and iteration != 0 and (iteration % epoch_size == 0): # update visdom loss plot update_vis_plot(epoch, loc_loss, conf_loss, epoch_plot, None, 'append', epoch_size) # reset epoch loss counters loc_loss = 0 conf_loss = 0 if iteration != 0 and (iteration % epoch_size == 0): # adjust_learning_rate(optimizer, args.gamma, epoch) # evaluation if args.evaluate == True: # load net net.eval() APs, mAP = test_net(args.eval_folder, net, detector, priors, args.cuda, val_dataset, BaseTransform(net.module.size, cfg['testset_mean']), args.max_per_image, thresh=args.confidence_threshold ) # 320 originally for cfg['min_dim'] net.train() epoch += 1 # update learning rate if iteration in stepvalues: step_index = stepvalues.index(iteration) + 1 lr = adjust_learning_rate(optimizer, args.gamma, epoch, step_index, iteration, epoch_size) if args.cuda: images = Variable(images.cuda()) targets = [Variable(ann.cuda(), volatile=True) for ann in targets] else: images = Variable(images) targets = [Variable(ann, volatile=True) for ann in targets] # forward t0 = time.time() out = net(images) arm_loc, arm_conf, odm_loc, odm_conf = out # backprop optimizer.zero_grad() #arm branch loss #priors = priors.type(type(images.data)) #convert to same datatype arm_loss_l, arm_loss_c = arm_criterion((arm_loc, arm_conf), priors, targets) #odm branch loss odm_loss_l, odm_loss_c = odm_criterion( (odm_loc, odm_conf), priors, targets, (arm_loc, arm_conf), False) mean_arm_loss_c += arm_loss_c.data[0] mean_arm_loss_l += arm_loss_l.data[0] mean_odm_loss_c += odm_loss_c.data[0] mean_odm_loss_l += odm_loss_l.data[0] loss = arm_loss_l + arm_loss_c + odm_loss_l + odm_loss_c loss.backward() optimizer.step() t1 = time.time() if iteration % 10 == 0: print('Epoch:' + repr(epoch) + ' || epochiter: ' + repr(iteration % epoch_size) + '/' + repr(epoch_size) + '|| Total iter ' + repr(iteration) + ' || AL: %.4f AC: %.4f OL: %.4f OC: %.4f||' % (mean_arm_loss_l / 10, mean_arm_loss_c / 10, mean_odm_loss_l / 10, mean_odm_loss_c / 10) + 'Timer: %.4f sec. ||' % (t1 - t0) + 'Loss: %.4f ||' % (loss.data[0]) + 'LR: %.8f' % (lr)) mean_odm_loss_c = 0 mean_odm_loss_l = 0 mean_arm_loss_c = 0 mean_arm_loss_l = 0 # if args.visdom: # update_vis_plot(iteration, loss_l.data[0], loss_c.data[0], # iter_plot, epoch_plot, 'append') if iteration != 0 and iteration % 5000 == 0: print('Saving state, iter:', iteration) torch.save(ssd_net.state_dict(), 'weights/ssd300_refineDet_' + repr(iteration) + '.pth') torch.save(ssd_net.state_dict(), args.save_folder + '' + args.dataset + '.pth')
def main(): global args args = arg_parse() save_folder = args.save_folder batch_size = args.batch_size bgr_means = (104, 117, 123) weight_decay = 0.0005 p = 0.6 gamma = 0.1 momentum = 0.9 dataset_name = args.dataset size = args.size channel_size = args.channel_size thresh = args.confidence_threshold use_refine = False if args.version.split("_")[0] == "refine": use_refine = True if dataset_name[0] == "V": cfg = cfg_dict["VOC"][args.version][str(size)] trainvalDataset = VOCDetection dataroot = VOCroot targetTransform = AnnotationTransform() valSet = datasets_dict["VOC2007"] top_k = 200 else: cfg = cfg_dict["COCO"][args.version][str(size)] trainvalDataset = COCODetection dataroot = COCOroot targetTransform = None valSet = datasets_dict["COCOval"] top_k = 300 num_classes = cfg['num_classes'] start_epoch = args.resume_epoch epoch_step = cfg["epoch_step"] end_epoch = cfg["end_epoch"] if not os.path.exists(save_folder): os.mkdir(save_folder) if args.cuda and torch.cuda.is_available(): torch.set_default_tensor_type('torch.cuda.FloatTensor') else: torch.set_default_tensor_type('torch.FloatTensor') net = model_builder(args.version, cfg, "train", int(size), num_classes, args.channel_size) print(net) if args.resume_net == None: net.load_weights(pretrained_model[args.version]) else: state_dict = torch.load(args.resume_net) from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items(): head = k[:7] if head == 'module.': name = k[7:] # remove `module.` else: name = k new_state_dict[name] = v net.load_state_dict(new_state_dict) print('Loading resume network...') if args.ngpu > 1: net = torch.nn.DataParallel(net) if args.cuda: net.cuda() cudnn.benchmark = True optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=momentum, weight_decay=weight_decay) criterion = list() if use_refine: detector = Detect(num_classes, 0, cfg, use_arm=use_refine) arm_criterion = RefineMultiBoxLoss(2, 0.5, True, 0, True, 3, 0.5, False, args.cuda) odm_criterion = RefineMultiBoxLoss(num_classes, 0.5, True, 0, True, 3, 0.5, False, 0.01, args.cuda) criterion.append(arm_criterion) criterion.append(odm_criterion) else: detector = Detect(num_classes, 0, cfg, use_arm=use_refine) ssd_criterion = MultiBoxLoss(num_classes, 0.5, True, 0, True, 3, 0.5, False, args.cuda) criterion.append(ssd_criterion) TrainTransform = preproc(cfg["img_wh"], bgr_means, p) ValTransform = BaseTransform(cfg["img_wh"], bgr_means, (2, 0, 1)) val_dataset = trainvalDataset(dataroot, valSet, ValTransform, targetTransform, dataset_name) val_loader = data.DataLoader(val_dataset, batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=detection_collate) for epoch in range(start_epoch + 1, end_epoch + 1): train_dataset = trainvalDataset(dataroot, datasets_dict[dataset_name], TrainTransform, targetTransform, dataset_name) epoch_size = len(train_dataset) train_loader = data.DataLoader(train_dataset, batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=detection_collate) train(train_loader, net, criterion, optimizer, epoch, epoch_step, gamma, use_refine) if (epoch % 10 == 0) or (epoch % 5 == 0 and epoch >= 200): save_checkpoint(net, epoch, size) if (epoch >= 150 and epoch % 10 == 0): eval_net(val_dataset, val_loader, net, detector, cfg, ValTransform, top_k, thresh=thresh, batch_size=batch_size) eval_net(val_dataset, val_loader, net, detector, cfg, ValTransform, top_k, thresh=thresh, batch_size=batch_size) save_checkpoint(net, end_epoch, size)
print('Loading resume network...') if args.ngpu > 1: # net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu))) net = torch.nn.DataParallel(net) if args.cuda: net.cuda() cudnn.benchmark = True optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) arm_criterion = RefineMultiBoxLoss(2, 0.5, True, 0, True, 3, 0.5, False, args.cuda) criterion = RefineMultiBoxLoss(num_classes, 0.5, True, 0, True, 3, 0.5, False, 0.01, args.cuda) def train(): net.train() # loss counters loc_loss = 0 # epoch conf_loss = 0 epoch = 0 + args.resume_epoch print('Loading Dataset...') epoch_size = len(train_dataset) // args.batch_size max_iter = args.max_epoch * epoch_size