def initialize_model_from_cfg(args, gpu_id=0): """Initialize a model from the global cfg. Loads test-time weights and set to evaluation mode. """ Generalized_RCNN = importlib.import_module('modeling_rel.' + cfg.MODEL.TYPE).Generalized_RCNN model = Generalized_RCNN() model.eval() if args.cuda: model.cuda() if args.load_ckpt: load_name = args.load_ckpt logger.info("loading checkpoint %s", load_name) checkpoint = torch.load(load_name, map_location=lambda storage, loc: storage) net_utils_rel.load_ckpt_rel(model, checkpoint['model']) if args.load_detectron: logger.info("loading detectron weights %s", args.load_detectron) load_detectron_weight(model, args.load_detectron) model = mynn.DataParallel(model, cpu_keywords=['im_info', 'roidb'], minibatch=True) return model
def initialize_model(load_ckpt): """Initialize a model from the global cfg. Loads test-time weights and set to evaluation mode. """ model = model_builder_rel.Generalized_RCNN() model.train() model.cuda() load_name = load_ckpt checkpoint = torch.load(load_name, map_location=lambda storage, loc: storage) net_utils_rel.load_ckpt_rel(model, checkpoint['model']) #model = mynn.DataParallel(model, cpu_keywords=['im_info', 'roidb'], minibatch=True) return model
def load_detector_weights(self, weight_name): logger.info("loading pretrained weights from %s", weight_name) checkpoint = torch.load(weight_name, map_location=lambda storage, loc: storage) net_utils_rel.load_ckpt_rel(self, checkpoint['model']) # freeze everything above the rel module for p in self.Conv_Body.parameters(): p.requires_grad = False for p in self.RPN.parameters(): p.requires_grad = False if not cfg.MODEL.UNFREEZE_DET: for p in self.Box_Head.parameters(): p.requires_grad = False for p in self.Box_Outs.parameters(): p.requires_grad = False
def initialize_model_from_cfg(args, gpu_id=0): """Initialize a model from the global cfg. Loads test-time weights and set to evaluation mode. """ model = model_builder_rel.Generalized_RCNN() model.eval() if args.cuda: model.cuda() if args.load_ckpt: load_name = args.load_ckpt logger.info("loading checkpoint %s", load_name) checkpoint = torch.load(load_name, map_location=lambda storage, loc: storage) net_utils_rel.load_ckpt_rel(model, checkpoint['model']) if args.load_detectron: logger.info("loading detectron weights %s", args.load_detectron) load_detectron_weight(model, args.load_detectron) model.RelDN.mix_cent_loss.centroids.data = torch.from_numpy(np.load('/home/wwt/ECCV2020/stage_one/Outputs/mix_centroids_headlimittailgo.npy')).cuda() #model.RelDN.mix_cent_loss.centroids.data = torch.zeros((51,1024)).cuda() #model.RelDN.mix_cent_loss.centroids.data = (torch.randn((51,1024))*0.08+0.02).cuda() model = mynn.DataParallel(model, cpu_keywords=['im_info', 'roidb'], minibatch=True) return model
def _init_modules(self): # VGG16 imagenet pretrained model is initialized in VGG16.py if cfg.RESNETS.IMAGENET_PRETRAINED_WEIGHTS != '': logger.info("Loading pretrained weights from %s", cfg.RESNETS.IMAGENET_PRETRAINED_WEIGHTS) resnet_utils.load_pretrained_imagenet_weights(self) for p in self.Conv_Body.parameters(): p.requires_grad = False if cfg.RESNETS.VRD_PRETRAINED_WEIGHTS != '': self.load_detector_weights(cfg.RESNETS.VRD_PRETRAINED_WEIGHTS) if cfg.VGG16.VRD_PRETRAINED_WEIGHTS != '': self.load_detector_weights(cfg.VGG16.VRD_PRETRAINED_WEIGHTS) if cfg.RESNETS.VG_PRETRAINED_WEIGHTS != '': self.load_detector_weights(cfg.RESNETS.VG_PRETRAINED_WEIGHTS) if cfg.VGG16.VG_PRETRAINED_WEIGHTS != '': self.load_detector_weights(cfg.VGG16.VG_PRETRAINED_WEIGHTS) if cfg.RESNETS.OI_REL_PRETRAINED_WEIGHTS != '': self.load_detector_weights(cfg.RESNETS.OI_REL_PRETRAINED_WEIGHTS) if cfg.VGG16.OI_REL_PRETRAINED_WEIGHTS != '': self.load_detector_weights(cfg.VGG16.OI_REL_PRETRAINED_WEIGHTS) if cfg.RESNETS.VRD_PRD_PRETRAINED_WEIGHTS != '' or cfg.VGG16.VRD_PRD_PRETRAINED_WEIGHTS != '' or \ cfg.RESNETS.VG_PRD_PRETRAINED_WEIGHTS != '' or cfg.VGG16.VG_PRD_PRETRAINED_WEIGHTS != '' or \ cfg.RESNETS.OI_REL_PRD_PRETRAINED_WEIGHTS != '' or cfg.VGG16.OI_REL_PRD_PRETRAINED_WEIGHTS != '': if cfg.RESNETS.VRD_PRD_PRETRAINED_WEIGHTS != '': logger.info("loading prd pretrained weights from %s", cfg.RESNETS.VRD_PRD_PRETRAINED_WEIGHTS) checkpoint = torch.load( cfg.RESNETS.VRD_PRD_PRETRAINED_WEIGHTS, map_location=lambda storage, loc: storage) if cfg.VGG16.VRD_PRD_PRETRAINED_WEIGHTS != '': logger.info("loading prd pretrained weights from %s", cfg.VGG16.VRD_PRD_PRETRAINED_WEIGHTS) checkpoint = torch.load( cfg.VGG16.VRD_PRD_PRETRAINED_WEIGHTS, map_location=lambda storage, loc: storage) if cfg.RESNETS.VG_PRD_PRETRAINED_WEIGHTS != '': logger.info("loading prd pretrained weights from %s", cfg.RESNETS.VG_PRD_PRETRAINED_WEIGHTS) checkpoint = torch.load( cfg.RESNETS.VG_PRD_PRETRAINED_WEIGHTS, map_location=lambda storage, loc: storage) if cfg.VGG16.VG_PRD_PRETRAINED_WEIGHTS != '': logger.info("loading prd pretrained weights from %s", cfg.VGG16.VG_PRD_PRETRAINED_WEIGHTS) checkpoint = torch.load( cfg.VGG16.VG_PRD_PRETRAINED_WEIGHTS, map_location=lambda storage, loc: storage) if cfg.RESNETS.OI_REL_PRD_PRETRAINED_WEIGHTS != '': logger.info("loading prd pretrained weights from %s", cfg.RESNETS.OI_REL_PRD_PRETRAINED_WEIGHTS) checkpoint = torch.load( cfg.RESNETS.OI_REL_PRD_PRETRAINED_WEIGHTS, map_location=lambda storage, loc: storage) if cfg.VGG16.OI_REL_PRD_PRETRAINED_WEIGHTS != '': logger.info("loading prd pretrained weights from %s", cfg.VGG16.OI_REL_PRD_PRETRAINED_WEIGHTS) checkpoint = torch.load( cfg.VGG16.OI_REL_PRD_PRETRAINED_WEIGHTS, map_location=lambda storage, loc: storage) # not using the last softmax layers del checkpoint['model']['Box_Outs.cls_score.weight'] del checkpoint['model']['Box_Outs.cls_score.bias'] del checkpoint['model']['Box_Outs.bbox_pred.weight'] del checkpoint['model']['Box_Outs.bbox_pred.bias'] net_utils_rel.load_ckpt_rel(self.Prd_RCNN, checkpoint['model']) if cfg.TRAIN.FREEZE_PRD_CONV_BODY: for p in self.Prd_RCNN.Conv_Body.parameters(): p.requires_grad = False if cfg.TRAIN.FREEZE_PRD_BOX_HEAD: for p in self.Prd_RCNN.Box_Head.parameters(): p.requires_grad = False if cfg.RESNETS.TO_BE_FINETUNED_WEIGHTS != '' or cfg.VGG16.TO_BE_FINETUNED_WEIGHTS != '': if cfg.RESNETS.TO_BE_FINETUNED_WEIGHTS != '': logger.info( "loading trained and to be finetuned weights from %s", cfg.RESNETS.TO_BE_FINETUNED_WEIGHTS) checkpoint = torch.load( cfg.RESNETS.TO_BE_FINETUNED_WEIGHTS, map_location=lambda storage, loc: storage) if cfg.VGG16.TO_BE_FINETUNED_WEIGHTS != '': logger.info( "loading trained and to be finetuned weights from %s", cfg.VGG16.TO_BE_FINETUNED_WEIGHTS) checkpoint = torch.load( cfg.VGG16.TO_BE_FINETUNED_WEIGHTS, map_location=lambda storage, loc: storage) net_utils_rel.load_ckpt_rel(self, checkpoint['model']) for p in self.Conv_Body.parameters(): p.requires_grad = False for p in self.RPN.parameters(): p.requires_grad = False if not cfg.MODEL.UNFREEZE_DET: for p in self.Box_Head.parameters(): p.requires_grad = False for p in self.Box_Outs.parameters(): p.requires_grad = False if cfg.RESNETS.REL_PRETRAINED_WEIGHTS != '': logger.info("loading rel pretrained weights from %s", cfg.RESNETS.REL_PRETRAINED_WEIGHTS) checkpoint = torch.load(cfg.RESNETS.REL_PRETRAINED_WEIGHTS, map_location=lambda storage, loc: storage) prd_rcnn_state_dict = {} reldn_state_dict = {} for name in checkpoint['model']: if name.find('Prd_RCNN') >= 0: prd_rcnn_state_dict[name] = checkpoint['model'][name] if name.find('RelDN') >= 0: reldn_state_dict[name] = checkpoint['model'][name] net_utils_rel.load_ckpt_rel(self.Prd_RCNN, prd_rcnn_state_dict) if cfg.TRAIN.FREEZE_PRD_CONV_BODY: for p in self.Prd_RCNN.Conv_Body.parameters(): p.requires_grad = False if cfg.TRAIN.FREEZE_PRD_BOX_HEAD: for p in self.Prd_RCNN.Box_Head.parameters(): p.requires_grad = False del reldn_state_dict['RelDN.prd_cls_scores.weight'] del reldn_state_dict['RelDN.prd_cls_scores.bias'] if 'RelDN.prd_sbj_scores.weight' in reldn_state_dict: del reldn_state_dict['RelDN.prd_sbj_scores.weight'] if 'RelDN.prd_sbj_scores.bias' in reldn_state_dict: del reldn_state_dict['RelDN.prd_sbj_scores.bias'] if 'RelDN.prd_obj_scores.weight' in reldn_state_dict: del reldn_state_dict['RelDN.prd_obj_scores.weight'] if 'RelDN.prd_obj_scores.bias' in reldn_state_dict: del reldn_state_dict['RelDN.prd_obj_scores.bias'] if 'RelDN.spt_cls_scores.weight' in reldn_state_dict: del reldn_state_dict['RelDN.spt_cls_scores.weight'] if 'RelDN.spt_cls_scores.bias' in reldn_state_dict: del reldn_state_dict['RelDN.spt_cls_scores.bias'] net_utils_rel.load_ckpt_rel(self.RelDN, reldn_state_dict)
def main(): """Main function""" args = parse_args() print('Called with args:') print(args) if not torch.cuda.is_available(): sys.exit("Need a CUDA device to run the code.") if args.cuda or cfg.NUM_GPUS > 0: cfg.CUDA = True else: raise ValueError("Need Cuda device to run !") if args.dataset == "vrd": cfg.TRAIN.DATASETS = ('vrd_train', ) cfg.TEST.DATASETS = ('vrd_val', ) cfg.MODEL.NUM_CLASSES = 101 cfg.MODEL.NUM_PRD_CLASSES = 70 # exclude background elif args.dataset == "vg_mini": cfg.TRAIN.DATASETS = ('vg_train_mini', ) cfg.MODEL.NUM_CLASSES = 151 cfg.MODEL.NUM_PRD_CLASSES = 50 # exclude background elif args.dataset == "vg": cfg.TRAIN.DATASETS = ('vg_train', ) cfg.MODEL.NUM_CLASSES = 151 cfg.MODEL.NUM_PRD_CLASSES = 50 # exclude background elif args.dataset == "oi_rel": cfg.TRAIN.DATASETS = ('oi_rel_train', ) # cfg.MODEL.NUM_CLASSES = 62 cfg.MODEL.NUM_CLASSES = 58 cfg.MODEL.NUM_PRD_CLASSES = 9 # rel, exclude background elif args.dataset == "oi_rel_mini": cfg.TRAIN.DATASETS = ('oi_rel_train_mini', ) # cfg.MODEL.NUM_CLASSES = 62 cfg.MODEL.NUM_CLASSES = 58 cfg.MODEL.NUM_PRD_CLASSES = 9 # rel, exclude background else: raise ValueError("Unexpected args.dataset: {}".format(args.dataset)) cfg_from_file(args.cfg_file) if args.set_cfgs is not None: cfg_from_list(args.set_cfgs) Generalized_RCNN = importlib.import_module('modeling_rel.' + cfg.MODEL.TYPE).Generalized_RCNN from core.test_engine_rel_mps import get_metrics_det_boxes, get_metrics_gt_boxes ### Adaptively adjust some configs ### original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH original_num_gpus = cfg.NUM_GPUS original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH if args.batch_size is None: args.batch_size = original_batch_size cfg.NUM_GPUS = torch.cuda.device_count() assert (args.batch_size % cfg.NUM_GPUS) == 0, \ 'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS) cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS effective_batch_size = args.iter_size * args.batch_size print('effective_batch_size = batch_size * iter_size = %d * %d' % (args.batch_size, args.iter_size)) print('Adaptive config changes:') print(' effective_batch_size: %d --> %d' % (original_batch_size, effective_batch_size)) print(' NUM_GPUS: %d --> %d' % (original_num_gpus, cfg.NUM_GPUS)) print(' IMS_PER_BATCH: %d --> %d' % (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH)) ### Adjust learning based on batch size change linearly # For iter_size > 1, gradients are `accumulated`, so lr is scaled based # on batch_size instead of effective_batch_size old_base_lr = cfg.SOLVER.BASE_LR cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size print('Adjust BASE_LR linearly according to batch_size change:\n' ' BASE_LR: {} --> {}'.format(old_base_lr, cfg.SOLVER.BASE_LR)) ### Adjust solver steps step_scale = original_batch_size / effective_batch_size old_solver_steps = cfg.SOLVER.STEPS old_max_iter = cfg.SOLVER.MAX_ITER cfg.SOLVER.STEPS = list( map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS)) cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5) print( 'Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n' ' SOLVER.STEPS: {} --> {}\n' ' SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps, cfg.SOLVER.STEPS, old_max_iter, cfg.SOLVER.MAX_ITER)) # Scale FPN rpn_proposals collect size (post_nms_topN) in `collect` function # of `collect_and_distribute_fpn_rpn_proposals.py` # # post_nms_topN = int(cfg[cfg_key].RPN_POST_NMS_TOP_N * cfg.FPN.RPN_COLLECT_SCALE + 0.5) if cfg.FPN.FPN_ON and cfg.MODEL.FASTER_RCNN: cfg.FPN.RPN_COLLECT_SCALE = cfg.TRAIN.IMS_PER_BATCH / original_ims_per_batch print( 'Scale FPN rpn_proposals collect size directly propotional to the change of IMS_PER_BATCH:\n' ' cfg.FPN.RPN_COLLECT_SCALE: {}'.format( cfg.FPN.RPN_COLLECT_SCALE)) if args.num_workers is not None: cfg.DATA_LOADER.NUM_THREADS = args.num_workers print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS) ### Overwrite some solver settings from command line arguments if args.optimizer is not None: cfg.SOLVER.TYPE = args.optimizer if args.lr is not None: cfg.SOLVER.BASE_LR = args.lr if args.lr_decay_gamma is not None: cfg.SOLVER.GAMMA = args.lr_decay_gamma assert_and_infer_cfg() timers = defaultdict(Timer) ### Dataset ### timers['roidb'].tic() roidb, ratio_list, ratio_index, ds = combined_roidb_for_training( cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES) timers['roidb'].toc() roidb_size = len(roidb) logger.info('{:d} roidb entries'.format(roidb_size)) logger.info('Takes %.2f sec(s) to construct roidb', timers['roidb'].average_time) # Effective training sample size for one epoch train_size = roidb_size // args.batch_size * args.batch_size batchSampler = BatchSampler(sampler=MinibatchSampler( ratio_list, ratio_index), batch_size=args.batch_size, drop_last=True) dataset = RoiDataLoader(roidb, cfg.MODEL.NUM_CLASSES, training=True) dataloader = torch.utils.data.DataLoader( dataset, batch_sampler=batchSampler, num_workers=cfg.DATA_LOADER.NUM_THREADS, collate_fn=collate_minibatch) dataiterator = iter(dataloader) ### Model ### maskRCNN = Generalized_RCNN() if cfg.CUDA: maskRCNN.cuda() ### Optimizer ### # record backbone params, i.e., conv_body and box_head params gn_params = [] backbone_bias_params = [] backbone_bias_param_names = [] prd_branch_bias_params = [] prd_branch_bias_param_names = [] backbone_nonbias_params = [] backbone_nonbias_param_names = [] prd_branch_nonbias_params = [] prd_branch_nonbias_param_names = [] for key, value in dict(maskRCNN.named_parameters()).items(): if value.requires_grad: if 'gn' in key: gn_params.append(value) elif 'Conv_Body' in key or 'Box_Head' in key or 'Box_Outs' in key or 'RPN' in key: if 'bias' in key: backbone_bias_params.append(value) backbone_bias_param_names.append(key) else: backbone_nonbias_params.append(value) backbone_nonbias_param_names.append(key) else: if 'bias' in key: prd_branch_bias_params.append(value) prd_branch_bias_param_names.append(key) else: prd_branch_nonbias_params.append(value) prd_branch_nonbias_param_names.append(key) # Learning rate of 0 is a dummy value to be set properly at the start of training params = [{ 'params': backbone_nonbias_params, 'lr': 0, 'weight_decay': cfg.SOLVER.WEIGHT_DECAY }, { 'params': backbone_bias_params, 'lr': 0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1), 'weight_decay': cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0 }, { 'params': prd_branch_nonbias_params, 'lr': 0, 'weight_decay': cfg.SOLVER.WEIGHT_DECAY }, { 'params': prd_branch_bias_params, 'lr': 0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1), 'weight_decay': cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0 }, { 'params': gn_params, 'lr': 0, 'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN }] if cfg.SOLVER.TYPE == "SGD": optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM) elif cfg.SOLVER.TYPE == "Adam": optimizer = torch.optim.Adam(params) ### Load checkpoint if args.load_ckpt: load_name = args.load_ckpt logging.info("loading checkpoint %s", load_name) checkpoint = torch.load(load_name, map_location=lambda storage, loc: storage) net_utils_rel.load_ckpt_rel(maskRCNN, checkpoint['model']) if args.resume: args.start_step = checkpoint['step'] + 1 if 'train_size' in checkpoint: # For backward compatibility if checkpoint['train_size'] != train_size: print( 'train_size value: %d different from the one in checkpoint: %d' % (train_size, checkpoint['train_size'])) # reorder the params in optimizer checkpoint's params_groups if needed # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint) # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1. # However it's fixed on master. optimizer.load_state_dict(checkpoint['optimizer']) misc_utils.load_optimizer_state_dict(optimizer, checkpoint['optimizer']) del checkpoint torch.cuda.empty_cache() if args.load_detectron: #TODO resume for detectron weights (load sgd momentum values) logging.info("loading Detectron weights %s", args.load_detectron) load_detectron_weight(maskRCNN, args.load_detectron) # lr = optimizer.param_groups[0]['lr'] # lr of non-bias parameters, for commmand line outputs. lr = optimizer.param_groups[2][ 'lr'] # lr of non-backbone parameters, for commmand line outputs. backbone_lr = optimizer.param_groups[0][ 'lr'] # lr of backbone parameters, for commmand line outputs. device_ids = list(range(torch.cuda.device_count())) maskRCNN_one_gpu = mynn.DataParallel(maskRCNN, cpu_keywords=['im_info', 'roidb'], minibatch=True, device_ids=[device_ids[0]]) maskRCNN = mynn.DataParallel(maskRCNN, cpu_keywords=['im_info', 'roidb'], minibatch=True) ### Training Setups ### args.run_name = misc_utils.get_run_name( ) + '_' + args.exp + '_' + '_step_with_prd_cls_v' + str(cfg.MODEL.SUBTYPE) output_dir = misc_utils.get_output_dir(args, args.run_name) args.cfg_filename = os.path.basename(args.cfg_file) if not args.no_save: if not os.path.exists(output_dir): os.makedirs(output_dir) blob = {'cfg': yaml.dump(cfg), 'args': args} with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f: pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL) if args.use_tfboard: from tensorboardX import SummaryWriter # Set the Tensorboard logger tblogger = SummaryWriter(output_dir) ### Training Loop ### maskRCNN.train() # CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS) CHECKPOINT_PERIOD = ds.len // effective_batch_size # Set index for decay steps decay_steps_ind = None for i in range(1, len(cfg.SOLVER.STEPS)): if cfg.SOLVER.STEPS[i] >= args.start_step: decay_steps_ind = i break if decay_steps_ind is None: decay_steps_ind = len(cfg.SOLVER.STEPS) training_stats = TrainingStats( args, args.disp_interval, tblogger if args.use_tfboard and not args.no_save else None) # metrics = get_metrics_gt_boxes(maskRCNN, timers, cfg.TEST.DATASETS[0]) # tblogger.add_scalar(args.dataset + '_r@100', metrics, 0) try: logger.info('Training starts !') step = args.start_step for step in range(args.start_step, cfg.SOLVER.MAX_ITER): # Warm up if step < cfg.SOLVER.WARM_UP_ITERS: method = cfg.SOLVER.WARM_UP_METHOD if method == 'constant': warmup_factor = cfg.SOLVER.WARM_UP_FACTOR elif method == 'linear': alpha = step / cfg.SOLVER.WARM_UP_ITERS warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 - alpha) + alpha else: raise KeyError( 'Unknown SOLVER.WARM_UP_METHOD: {}'.format(method)) lr_new = cfg.SOLVER.BASE_LR * warmup_factor net_utils_rel.update_learning_rate_rel(optimizer, lr, lr_new) # lr = optimizer.param_groups[0]['lr'] lr = optimizer.param_groups[2]['lr'] backbone_lr = optimizer.param_groups[0]['lr'] assert lr == lr_new elif step == cfg.SOLVER.WARM_UP_ITERS: net_utils_rel.update_learning_rate_rel(optimizer, lr, cfg.SOLVER.BASE_LR) # lr = optimizer.param_groups[0]['lr'] lr = optimizer.param_groups[2]['lr'] backbone_lr = optimizer.param_groups[0]['lr'] assert lr == cfg.SOLVER.BASE_LR # Learning rate decay if decay_steps_ind < len(cfg.SOLVER.STEPS) and \ step == cfg.SOLVER.STEPS[decay_steps_ind]: logger.info('Decay the learning on step %d', step) lr_new = lr * cfg.SOLVER.GAMMA net_utils_rel.update_learning_rate_rel(optimizer, lr, lr_new) # lr = optimizer.param_groups[0]['lr'] lr = optimizer.param_groups[2]['lr'] backbone_lr = optimizer.param_groups[0]['lr'] assert lr == lr_new decay_steps_ind += 1 training_stats.IterTic() optimizer.zero_grad() for inner_iter in range(args.iter_size): try: input_data = next(dataiterator) except StopIteration: dataiterator = iter(dataloader) input_data = next(dataiterator) for key in input_data: if key != 'roidb': # roidb is a list of ndarrays with inconsistent length input_data[key] = list(map(Variable, input_data[key])) net_outputs = maskRCNN(**input_data) training_stats.UpdateIterStats(net_outputs, inner_iter) loss = net_outputs['total_loss'] loss.backward() optimizer.step() training_stats.IterToc() if step == args.start_step: for n, p in maskRCNN.named_parameters(): if p.requires_grad == True and p.grad is None: logger.warning('The module was defined but no-use!') logger.warning(n) training_stats.LogIterStats(step, lr, backbone_lr) if int(step + 1) % CHECKPOINT_PERIOD == 0: save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer) metrics = get_metrics_gt_boxes(maskRCNN_one_gpu, timers, cfg.TEST.DATASETS[0]) maskRCNN.train() tblogger.add_scalar(args.dataset + '_metrics', metrics, step) # ---- Training ends ---- # Save last checkpoint save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer) # metrics = get_metrics(maskRCNN, timers, cfg.TEST.DATASETS) # tblogger.add_scalar(args.dataset + '_r@100', metrics, step) except (RuntimeError, KeyboardInterrupt): del dataiterator logger.info('Save ckpt on exception ...') save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer) logger.info('Save ckpt done.') stack_trace = traceback.format_exc() print(stack_trace) finally: if args.use_tfboard and not args.no_save: tblogger.close()
def load_detector_weights(self, weight_name): logger.info("loading pretrained weights from %s", weight_name) checkpoint = torch.load(weight_name, map_location=lambda storage, loc: storage) net_utils_rel.load_ckpt_rel(self, checkpoint['model'])
def _init_modules(self): # VGG16 imagenet pretrained model is initialized in VGG16.py if cfg.RESNETS.IMAGENET_PRETRAINED_WEIGHTS != '': logger.info("Loading pretrained weights from %s", cfg.RESNETS.IMAGENET_PRETRAINED_WEIGHTS) resnet_utils.load_pretrained_imagenet_weights(self) for p in self.Conv_Body.parameters(): p.requires_grad = False if cfg.RESNETS.VRD_PRETRAINED_WEIGHTS != '': self.load_detector_weights(cfg.RESNETS.VRD_PRETRAINED_WEIGHTS) if cfg.VGG16.VRD_PRETRAINED_WEIGHTS != '': self.load_detector_weights(cfg.VGG16.VRD_PRETRAINED_WEIGHTS) if cfg.RESNETS.VG_PRETRAINED_WEIGHTS != '': self.load_detector_weights(cfg.RESNETS.VG_PRETRAINED_WEIGHTS) if cfg.VGG16.VG_PRETRAINED_WEIGHTS != '': self.load_detector_weights(cfg.VGG16.VG_PRETRAINED_WEIGHTS) if cfg.RESNETS.OI_REL_PRETRAINED_WEIGHTS != '': self.load_detector_weights(cfg.RESNETS.OI_REL_PRETRAINED_WEIGHTS) if cfg.VGG16.OI_REL_PRETRAINED_WEIGHTS != '': self.load_detector_weights(cfg.VGG16.OI_REL_PRETRAINED_WEIGHTS) if cfg.RESNETS.VRD_PRD_PRETRAINED_WEIGHTS != '' or cfg.VGG16.VRD_PRD_PRETRAINED_WEIGHTS != '' or \ cfg.RESNETS.VG_PRD_PRETRAINED_WEIGHTS != '' or cfg.VGG16.VG_PRD_PRETRAINED_WEIGHTS != '' or \ cfg.RESNETS.OI_REL_PRD_PRETRAINED_WEIGHTS != '' or cfg.VGG16.OI_REL_PRD_PRETRAINED_WEIGHTS != '': if cfg.RESNETS.VRD_PRD_PRETRAINED_WEIGHTS != '': logger.info("loading prd pretrained weights from %s", cfg.RESNETS.VRD_PRD_PRETRAINED_WEIGHTS) checkpoint = torch.load( cfg.RESNETS.VRD_PRD_PRETRAINED_WEIGHTS, map_location=lambda storage, loc: storage) if cfg.VGG16.VRD_PRD_PRETRAINED_WEIGHTS != '': logger.info("loading prd pretrained weights from %s", cfg.VGG16.VRD_PRD_PRETRAINED_WEIGHTS) checkpoint = torch.load( cfg.VGG16.VRD_PRD_PRETRAINED_WEIGHTS, map_location=lambda storage, loc: storage) if cfg.RESNETS.VG_PRD_PRETRAINED_WEIGHTS != '': logger.info("loading prd pretrained weights from %s", cfg.RESNETS.VG_PRD_PRETRAINED_WEIGHTS) checkpoint = torch.load( cfg.RESNETS.VG_PRD_PRETRAINED_WEIGHTS, map_location=lambda storage, loc: storage) if cfg.VGG16.VG_PRD_PRETRAINED_WEIGHTS != '': logger.info("loading prd pretrained weights from %s", cfg.VGG16.VG_PRD_PRETRAINED_WEIGHTS) checkpoint = torch.load( cfg.VGG16.VG_PRD_PRETRAINED_WEIGHTS, map_location=lambda storage, loc: storage) if cfg.RESNETS.OI_REL_PRD_PRETRAINED_WEIGHTS != '': logger.info("loading prd pretrained weights from %s", cfg.RESNETS.OI_REL_PRD_PRETRAINED_WEIGHTS) checkpoint = torch.load( cfg.RESNETS.OI_REL_PRD_PRETRAINED_WEIGHTS, map_location=lambda storage, loc: storage) if cfg.VGG16.OI_REL_PRD_PRETRAINED_WEIGHTS != '': logger.info("loading prd pretrained weights from %s", cfg.VGG16.OI_REL_PRD_PRETRAINED_WEIGHTS) checkpoint = torch.load( cfg.VGG16.OI_REL_PRD_PRETRAINED_WEIGHTS, map_location=lambda storage, loc: storage) self.Box_Head_sg.heads[0].weight.data.copy_( checkpoint['model']['Box_Head.heads.0.weight']) self.Box_Head_sg.heads[0].bias.data.copy_( checkpoint['model']['Box_Head.heads.0.bias']) self.Box_Head_sg.heads[3].weight.data.copy_( checkpoint['model']['Box_Head.heads.3.weight']) self.Box_Head_sg.heads[3].bias.data.copy_( checkpoint['model']['Box_Head.heads.3.bias']) self.Box_Head_prd.heads[0].weight.data.copy_( checkpoint['model']['Box_Head.heads.0.weight']) self.Box_Head_prd.heads[0].bias.data.copy_( checkpoint['model']['Box_Head.heads.0.bias']) self.Box_Head_prd.heads[3].weight.data.copy_( checkpoint['model']['Box_Head.heads.3.weight']) self.Box_Head_prd.heads[3].bias.data.copy_( checkpoint['model']['Box_Head.heads.3.bias']) if cfg.RESNETS.TO_BE_FINETUNED_WEIGHTS != '' or cfg.VGG16.TO_BE_FINETUNED_WEIGHTS != '': if cfg.RESNETS.TO_BE_FINETUNED_WEIGHTS != '': logger.info( "loading trained and to be finetuned weights from %s", cfg.RESNETS.TO_BE_FINETUNED_WEIGHTS) checkpoint = torch.load( cfg.RESNETS.TO_BE_FINETUNED_WEIGHTS, map_location=lambda storage, loc: storage) if cfg.VGG16.TO_BE_FINETUNED_WEIGHTS != '': logger.info( "loading trained and to be finetuned weights from %s", cfg.VGG16.TO_BE_FINETUNED_WEIGHTS) checkpoint = torch.load( cfg.VGG16.TO_BE_FINETUNED_WEIGHTS, map_location=lambda storage, loc: storage) net_utils_rel.load_ckpt_rel(self, checkpoint['model']) for p in self.Conv_Body.parameters(): p.requires_grad = False for p in self.RPN.parameters(): p.requires_grad = False if not cfg.MODEL.UNFREEZE_DET: for p in self.Box_Head.parameters(): p.requires_grad = False for p in self.Box_Outs.parameters(): p.requires_grad = False