def main(): """Main function""" args = parse_args() print('Called with args:') print(args) if not torch.cuda.is_available(): sys.exit("Need a CUDA device to run the code.") if args.cuda or cfg.NUM_GPUS > 0: cfg.CUDA = True else: raise ValueError("Need Cuda device to run !") cfg.DATASET = args.dataset if args.dataset == "vrd": cfg.TRAIN.DATASETS = ('vrd_train',) cfg.MODEL.NUM_CLASSES = 101 cfg.MODEL.NUM_PRD_CLASSES = 70 # exclude background elif args.dataset == "vg": cfg.TRAIN.DATASETS = ('vg_train',) cfg.MODEL.NUM_CLASSES = 151 cfg.MODEL.NUM_PRD_CLASSES = 50 # exclude background elif args.dataset == "vg80k": cfg.TRAIN.DATASETS = ('vg80k_train',) cfg.MODEL.NUM_CLASSES = 53305 # includes background cfg.MODEL.NUM_PRD_CLASSES = 29086 # excludes background elif args.dataset == "gvqa20k": cfg.TRAIN.DATASETS = ('gvqa20k_train',) cfg.MODEL.NUM_CLASSES = 1704 # includes background cfg.MODEL.NUM_PRD_CLASSES = 310 # exclude background elif args.dataset == "gvqa10k": cfg.TRAIN.DATASETS = ('gvqa10k_train',) cfg.MODEL.NUM_CLASSES = 1704 # includes background cfg.MODEL.NUM_PRD_CLASSES = 310 # exclude background elif args.dataset == "gvqa": cfg.TRAIN.DATASETS = ('gvqa_train',) cfg.MODEL.NUM_CLASSES = 1704 # includes background cfg.MODEL.NUM_PRD_CLASSES = 310 # exclude background else: raise ValueError("Unexpected args.dataset: {}".format(args.dataset)) cfg_from_file(args.cfg_file) if args.set_cfgs is not None: cfg_from_list(args.set_cfgs) ### Adaptively adjust some configs ### original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH original_num_gpus = cfg.NUM_GPUS if args.batch_size is None: args.batch_size = original_batch_size cfg.NUM_GPUS = torch.cuda.device_count() assert (args.batch_size % cfg.NUM_GPUS) == 0, \ 'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS) cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS effective_batch_size = args.iter_size * args.batch_size print('effective_batch_size = batch_size * iter_size = %d * %d' % (args.batch_size, args.iter_size)) print('Adaptive config changes:') print(' effective_batch_size: %d --> %d' % (original_batch_size, effective_batch_size)) print(' NUM_GPUS: %d --> %d' % (original_num_gpus, cfg.NUM_GPUS)) print(' IMS_PER_BATCH: %d --> %d' % (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH)) ### Adjust learning based on batch size change linearly # For iter_size > 1, gradients are `accumulated`, so lr is scaled based # on batch_size instead of effective_batch_size old_base_lr = cfg.SOLVER.BASE_LR cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size print('Adjust BASE_LR linearly according to batch_size change:\n' ' BASE_LR: {} --> {}'.format(old_base_lr, cfg.SOLVER.BASE_LR)) ### Adjust solver steps step_scale = original_batch_size / effective_batch_size old_solver_steps = cfg.SOLVER.STEPS old_max_iter = cfg.SOLVER.MAX_ITER cfg.SOLVER.STEPS = list(map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS)) cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5) print('Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n' ' SOLVER.STEPS: {} --> {}\n' ' SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps, cfg.SOLVER.STEPS, old_max_iter, cfg.SOLVER.MAX_ITER)) if args.num_workers is not None: cfg.DATA_LOADER.NUM_THREADS = args.num_workers print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS) ### Overwrite some solver settings from command line arguments if args.optimizer is not None: cfg.SOLVER.TYPE = args.optimizer if args.lr is not None: cfg.SOLVER.BASE_LR = args.lr if args.lr_decay_gamma is not None: cfg.SOLVER.GAMMA = args.lr_decay_gamma assert_and_infer_cfg() timers = defaultdict(Timer) ### Dataset ### timers['roidb'].tic() roidb, ratio_list, ratio_index = combined_roidb_for_training( cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES) timers['roidb'].toc() roidb_size = len(roidb) logger.info('{:d} roidb entries'.format(roidb_size)) logger.info('Takes %.2f sec(s) to construct roidb', timers['roidb'].average_time) # Effective training sample size for one epoch train_size = roidb_size // args.batch_size * args.batch_size batchSampler = BatchSampler( sampler=MinibatchSampler(ratio_list, ratio_index), batch_size=args.batch_size, drop_last=True ) dataset = RoiDataLoader( roidb, cfg.MODEL.NUM_CLASSES, training=True) dataloader = torch.utils.data.DataLoader( dataset, batch_sampler=batchSampler, num_workers=cfg.DATA_LOADER.NUM_THREADS, collate_fn=collate_minibatch) dataiterator = iter(dataloader) ### Model ### maskRCNN = Generalized_RCNN() if cfg.CUDA: maskRCNN.cuda() ### Optimizer ### # record backbone params, i.e., conv_body and box_head params gn_params = [] backbone_bias_params = [] backbone_bias_param_names = [] prd_branch_bias_params = [] prd_branch_bias_param_names = [] backbone_nonbias_params = [] backbone_nonbias_param_names = [] prd_branch_nonbias_params = [] prd_branch_nonbias_param_names = [] classifier_params = [] classifier_param_names = [] for key, value in dict(maskRCNN.named_parameters()).items(): if value.requires_grad: if 'gn' in key: gn_params.append(value) elif 'Conv_Body' in key or 'Box_Head' in key or 'Box_Outs' in key or 'RPN' in key: if 'bias' in key: backbone_bias_params.append(value) backbone_bias_param_names.append(key) else: backbone_nonbias_params.append(value) backbone_nonbias_param_names.append(key) #elif 'classifier' in key: # classifier_params.append(value) # classifier_param_names.append(value) else: if 'bias' in key: prd_branch_bias_params.append(value) prd_branch_bias_param_names.append(key) else: prd_branch_nonbias_params.append(value) prd_branch_nonbias_param_names.append(key) # Learning rate of 0 is a dummy value to be set properly at the start of training params = [ {'params': backbone_nonbias_params, 'lr': 0, 'weight_decay': cfg.SOLVER.WEIGHT_DECAY}, {'params': backbone_bias_params, 'lr': 0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1), 'weight_decay': cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0}, {'params': prd_branch_nonbias_params, 'lr': 0, 'weight_decay': cfg.SOLVER.WEIGHT_DECAY}, {'params': prd_branch_bias_params, 'lr': 0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1), 'weight_decay': cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0}, {'params': gn_params, 'lr': 0, 'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN}#, #{'params': classifier_params, # 'lr': 0.1, 'momentum': 0.9, 'weight_decay': 0.0005} ] # print('Initializing model optimizer.') # if cfg.SOLVER.TYPE == "SGD": # optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM) # elif cfg.SOLVER.TYPE == "Adam": # optimizer = torch.optim.Adam(params) print('Initializing model and classifier optimizers.') #classifier_optim_param = {'lr': 0.1, 'momentum': 0.9, 'weight_decay': 0.0005} #params.append({'params': maskRCNN.classifier.parameters(), # 'lr': classifier_optim_param['lr']), # 'momentum': classifier_optim_param['momentum'], # 'weight_decay': classifier_optim_param['weight_decay']}) #params.append({'params': maskRCNN.prd_classifier.parameters(), # 'lr': classifier_optim_param['lr'], # 'momentum': classifier_optim_param['momentum'], # 'weight_decay': classifier_optim_param['weight_decay']}) if cfg.MODEL.MEMORY_MODULE_STAGE == 1: step_size = 10 elif cfg.MODEL.MEMORY_MODULE_STAGE == 2: step_size = 20 else: raise NotImplementedError scheduler_params = {'step_size': step_size, 'gamma': 0.1} optimizer, optimizer_scheduler = init_optimizers(params, scheduler_params) criterion_optimizer, criterion_optimizer_scheduler = None, None if cfg.MODEL.MEMORY_MODULE_STAGE == 2: print('Initializing criterion optimizer.') feat_loss_optim_param = {'lr': 0.01, 'momentum': 0.9, 'weight_decay': 0.0005} optim_params = feat_loss_optim_param optim_params = [{'params': maskRCNN.feature_loss_sbj_obj.parameters(), 'lr': optim_params['lr'], 'momentum': optim_params['momentum'], 'weight_decay': optim_params['weight_decay']}, {'params': maskRCNN.feature_loss_prd.parameters(), 'lr': optim_params['lr'], 'momentum': optim_params['momentum'], 'weight_decay': optim_params['weight_decay']} ] # Initialize criterion optimizer and scheduler criterion_optimizer, criterion_optimizer_scheduler = init_optimizers(optim_params, scheduler_params) if cfg.MODEL.MEMORY_MODULE_STAGE == 2: weights_path = 'Outputs/e2e_relcnn_VGG16_8_epochs_gvqa_y_loss_only_1_gpu/gvqa/Feb07-10-55-03_login104-09_step_with_prd_cls_v3/ckpt/model_step1439.pth' weights = torch.load(weights_path) #print('weights', weights['model'].keys()) #print(maskRCNN.state_dict().keys()) maskRCNN.load_state_dict(weights['model'], strict=False) #print(list(maskRCNN.parameters())) #print(maskRCNN.state_dict().keys()) #print(maskRCNN.state_dict()['prd_classifier.fc_hallucinator.weight'] == weights['model']['prd_classifier.fc.weight']) #print(torch.all(torch.eq(maskRCNN.state_dict()['prd_classifier.fc_hallucinator.weight'], weights['model']['prd_classifier.fc.weight']))) #print(torch.all(torch.eq(maskRCNN.state_dict()['Box_Head.heads.0.weight'], weights['model']['Box_Head.heads.0.weight']))) #exit() ### Load checkpoint if args.load_ckpt: load_name = args.load_ckpt logging.info("loading checkpoint %s", load_name) checkpoint = torch.load(load_name, map_location=lambda storage, loc: storage) net_utils.load_ckpt(maskRCNN, checkpoint['model']) if args.resume: args.start_step = checkpoint['step'] + 1 if 'train_size' in checkpoint: # For backward compatibility if checkpoint['train_size'] != train_size: print('train_size value: %d different from the one in checkpoint: %d' % (train_size, checkpoint['train_size'])) # reorder the params in optimizer checkpoint's params_groups if needed # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint) # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1. # However it's fixed on master. # optimizer.load_state_dict(checkpoint['optimizer']) misc_utils.load_optimizer_state_dict(optimizer, checkpoint['optimizer']) del checkpoint torch.cuda.empty_cache() if args.load_detectron: #TODO resume for detectron weights (load sgd momentum values) logging.info("loading Detectron weights %s", args.load_detectron) load_detectron_weight(maskRCNN, args.load_detectron) lr = optimizer.param_groups[2]['lr'] # lr of non-backbone parameters, for commmand line outputs. backbone_lr = optimizer.param_groups[0]['lr'] # lr of backbone parameters, for commmand line outputs. maskRCNN = mynn.DataParallel(maskRCNN, cpu_keywords=['im_info', 'roidb'], minibatch=True) ### Training Setups ### args.run_name = misc_utils.get_run_name() + '_step_with_prd_cls_v' + str(cfg.MODEL.SUBTYPE) output_dir = misc_utils.get_output_dir(args, args.run_name) args.cfg_filename = os.path.basename(args.cfg_file) if not args.no_save: if not os.path.exists(output_dir): os.makedirs(output_dir) blob = {'cfg': yaml.dump(cfg), 'args': args} with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f: pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL) if args.use_tfboard: from tensorboardX import SummaryWriter # Set the Tensorboard logger tblogger = SummaryWriter(output_dir) ### Training Loop ### maskRCNN.train() # CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS) # CHECKPOINT_PERIOD = cfg.SOLVER.MAX_ITER / cfg.TRAIN.SNAPSHOT_FREQ CHECKPOINT_PERIOD = 200000 # Set index for decay steps decay_steps_ind = None for i in range(1, len(cfg.SOLVER.STEPS)): if cfg.SOLVER.STEPS[i] >= args.start_step: decay_steps_ind = i break if decay_steps_ind is None: decay_steps_ind = len(cfg.SOLVER.STEPS) training_stats = TrainingStats( args, args.disp_interval, tblogger if args.use_tfboard and not args.no_save else None) try: logger.info('Training starts !') step = args.start_step for step in range(args.start_step, cfg.SOLVER.MAX_ITER): # optimizer_scheduler.step() # if criterion_optimizer: # criterion_optimizer_scheduler.step() # Warm up if step < cfg.SOLVER.WARM_UP_ITERS: method = cfg.SOLVER.WARM_UP_METHOD if method == 'constant': warmup_factor = cfg.SOLVER.WARM_UP_FACTOR elif method == 'linear': alpha = step / cfg.SOLVER.WARM_UP_ITERS warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 - alpha) + alpha else: raise KeyError('Unknown SOLVER.WARM_UP_METHOD: {}'.format(method)) lr_new = cfg.SOLVER.BASE_LR * warmup_factor net_utils.update_learning_rate_rel(optimizer, lr, lr_new) lr = optimizer.param_groups[2]['lr'] backbone_lr = optimizer.param_groups[0]['lr'] assert lr == lr_new elif step == cfg.SOLVER.WARM_UP_ITERS: net_utils.update_learning_rate_rel(optimizer, lr, cfg.SOLVER.BASE_LR) lr = optimizer.param_groups[2]['lr'] backbone_lr = optimizer.param_groups[0]['lr'] assert lr == cfg.SOLVER.BASE_LR # Learning rate decay if decay_steps_ind < len(cfg.SOLVER.STEPS) and \ step == cfg.SOLVER.STEPS[decay_steps_ind]: logger.info('Decay the learning on step %d', step) lr_new = lr * cfg.SOLVER.GAMMA net_utils.update_learning_rate_rel(optimizer, lr, lr_new) if criterion_optimizer: net_utils.update_learning_rate_rel(criterion_optimizer, lr, lr_new) lr = optimizer.param_groups[2]['lr'] backbone_lr = optimizer.param_groups[0]['lr'] assert lr == lr_new decay_steps_ind += 1 training_stats.IterTic() optimizer.zero_grad() if criterion_optimizer: criterion_optimizer.zero_grad() for inner_iter in range(args.iter_size): try: input_data = next(dataiterator) except StopIteration: dataiterator = iter(dataloader) input_data = next(dataiterator) #print('input_data', [torch.isnan(x) for x in input_data.values()]) for key in input_data: if key != 'roidb': # roidb is a list of ndarrays with inconsistent length input_data[key] = list(map(Variable, input_data[key])) #with autograd.detect_anomaly(): net_outputs = maskRCNN(**input_data) training_stats.UpdateIterStats(net_outputs, inner_iter) loss = net_outputs['total_loss'] loss.backward() optimizer.step() if criterion_optimizer: criterion_optimizer.step() training_stats.IterToc() training_stats.LogIterStats(step, lr, backbone_lr) if (step+1) % CHECKPOINT_PERIOD == 0: print('Saving Checkpoint..') save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer) # ---- Training ends ---- # Save last checkpoint save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer) except Exception as e: del dataiterator logger.info('Save ckpt on exception ...') save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer) logger.info('Save ckpt done.') stack_trace = traceback.format_exc() print(stack_trace) finally: if args.use_tfboard and not args.no_save: tblogger.close()
def main(): args = parse_args() print('Called with args:') print(args) cfg = set_configs(args) timers = defaultdict(Timer) ### -------------------------------------------------------------------------------- ### Dataset Training ### ### -------------------------------------------------------------------------------- timers['roidb_training'].tic() roidb_training, ratio_list_training, ratio_index_training, category_to_id_map, prd_category_to_id_map = combined_roidb_for_training( cfg.TRAIN.DATASETS) timers['roidb_training'].toc() roidb_size_training = len(roidb_training) logger.info('{:d} training roidb entries'.format(roidb_size_training)) logger.info('Takes %.2f sec(s) to construct training roidb', timers['roidb_training'].average_time) batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH dataset_training = RoiDataLoader(roidb_training, cfg.MODEL.NUM_CLASSES, training=True, dataset=cfg.TRAIN.DATASETS) dataloader_training = torch.utils.data.DataLoader( dataset_training, batch_size=batch_size, num_workers=cfg.DATA_LOADER.NUM_THREADS, collate_fn=collate_minibatch, shuffle=True, drop_last=True) dataiterator_training = iter(dataloader_training) ### -------------------------------------------------------------------------------- ### Dataset Validation ### ### -------------------------------------------------------------------------------- timers['roidb_val'].tic() roidb_val, ratio_list_val, ratio_index_val, _, _ = combined_roidb_for_training( cfg.VAL.DATASETS) timers['roidb_val'].toc() roidb_size_val = len(roidb_val) logger.info('{:d} val roidb entries'.format(roidb_size_val)) logger.info('Takes %.2f sec(s) to construct val roidb', timers['roidb_val'].average_time) dataset_val = RoiDataLoader(roidb_val, cfg.MODEL.NUM_CLASSES, training=False, dataset=cfg.VAL.DATASETS) dataloader_val = torch.utils.data.DataLoader( dataset_val, batch_size=batch_size, num_workers=cfg.DATA_LOADER.NUM_THREADS, collate_fn=collate_minibatch, drop_last=True) ### -------------------------------------------------------------------------------- ### Dataset Test ### ### -------------------------------------------------------------------------------- timers['roidb_test'].tic() roidb_test, ratio_list_test, ratio_index_test, _, _ = combined_roidb_for_training( cfg.TEST.DATASETS) timers['roidb_test'].toc() roidb_size_test = len(roidb_test) logger.info('{:d} test roidb entries'.format(roidb_size_test)) logger.info('Takes %.2f sec(s) to construct test roidb', timers['roidb_test'].average_time) dataset_test = RoiDataLoader(roidb_test, cfg.MODEL.NUM_CLASSES, training=False, dataset=cfg.TEST.DATASETS) dataloader_test = torch.utils.data.DataLoader( dataset_test, batch_size=batch_size, num_workers=cfg.DATA_LOADER.NUM_THREADS, collate_fn=collate_minibatch, drop_last=True) ### -------------------------------------------------------------------------------- ### Dataset Unseen ### ### -------------------------------------------------------------------------------- if args.dataset == 'vhico': timers['roidb_unseen'].tic() roidb_unseen, ratio_list_unseen, ratio_index_unseen, _, _ = combined_roidb_for_training( cfg.UNSEEN.DATASETS) timers['roidb_unseen'].toc() roidb_size_unseen = len(roidb_unseen) logger.info('{:d} test unseen roidb entries'.format(roidb_size_unseen)) logger.info('Takes %.2f sec(s) to construct test roidb', timers['roidb_unseen'].average_time) dataset_unseen = RoiDataLoader(roidb_unseen, cfg.MODEL.NUM_CLASSES, training=False, dataset=cfg.UNSEEN.DATASETS) dataloader_unseen = torch.utils.data.DataLoader( dataset_unseen, batch_size=batch_size, num_workers=cfg.DATA_LOADER.NUM_THREADS, collate_fn=collate_minibatch, drop_last=True) ### -------------------------------------------------------------------------------- ### Model ### ### -------------------------------------------------------------------------------- maskRCNN = Generalized_RCNN(category_to_id_map=category_to_id_map, prd_category_to_id_map=prd_category_to_id_map, args=args) if cfg.CUDA: maskRCNN.cuda() ### -------------------------------------------------------------------------------- ### Optimizer ### # record backbone params, i.e., conv_body and box_head params ### -------------------------------------------------------------------------------- gn_params = [] backbone_bias_params = [] backbone_bias_param_names = [] prd_branch_bias_params = [] prd_branch_bias_param_names = [] backbone_nonbias_params = [] backbone_nonbias_param_names = [] prd_branch_nonbias_params = [] prd_branch_nonbias_param_names = [] for key, value in dict(maskRCNN.named_parameters()).items(): if value.requires_grad: if 'gn' in key: gn_params.append(value) elif 'Conv_Body' in key or 'Box_Head' in key or 'Box_Outs' in key or 'RPN' in key: if 'bias' in key: backbone_bias_params.append(value) backbone_bias_param_names.append(key) else: backbone_nonbias_params.append(value) backbone_nonbias_param_names.append(key) else: if 'bias' in key: prd_branch_bias_params.append(value) prd_branch_bias_param_names.append(key) else: prd_branch_nonbias_params.append(value) prd_branch_nonbias_param_names.append(key) # Learning rate of 0 is a dummy value to be set properly at the start of training params = [{ 'params': backbone_nonbias_params, 'lr': 0, 'weight_decay': cfg.SOLVER.WEIGHT_DECAY }, { 'params': backbone_bias_params, 'lr': 0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1), 'weight_decay': cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0 }, { 'params': prd_branch_nonbias_params, 'lr': 0, 'weight_decay': cfg.SOLVER.WEIGHT_DECAY }, { 'params': prd_branch_bias_params, 'lr': 0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1), 'weight_decay': cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0 }, { 'params': gn_params, 'lr': 0, 'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN }] if cfg.SOLVER.TYPE == "SGD": optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM) elif cfg.SOLVER.TYPE == "Adam": optimizer = torch.optim.Adam(params) ### -------------------------------------------------------------------------------- ### Load checkpoint ### -------------------------------------------------------------------------------- if args.load_ckpt: load_name = args.load_ckpt logging.info("loading checkpoint %s", load_name) checkpoint = torch.load(load_name, map_location=lambda storage, loc: storage) net_utils.load_ckpt(maskRCNN, checkpoint['model']) print( '--------------------------------------------------------------------------------' ) print('loading checkpoint %s' % load_name) print( '--------------------------------------------------------------------------------' ) if args.resume: print('resume') args.start_step = checkpoint['step'] + 1 misc_utils.load_optimizer_state_dict(optimizer, checkpoint['optimizer']) del checkpoint torch.cuda.empty_cache() else: print('args.load_ckpt', args.load_ckpt) lr = optimizer.param_groups[2][ 'lr'] # lr of non-backbone parameters, for commmand line outputs. backbone_lr = optimizer.param_groups[0][ 'lr'] # lr of backbone parameters, for commmand line outputs. maskRCNN = mynn.DataParallel(maskRCNN, cpu_keywords=['im_info', 'roidb'], minibatch=True) ### -------------------------------------------------------------------------------- ### Training Setups ### ### -------------------------------------------------------------------------------- args.run_name = args.out_dir output_dir = misc_utils.get_output_dir(args, args.out_dir) args.cfg_filename = os.path.basename(args.cfg_file) if not args.no_save: if not os.path.exists(output_dir): os.makedirs(output_dir) blob = {'cfg': yaml.dump(cfg), 'args': args} with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f: pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL) if args.use_tfboard: from tensorboardX import SummaryWriter tblogger = SummaryWriter(output_dir) ### -------------------------------------------------------------------------------- ### Training Loop ### ### -------------------------------------------------------------------------------- maskRCNN.train() # Set index for decay steps decay_steps_ind = None for i in range(1, len(cfg.SOLVER.STEPS)): if cfg.SOLVER.STEPS[i] >= args.start_step: decay_steps_ind = i break if decay_steps_ind is None: decay_steps_ind = len(cfg.SOLVER.STEPS) training_stats = TrainingStats( args, args.disp_interval, tblogger if args.use_tfboard and not args.no_save else None, True) val_stats = ValStats( args, args.disp_interval, tblogger if args.use_tfboard and not args.no_save else None, False) test_stats = TestStats( args, args.disp_interval, tblogger if args.use_tfboard and not args.no_save else None, False) best_total_loss = np.inf best_eval_result = 0 ### -------------------------------------------------------------------------------- ### EVAL ### ### -------------------------------------------------------------------------------- if cfg.EVAL_SUBSET == 'unseen': print('testing unseen ...') is_best, best_eval_result = run_eval(args, cfg, maskRCNN, dataloader_unseen, step=0, output_dir=output_dir, test_stats=test_stats, best_eval_result=best_eval_result, eval_subset=cfg.EVAL_SUBSET) return elif cfg.EVAL_SUBSET == 'test': print('testing ...') is_best, best_eval_result = run_eval(args, cfg, maskRCNN, dataloader_test, step=0, output_dir=output_dir, test_stats=test_stats, best_eval_result=best_eval_result, eval_subset=cfg.EVAL_SUBSET) return ### -------------------------------------------------------------------------------- ### TRAIN ### ### -------------------------------------------------------------------------------- try: logger.info('Training starts !') step = args.start_step for step in range(args.start_step, cfg.SOLVER.MAX_ITER): # Warm up if step < cfg.SOLVER.WARM_UP_ITERS: method = cfg.SOLVER.WARM_UP_METHOD if method == 'constant': warmup_factor = cfg.SOLVER.WARM_UP_FACTOR elif method == 'linear': alpha = step / cfg.SOLVER.WARM_UP_ITERS warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 - alpha) + alpha else: raise KeyError( 'Unknown SOLVER.WARM_UP_METHOD: {}'.format(method)) lr_new = cfg.SOLVER.BASE_LR * warmup_factor net_utils.update_learning_rate_rel(optimizer, lr, lr_new) lr = optimizer.param_groups[2]['lr'] backbone_lr = optimizer.param_groups[0]['lr'] assert lr == lr_new elif step == cfg.SOLVER.WARM_UP_ITERS: net_utils.update_learning_rate_rel(optimizer, lr, cfg.SOLVER.BASE_LR) lr = optimizer.param_groups[2]['lr'] backbone_lr = optimizer.param_groups[0]['lr'] assert lr == cfg.SOLVER.BASE_LR # Learning rate decay if decay_steps_ind < len(cfg.SOLVER.STEPS) and \ step == cfg.SOLVER.STEPS[decay_steps_ind]: logger.info('Decay the learning on step %d', step) lr_new = lr * cfg.SOLVER.GAMMA net_utils.update_learning_rate_rel(optimizer, lr, lr_new) lr = optimizer.param_groups[2]['lr'] backbone_lr = optimizer.param_groups[0]['lr'] assert lr == lr_new decay_steps_ind += 1 ######################################################################################################################### ## train ######################################################################################################################### training_stats.IterTic() optimizer.zero_grad() for inner_iter in range(args.iter_size): try: input_data = next(dataiterator_training) except StopIteration: print('recurrence data loader') dataiterator_training = iter(dataloader_training) input_data = next(dataiterator_training) for key in input_data: if key != 'roidb': # roidb is a list of ndarrays with inconsistent length input_data[key] = list(map(Variable, input_data[key])) net_outputs = maskRCNN(**input_data) training_stats.UpdateIterStats(net_outputs['gt_label'], inner_iter) loss = net_outputs['gt_label']['total_loss'] loss.backward() optimizer.step() training_stats.IterToc() training_stats.LogIterStats(step, lr, backbone_lr) if (step + 1) % cfg.SAVE_MODEL_ITER == 0: save_ckpt(output_dir, args, step, batch_size, maskRCNN, optimizer, False, best_total_loss) # ---- Training ends ---- save_ckpt(output_dir, args, step, batch_size, maskRCNN, optimizer, False, best_total_loss) except (RuntimeError, KeyboardInterrupt): del dataiterator_training logger.info('Save ckpt on exception ...') save_ckpt(output_dir, args, step, batch_size, maskRCNN, optimizer, False, best_total_loss) logger.info('Save ckpt done.') stack_trace = traceback.format_exc() print(stack_trace) finally: if args.use_tfboard and not args.no_save: tblogger.close()
def main(): """Main function""" args = parse_args() print('Called with args:') print(args) if not torch.cuda.is_available(): sys.exit("Need a CUDA device to run the code.") if args.cuda or cfg.NUM_GPUS > 0: cfg.CUDA = True else: raise ValueError("Need Cuda device to run !") cfg.DATASET = args.dataset if args.dataset == "vg80k": cfg.TRAIN.DATASETS = ('vg80k_train', ) cfg.TEST.DATASETS = ('vg80k_val', ) cfg.MODEL.NUM_CLASSES = 53305 # includes background cfg.MODEL.NUM_PRD_CLASSES = 29086 # excludes background elif args.dataset == "vg8k": cfg.TRAIN.DATASETS = ('vg8k_train', ) cfg.TEST.DATASETS = ('vg8k_val', ) cfg.MODEL.NUM_CLASSES = 5331 # includes background cfg.MODEL.NUM_PRD_CLASSES = 2000 # excludes background elif args.dataset == "vrd": cfg.TRAIN.DATASETS = ('vrd_train', ) cfg.TEST.DATASETS = ('vrd_val', ) cfg.MODEL.NUM_CLASSES = 101 cfg.MODEL.NUM_PRD_CLASSES = 70 # exclude background elif args.dataset == "vg": cfg.TRAIN.DATASETS = ('vg_train', ) cfg.TEST.DATASETS = ('vg_val', ) cfg.MODEL.NUM_CLASSES = 151 cfg.MODEL.NUM_PRD_CLASSES = 50 # exclude background elif args.dataset == "gvqa20k": cfg.TRAIN.DATASETS = ('gvqa20k_train', ) cfg.TEST.DATASETS = ('gvqa20k_val', ) cfg.MODEL.NUM_CLASSES = 1704 # includes background cfg.MODEL.NUM_PRD_CLASSES = 310 # exclude background elif args.dataset == "gvqa10k": cfg.TRAIN.DATASETS = ('gvqa10k_train', ) cfg.TEST.DATASETS = ('gvqa10k_val', ) cfg.MODEL.NUM_CLASSES = 1704 # includes background cfg.MODEL.NUM_PRD_CLASSES = 310 # exclude background elif args.dataset == "gvqa": cfg.TRAIN.DATASETS = ('gvqa_train', ) cfg.TEST.DATASETS = ('gvqa_val', ) cfg.MODEL.NUM_CLASSES = 1704 # includes background cfg.MODEL.NUM_PRD_CLASSES = 310 # exclude background else: raise ValueError("Unexpected args.dataset: {}".format(args.dataset)) cfg_from_file(args.cfg_file) if args.set_cfgs is not None: cfg_from_list(args.set_cfgs) if args.seed: cfg.RNG_SEED = args.seed # Some imports need to be done after loading the config to avoid using default values from datasets.roidb_rel import combined_roidb_for_training from modeling.model_builder_rel import Generalized_RCNN from core.test_engine_rel import run_eval_inference, run_inference from core.test_engine_rel import get_inference_dataset, get_roidb_and_dataset logger.info('Training with config:') logger.info(pprint.pformat(cfg)) ### Adaptively adjust some configs ### original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH original_num_gpus = cfg.NUM_GPUS if args.batch_size is None: args.batch_size = original_batch_size cfg.NUM_GPUS = torch.cuda.device_count() assert (args.batch_size % cfg.NUM_GPUS) == 0, \ 'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS) cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS effective_batch_size = args.iter_size * args.batch_size print('effective_batch_size = batch_size * iter_size = %d * %d' % (args.batch_size, args.iter_size)) print('Adaptive config changes:') print(' effective_batch_size: %d --> %d' % (original_batch_size, effective_batch_size)) print(' NUM_GPUS: %d --> %d' % (original_num_gpus, cfg.NUM_GPUS)) print(' IMS_PER_BATCH: %d --> %d' % (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH)) ### Adjust learning based on batch size change linearly # For iter_size > 1, gradients are `accumulated`, so lr is scaled based # on batch_size instead of effective_batch_size old_base_lr = cfg.SOLVER.BASE_LR cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size print('Adjust BASE_LR linearly according to batch_size change:\n' ' BASE_LR: {} --> {}'.format(old_base_lr, cfg.SOLVER.BASE_LR)) ### Adjust solver steps step_scale = original_batch_size / effective_batch_size old_solver_steps = cfg.SOLVER.STEPS old_max_iter = cfg.SOLVER.MAX_ITER cfg.SOLVER.STEPS = list( map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS)) cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5) print( 'Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n' ' SOLVER.STEPS: {} --> {}\n' ' SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps, cfg.SOLVER.STEPS, old_max_iter, cfg.SOLVER.MAX_ITER)) if cfg.FPN.FPN_ON and cfg.MODEL.FASTER_RCNN: cfg.FPN.RPN_COLLECT_SCALE = cfg.TRAIN.IMS_PER_BATCH / original_ims_per_batch print( 'Scale FPN rpn_proposals collect size directly propotional to the change of IMS_PER_BATCH:\n' ' cfg.FPN.RPN_COLLECT_SCALE: {}'.format( cfg.FPN.RPN_COLLECT_SCALE)) if args.num_workers is not None: cfg.DATA_LOADER.NUM_THREADS = args.num_workers print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS) ### Overwrite some solver settings from command line arguments if args.optimizer is not None: cfg.SOLVER.TYPE = args.optimizer if args.lr is not None: cfg.SOLVER.BASE_LR = args.lr if args.lr_decay_gamma is not None: cfg.SOLVER.GAMMA = args.lr_decay_gamma assert_and_infer_cfg() timers = defaultdict(Timer) ### Dataset ### timers['roidb'].tic() roidb, ratio_list, ratio_index = combined_roidb_for_training( cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES) timers['roidb'].toc() roidb_size = len(roidb) logger.info('{:d} roidb entries'.format(roidb_size)) logger.info('Takes %.2f sec(s) to construct roidb', timers['roidb'].average_time) # Effective training sample size for one epoch train_size = roidb_size // args.batch_size * args.batch_size batchSampler = BatchSampler(sampler=MinibatchSampler( ratio_list, ratio_index), batch_size=args.batch_size, drop_last=True) dataset = RoiDataLoader(roidb, cfg.MODEL.NUM_CLASSES, training=True) dataloader = torch.utils.data.DataLoader( dataset, batch_sampler=batchSampler, num_workers=cfg.DATA_LOADER.NUM_THREADS, collate_fn=collate_minibatch) dataiterator = iter(dataloader) ### Model ### maskRCNN = Generalized_RCNN() if cfg.CUDA: maskRCNN.cuda() ### Optimizer ### # record backbone params, i.e., conv_body and box_head params gn_params = [] backbone_bias_params = [] backbone_bias_param_names = [] prd_branch_bias_params = [] prd_branch_bias_param_names = [] backbone_nonbias_params = [] backbone_nonbias_param_names = [] prd_branch_nonbias_params = [] prd_branch_nonbias_param_names = [] if cfg.MODEL.DECOUPLE: for key, value in dict(maskRCNN.named_parameters()).items(): if not 'so_sem_embeddings.2' in key: value.requires_grad = False for key, value in dict(maskRCNN.named_parameters()).items(): if value.requires_grad: if 'gn' in key: gn_params.append(value) elif 'Conv_Body' in key or 'Box_Head' in key or 'Box_Outs' in key or 'RPN' in key: if 'bias' in key: backbone_bias_params.append(value) backbone_bias_param_names.append(key) else: backbone_nonbias_params.append(value) backbone_nonbias_param_names.append(key) else: if 'bias' in key: prd_branch_bias_params.append(value) prd_branch_bias_param_names.append(key) else: prd_branch_nonbias_params.append(value) prd_branch_nonbias_param_names.append(key) # Learning rate of 0 is a dummy value to be set properly at the start of training params = [{ 'params': backbone_nonbias_params, 'lr': 0, 'weight_decay': cfg.SOLVER.WEIGHT_DECAY }, { 'params': backbone_bias_params, 'lr': 0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1), 'weight_decay': cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0 }, { 'params': prd_branch_nonbias_params, 'lr': 0, 'weight_decay': cfg.SOLVER.WEIGHT_DECAY }, { 'params': prd_branch_bias_params, 'lr': 0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1), 'weight_decay': cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0 }, { 'params': gn_params, 'lr': 0, 'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN }] if cfg.SOLVER.TYPE == "SGD": optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM) elif cfg.SOLVER.TYPE == "Adam": optimizer = torch.optim.Adam(params) load_ckpt_dir = './' ### Load checkpoint if args.load_ckpt_dir: load_name = get_checkpoint_resume_file(args.load_ckpt_dir) load_ckpt_dir = args.load_ckpt_dir elif args.load_ckpt: load_name = args.load_ckpt load_ckpt_dir = os.path.dirname(args.load_ckpt) if args.load_ckpt or args.load_ckpt_dir: logging.info("loading checkpoint %s", load_name) checkpoint = torch.load(load_name, map_location=lambda storage, loc: storage) if cfg.MODEL.DECOUPLE: del checkpoint['model']['RelDN.so_sem_embeddings.2.weight'] del checkpoint['model']['RelDN.so_sem_embeddings.2.bias'] del checkpoint['model']['RelDN.prd_sem_embeddings.2.weight'] del checkpoint['model']['RelDN.prd_sem_embeddings.2.bias'] net_utils.load_ckpt(maskRCNN, checkpoint['model']) if args.resume: args.start_step = checkpoint['step'] + 1 if 'train_size' in checkpoint: # For backward compatibility if checkpoint['train_size'] != train_size: print( 'train_size value: %d different from the one in checkpoint: %d' % (train_size, checkpoint['train_size'])) # reorder the params in optimizer checkpoint's params_groups if needed # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint) # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1. # However it's fixed on master. # optimizer.load_state_dict(checkpoint['optimizer']) misc_utils.load_optimizer_state_dict(optimizer, checkpoint['optimizer']) del checkpoint torch.cuda.empty_cache() if args.load_detectron: #TODO resume for detectron weights (load sgd momentum values) logging.info("loading Detectron weights %s", args.load_detectron) load_detectron_weight(maskRCNN, args.load_detectron) lr = optimizer.param_groups[2][ 'lr'] # lr of non-backbone parameters, for commmand line outputs. backbone_lr = optimizer.param_groups[0][ 'lr'] # lr of backbone parameters, for commmand line outputs. prd_categories = maskRCNN.prd_categories obj_categories = maskRCNN.obj_categories prd_freq_dict = maskRCNN.prd_freq_dict obj_freq_dict = maskRCNN.obj_freq_dict maskRCNN = mynn.DataParallel(maskRCNN, cpu_keywords=['im_info', 'roidb'], minibatch=True) ### Training Setups ### args.run_name = misc_utils.get_run_name() + '_step_with_prd_cls_v' + str( cfg.MODEL.SUBTYPE) output_dir = misc_utils.get_output_dir(args, args.run_name) args.cfg_filename = os.path.basename(args.cfg_file) if not args.no_save: if not os.path.exists(output_dir): os.makedirs(output_dir) blob = {'cfg': yaml.dump(cfg), 'args': args} with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f: pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL) ckpt_dir = os.path.join(output_dir, 'ckpt') if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) # if os.path.exists(os.path.join(ckpt_dir, 'best.json')): # best = json.load(open(os.path.join(ckpt_dir, 'best.json'))) if args.resume and os.path.exists( os.path.join(load_ckpt_dir, 'best.json')): logger.info('Loading best json from :' + os.path.join(load_ckpt_dir, 'best.json')) best = json.load(open(os.path.join(load_ckpt_dir, 'best.json'))) json.dump(best, open(os.path.join(ckpt_dir, 'best.json'), 'w')) else: best = {} best['avg_per_class_acc'] = 0.0 best['iteration'] = 0 best['accuracies'] = [] json.dump(best, open(os.path.join(ckpt_dir, 'best.json'), 'w')) if args.use_tfboard: from tensorboardX import SummaryWriter # Set the Tensorboard logger tblogger = SummaryWriter(output_dir) args.output_dir = output_dir args.do_val = True args.use_gt_boxes = True args.use_gt_labels = True logger.info('Creating val roidb') val_dataset_name, val_proposal_file = get_inference_dataset(0) val_roidb, val_dataset, start_ind, end_ind, total_num_images = get_roidb_and_dataset( val_dataset_name, val_proposal_file, None, args.do_val) logger.info('Done') ### Training Loop ### maskRCNN.train() # CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS) # CHECKPOINT_PERIOD = cfg.SOLVER.MAX_ITER / cfg.TRAIN.SNAPSHOT_FREQ CHECKPOINT_PERIOD = 10000 EVAL_PERIOD = cfg.TRAIN.EVAL_PERIOD # Set index for decay steps decay_steps_ind = None for i in range(1, len(cfg.SOLVER.STEPS)): if cfg.SOLVER.STEPS[i] >= args.start_step: decay_steps_ind = i break if decay_steps_ind is None: decay_steps_ind = len(cfg.SOLVER.STEPS) training_stats = TrainingStats( args, args.disp_interval, tblogger if args.use_tfboard and not args.no_save else None) try: logger.info('Training starts !') step = args.start_step for step in range(args.start_step, cfg.SOLVER.MAX_ITER): maskRCNN.train() # Warm up if step < cfg.SOLVER.WARM_UP_ITERS: method = cfg.SOLVER.WARM_UP_METHOD if method == 'constant': warmup_factor = cfg.SOLVER.WARM_UP_FACTOR elif method == 'linear': alpha = step / cfg.SOLVER.WARM_UP_ITERS warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 - alpha) + alpha else: raise KeyError( 'Unknown SOLVER.WARM_UP_METHOD: {}'.format(method)) lr_new = cfg.SOLVER.BASE_LR * warmup_factor net_utils.update_learning_rate_rel(optimizer, lr, lr_new) lr = optimizer.param_groups[2]['lr'] backbone_lr = optimizer.param_groups[0]['lr'] assert lr == lr_new elif step == cfg.SOLVER.WARM_UP_ITERS: net_utils.update_learning_rate_rel(optimizer, lr, cfg.SOLVER.BASE_LR) lr = optimizer.param_groups[2]['lr'] backbone_lr = optimizer.param_groups[0]['lr'] assert lr == cfg.SOLVER.BASE_LR # Learning rate decay if decay_steps_ind < len(cfg.SOLVER.STEPS) and \ step == cfg.SOLVER.STEPS[decay_steps_ind]: logger.info('Decay the learning on step %d', step) lr_new = lr * cfg.SOLVER.GAMMA net_utils.update_learning_rate_rel(optimizer, lr, lr_new) lr = optimizer.param_groups[2]['lr'] backbone_lr = optimizer.param_groups[0]['lr'] assert lr == lr_new decay_steps_ind += 1 training_stats.IterTic() optimizer.zero_grad() for inner_iter in range(args.iter_size): try: input_data = next(dataiterator) except StopIteration: dataiterator = iter(dataloader) input_data = next(dataiterator) for key in input_data: if key != 'roidb': # roidb is a list of ndarrays with inconsistent length input_data[key] = list(map(Variable, input_data[key])) net_outputs = maskRCNN(**input_data) training_stats.UpdateIterStats(net_outputs, inner_iter) loss = net_outputs['total_loss'] loss.backward() optimizer.step() training_stats.IterToc() training_stats.LogIterStats(step, lr, backbone_lr) if (step + 1) % EVAL_PERIOD == 0 or (step == cfg.SOLVER.MAX_ITER - 1): logger.info('Validating model') eval_model = maskRCNN.module eval_model = mynn.DataParallel( eval_model, cpu_keywords=['im_info', 'roidb'], device_ids=[0], minibatch=True) eval_model.eval() all_results = run_eval_inference(eval_model, val_roidb, args, val_dataset, val_dataset_name, val_proposal_file, ind_range=None, multi_gpu_testing=False, check_expected_results=True) csv_path = os.path.join(output_dir, 'eval.csv') all_results = all_results[0] generate_csv_file_from_det_obj(all_results, csv_path, obj_categories, prd_categories, obj_freq_dict, prd_freq_dict) overall_metrics, per_class_metrics = get_metrics_from_csv( csv_path) obj_acc = per_class_metrics[(csv_path, 'obj', 'top1')] sbj_acc = per_class_metrics[(csv_path, 'sbj', 'top1')] prd_acc = per_class_metrics[(csv_path, 'rel', 'top1')] avg_obj_sbj = (obj_acc + sbj_acc) / 2.0 avg_acc = (prd_acc + avg_obj_sbj) / 2.0 best = json.load(open(os.path.join(ckpt_dir, 'best.json'))) if avg_acc > best['avg_per_class_acc']: print('Found new best validation accuracy at {:2.2f}%'. format(avg_acc)) print('Saving best model..') best['avg_per_class_acc'] = avg_acc best['iteration'] = step best['per_class_metrics'] = { 'obj_top1': per_class_metrics[(csv_path, 'obj', 'top1')], 'sbj_top1': per_class_metrics[(csv_path, 'sbj', 'top1')], 'prd_top1': per_class_metrics[(csv_path, 'rel', 'top1')] } best['overall_metrics'] = { 'obj_top1': overall_metrics[(csv_path, 'obj', 'top1')], 'sbj_top1': overall_metrics[(csv_path, 'sbj', 'top1')], 'prd_top1': overall_metrics[(csv_path, 'rel', 'top1')] } save_best_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer) json.dump(best, open(os.path.join(ckpt_dir, 'best.json'), 'w')) if (step + 1) % CHECKPOINT_PERIOD == 0: print('Saving Checkpoint..') save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer) # ---- Training ends ---- # Save last checkpoint save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer) except Exception as e: del dataiterator logger.info('Save ckpt on exception ...') save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer) logger.info('Save ckpt done.') stack_trace = traceback.format_exc() print(stack_trace) finally: if args.use_tfboard and not args.no_save: tblogger.close()