Example #1
0
def add_model_training_inputs(model):
    """Load the training dataset and attach the training inputs to the model."""
    logger = logging.getLogger(__name__)
    logger.info('Loading dataset: {}'.format(cfg.TRAIN.DATASETS))
    roidb = combined_roidb_for_training(cfg.TRAIN.DATASETS,
                                        cfg.TRAIN.PROPOSAL_FILES)
    logger.info('{:d} roidb entries'.format(len(roidb)))
    model_builder.add_training_inputs(model, roidb=roidb)
def build_model(model, split):
    def _single_gpu_build_func(model):
        return model_creator_map[cfg.MODEL.MODEL_NAME].create_model(
            model=model)

    if split == 'train':
        roidb = combined_roidb_for_training(cfg.DATASET + '_train', None)
        proposals = get_gt_perturbed_proposals(roidb)
        logger.info('Training proposals length: {}'.format(len(proposals)))
    elif split == 'val':
        roidb = combined_roidb_for_val_test(cfg.DATASET + '_val')
        proposals = get_gt_val_test_proposals('val', roidb)
        logger.info('Validation proposals length: {}'.format(len(proposals)))
    else:
        roidb = combined_roidb_for_val_test(cfg.DATASET + '_' +
                                            cfg.TEST.DATA_TYPE)
        proposals = get_gt_val_test_proposals(cfg.TEST.DATA_TYPE, roidb)
    logger.info('{:d} roidb entries'.format(len(roidb)))

    landb = get_landb(cfg.DATASET + '_lan')

    add_inputs(model,
               roidb=roidb,
               landb=landb,
               proposals=proposals,
               split=split)

    feed_all_word_vecs(model)

    optim.build_data_parallel_model(model, _single_gpu_build_func)
    workspace.RunNetOnce(model.param_init_net)

    odir, cdir = get_dirs(model, split)

    if split != 'test':
        setup_model(model, cfg.TRAIN.PARAMS_FILE, split)
    else:
        setup_model(model, None, split)

    return model, odir, cdir
Example #3
0
def main():
    args = parse_args()
    print('Called with args:')
    print(args)

    cfg = set_configs(args)
    timers = defaultdict(Timer)

    ### --------------------------------------------------------------------------------
    ### Dataset Training ###
    ### --------------------------------------------------------------------------------
    timers['roidb_training'].tic()
    roidb_training, ratio_list_training, ratio_index_training, category_to_id_map, prd_category_to_id_map = combined_roidb_for_training(
        cfg.TRAIN.DATASETS)
    timers['roidb_training'].toc()
    roidb_size_training = len(roidb_training)
    logger.info('{:d} training roidb entries'.format(roidb_size_training))
    logger.info('Takes %.2f sec(s) to construct training roidb',
                timers['roidb_training'].average_time)

    batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH

    dataset_training = RoiDataLoader(roidb_training,
                                     cfg.MODEL.NUM_CLASSES,
                                     training=True,
                                     dataset=cfg.TRAIN.DATASETS)
    dataloader_training = torch.utils.data.DataLoader(
        dataset_training,
        batch_size=batch_size,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch,
        shuffle=True,
        drop_last=True)
    dataiterator_training = iter(dataloader_training)

    ### --------------------------------------------------------------------------------
    ### Dataset Validation ###
    ### --------------------------------------------------------------------------------
    timers['roidb_val'].tic()
    roidb_val, ratio_list_val, ratio_index_val, _, _ = combined_roidb_for_training(
        cfg.VAL.DATASETS)
    timers['roidb_val'].toc()
    roidb_size_val = len(roidb_val)
    logger.info('{:d} val roidb entries'.format(roidb_size_val))
    logger.info('Takes %.2f sec(s) to construct val roidb',
                timers['roidb_val'].average_time)

    dataset_val = RoiDataLoader(roidb_val,
                                cfg.MODEL.NUM_CLASSES,
                                training=False,
                                dataset=cfg.VAL.DATASETS)
    dataloader_val = torch.utils.data.DataLoader(
        dataset_val,
        batch_size=batch_size,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch,
        drop_last=True)

    ### --------------------------------------------------------------------------------
    ### Dataset Test ###
    ### --------------------------------------------------------------------------------
    timers['roidb_test'].tic()
    roidb_test, ratio_list_test, ratio_index_test, _, _ = combined_roidb_for_training(
        cfg.TEST.DATASETS)
    timers['roidb_test'].toc()
    roidb_size_test = len(roidb_test)
    logger.info('{:d} test roidb entries'.format(roidb_size_test))
    logger.info('Takes %.2f sec(s) to construct test roidb',
                timers['roidb_test'].average_time)

    dataset_test = RoiDataLoader(roidb_test,
                                 cfg.MODEL.NUM_CLASSES,
                                 training=False,
                                 dataset=cfg.TEST.DATASETS)
    dataloader_test = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=batch_size,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch,
        drop_last=True)

    ### --------------------------------------------------------------------------------
    ### Dataset Unseen ###
    ### --------------------------------------------------------------------------------
    if args.dataset == 'vhico':
        timers['roidb_unseen'].tic()
        roidb_unseen, ratio_list_unseen, ratio_index_unseen, _, _ = combined_roidb_for_training(
            cfg.UNSEEN.DATASETS)
        timers['roidb_unseen'].toc()
        roidb_size_unseen = len(roidb_unseen)
        logger.info('{:d} test unseen roidb entries'.format(roidb_size_unseen))
        logger.info('Takes %.2f sec(s) to construct test roidb',
                    timers['roidb_unseen'].average_time)

        dataset_unseen = RoiDataLoader(roidb_unseen,
                                       cfg.MODEL.NUM_CLASSES,
                                       training=False,
                                       dataset=cfg.UNSEEN.DATASETS)
        dataloader_unseen = torch.utils.data.DataLoader(
            dataset_unseen,
            batch_size=batch_size,
            num_workers=cfg.DATA_LOADER.NUM_THREADS,
            collate_fn=collate_minibatch,
            drop_last=True)

    ### --------------------------------------------------------------------------------
    ### Model ###
    ### --------------------------------------------------------------------------------
    maskRCNN = Generalized_RCNN(category_to_id_map=category_to_id_map,
                                prd_category_to_id_map=prd_category_to_id_map,
                                args=args)
    if cfg.CUDA:
        maskRCNN.cuda()

    ### --------------------------------------------------------------------------------
    ### Optimizer ###
    # record backbone params, i.e., conv_body and box_head params
    ### --------------------------------------------------------------------------------
    gn_params = []
    backbone_bias_params = []
    backbone_bias_param_names = []
    prd_branch_bias_params = []
    prd_branch_bias_param_names = []
    backbone_nonbias_params = []
    backbone_nonbias_param_names = []
    prd_branch_nonbias_params = []
    prd_branch_nonbias_param_names = []
    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'gn' in key:
                gn_params.append(value)
            elif 'Conv_Body' in key or 'Box_Head' in key or 'Box_Outs' in key or 'RPN' in key:
                if 'bias' in key:
                    backbone_bias_params.append(value)
                    backbone_bias_param_names.append(key)
                else:
                    backbone_nonbias_params.append(value)
                    backbone_nonbias_param_names.append(key)
            else:
                if 'bias' in key:
                    prd_branch_bias_params.append(value)
                    prd_branch_bias_param_names.append(key)
                else:
                    prd_branch_nonbias_params.append(value)
                    prd_branch_nonbias_param_names.append(key)
    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [{
        'params': backbone_nonbias_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        backbone_bias_params,
        'lr':
        0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }, {
        'params': prd_branch_nonbias_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        prd_branch_bias_params,
        'lr':
        0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }, {
        'params': gn_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN
    }]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### --------------------------------------------------------------------------------
    ### Load checkpoint
    ### --------------------------------------------------------------------------------
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])

        print(
            '--------------------------------------------------------------------------------'
        )
        print('loading checkpoint %s' % load_name)
        print(
            '--------------------------------------------------------------------------------'
        )

        if args.resume:
            print('resume')
            args.start_step = checkpoint['step'] + 1
            misc_utils.load_optimizer_state_dict(optimizer,
                                                 checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()
    else:
        print('args.load_ckpt', args.load_ckpt)

    lr = optimizer.param_groups[2][
        'lr']  # lr of non-backbone parameters, for commmand line outputs.
    backbone_lr = optimizer.param_groups[0][
        'lr']  # lr of backbone parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### --------------------------------------------------------------------------------
    ### Training Setups ###
    ### --------------------------------------------------------------------------------
    args.run_name = args.out_dir
    output_dir = misc_utils.get_output_dir(args, args.out_dir)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            tblogger = SummaryWriter(output_dir)

    ### --------------------------------------------------------------------------------
    ### Training Loop ###
    ### --------------------------------------------------------------------------------
    maskRCNN.train()

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None, True)

    val_stats = ValStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None, False)

    test_stats = TestStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None, False)

    best_total_loss = np.inf
    best_eval_result = 0

    ### --------------------------------------------------------------------------------
    ### EVAL ###
    ### --------------------------------------------------------------------------------
    if cfg.EVAL_SUBSET == 'unseen':
        print('testing unseen ...')
        is_best, best_eval_result = run_eval(args,
                                             cfg,
                                             maskRCNN,
                                             dataloader_unseen,
                                             step=0,
                                             output_dir=output_dir,
                                             test_stats=test_stats,
                                             best_eval_result=best_eval_result,
                                             eval_subset=cfg.EVAL_SUBSET)
        return
    elif cfg.EVAL_SUBSET == 'test':
        print('testing ...')
        is_best, best_eval_result = run_eval(args,
                                             cfg,
                                             maskRCNN,
                                             dataloader_test,
                                             step=0,
                                             output_dir=output_dir,
                                             test_stats=test_stats,
                                             best_eval_result=best_eval_result,
                                             eval_subset=cfg.EVAL_SUBSET)
        return

    ### --------------------------------------------------------------------------------
    ### TRAIN ###
    ### --------------------------------------------------------------------------------
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):
            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 -
                                                                 alpha) + alpha
                else:
                    raise KeyError(
                        'Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate_rel(optimizer, lr, lr_new)
                lr = optimizer.param_groups[2]['lr']
                backbone_lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate_rel(optimizer, lr,
                                                   cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[2]['lr']
                backbone_lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate_rel(optimizer, lr, lr_new)
                lr = optimizer.param_groups[2]['lr']
                backbone_lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            #########################################################################################################################
            ## train
            #########################################################################################################################
            training_stats.IterTic()
            optimizer.zero_grad()

            for inner_iter in range(args.iter_size):
                try:
                    input_data = next(dataiterator_training)
                except StopIteration:
                    print('recurrence data loader')
                    dataiterator_training = iter(dataloader_training)
                    input_data = next(dataiterator_training)

                for key in input_data:
                    if key != 'roidb':  # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))

                net_outputs = maskRCNN(**input_data)

                training_stats.UpdateIterStats(net_outputs['gt_label'],
                                               inner_iter)
                loss = net_outputs['gt_label']['total_loss']
                loss.backward()

            optimizer.step()
            training_stats.IterToc()
            training_stats.LogIterStats(step, lr, backbone_lr)

            if (step + 1) % cfg.SAVE_MODEL_ITER == 0:
                save_ckpt(output_dir, args, step, batch_size, maskRCNN,
                          optimizer, False, best_total_loss)

        # ---- Training ends ----
        save_ckpt(output_dir, args, step, batch_size, maskRCNN, optimizer,
                  False, best_total_loss)

    except (RuntimeError, KeyboardInterrupt):
        del dataiterator_training
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, batch_size, maskRCNN, optimizer,
                  False, best_total_loss)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
def main():
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    cfg.DATASET = args.dataset
    if args.dataset == "vrd":
        cfg.TRAIN.DATASETS = ('vrd_train',)
        cfg.MODEL.NUM_CLASSES = 101
        cfg.MODEL.NUM_PRD_CLASSES = 70  # exclude background
    elif args.dataset == "vg":
        cfg.TRAIN.DATASETS = ('vg_train',)
        cfg.MODEL.NUM_CLASSES = 151
        cfg.MODEL.NUM_PRD_CLASSES = 50  # exclude background
    elif args.dataset == "vg80k":
        cfg.TRAIN.DATASETS = ('vg80k_train',)
        cfg.MODEL.NUM_CLASSES = 53305 # includes background
        cfg.MODEL.NUM_PRD_CLASSES = 29086  # excludes background
    elif args.dataset == "gvqa20k":
        cfg.TRAIN.DATASETS = ('gvqa20k_train',)
        cfg.MODEL.NUM_CLASSES = 1704 # includes background
        cfg.MODEL.NUM_PRD_CLASSES = 310  # exclude background
    elif args.dataset == "gvqa10k":
        cfg.TRAIN.DATASETS = ('gvqa10k_train',)
        cfg.MODEL.NUM_CLASSES = 1704 # includes background
        cfg.MODEL.NUM_PRD_CLASSES = 310  # exclude background
    elif args.dataset == "gvqa":
        cfg.TRAIN.DATASETS = ('gvqa_train',)
        cfg.MODEL.NUM_CLASSES = 1704 # includes background
        cfg.MODEL.NUM_PRD_CLASSES = 310  # exclude background

    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH
    original_num_gpus = cfg.NUM_GPUS
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    effective_batch_size = args.iter_size * args.batch_size
    print('effective_batch_size = batch_size * iter_size = %d * %d' % (args.batch_size, args.iter_size))

    print('Adaptive config changes:')
    print('    effective_batch_size: %d --> %d' % (original_batch_size, effective_batch_size))
    print('    NUM_GPUS:             %d --> %d' % (original_num_gpus, cfg.NUM_GPUS))
    print('    IMS_PER_BATCH:        %d --> %d' % (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH))

    ### Adjust learning based on batch size change linearly
    # For iter_size > 1, gradients are `accumulated`, so lr is scaled based
    # on batch_size instead of effective_batch_size
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch_size change:\n'
          '    BASE_LR: {} --> {}'.format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Adjust solver steps
    step_scale = original_batch_size / effective_batch_size
    old_solver_steps = cfg.SOLVER.STEPS
    old_max_iter = cfg.SOLVER.MAX_ITER
    cfg.SOLVER.STEPS = list(map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS))
    cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5)
    print('Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n'
          '    SOLVER.STEPS: {} --> {}\n'
          '    SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps, cfg.SOLVER.STEPS,
                                                  old_max_iter, cfg.SOLVER.MAX_ITER))


    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma
    assert_and_infer_cfg()

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb', timers['roidb'].average_time)

    # Effective training sample size for one epoch
    train_size = roidb_size // args.batch_size * args.batch_size

    batchSampler = BatchSampler(
        sampler=MinibatchSampler(ratio_list, ratio_index),
        batch_size=args.batch_size,
        drop_last=True
    )
    dataset = RoiDataLoader(
        roidb,
        cfg.MODEL.NUM_CLASSES,
        training=True)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=batchSampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    dataiterator = iter(dataloader)

    ### Model ###
    maskRCNN = Generalized_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()
        
    ### Optimizer ###
    # record backbone params, i.e., conv_body and box_head params
    gn_params = []
    backbone_bias_params = []
    backbone_bias_param_names = []
    prd_branch_bias_params = []
    prd_branch_bias_param_names = []
    backbone_nonbias_params = []
    backbone_nonbias_param_names = []
    prd_branch_nonbias_params = []
    prd_branch_nonbias_param_names = []
    classifier_params = []
    classifier_param_names = []
    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'gn' in key:
                gn_params.append(value)
            elif 'Conv_Body' in key or 'Box_Head' in key or 'Box_Outs' in key or 'RPN' in key:
                if 'bias' in key:
                    backbone_bias_params.append(value)
                    backbone_bias_param_names.append(key)
                else:
                    backbone_nonbias_params.append(value)
                    backbone_nonbias_param_names.append(key)
            #elif 'classifier' in key:
            #    classifier_params.append(value)
            #    classifier_param_names.append(value)
            else:
                if 'bias' in key:
                    prd_branch_bias_params.append(value)
                    prd_branch_bias_param_names.append(key)
                else:
                    prd_branch_nonbias_params.append(value)
                    prd_branch_nonbias_param_names.append(key)
    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [
        {'params': backbone_nonbias_params,
         'lr': 0,
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY},
        {'params': backbone_bias_params,
         'lr': 0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0},
        {'params': prd_branch_nonbias_params,
         'lr': 0,
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY},
        {'params': prd_branch_bias_params,
         'lr': 0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0},
        {'params': gn_params,
         'lr': 0,
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN}#,
        #{'params': classifier_params,
        # 'lr': 0.1, 'momentum': 0.9, 'weight_decay': 0.0005}
    ]

    # print('Initializing model optimizer.')

    # if cfg.SOLVER.TYPE == "SGD":
    #     optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    # elif cfg.SOLVER.TYPE == "Adam":
    #     optimizer = torch.optim.Adam(params)


    print('Initializing model and classifier optimizers.')
    #classifier_optim_param = {'lr': 0.1, 'momentum': 0.9, 'weight_decay': 0.0005}
    #params.append({'params': maskRCNN.classifier.parameters(),
    #                'lr': classifier_optim_param['lr']),
    #                'momentum': classifier_optim_param['momentum'],
    #                'weight_decay': classifier_optim_param['weight_decay']})
    #params.append({'params': maskRCNN.prd_classifier.parameters(),
    #                'lr': classifier_optim_param['lr'],
    #                'momentum': classifier_optim_param['momentum'],
    #                'weight_decay': classifier_optim_param['weight_decay']})

    if cfg.MODEL.MEMORY_MODULE_STAGE == 1:
        step_size = 10
    elif cfg.MODEL.MEMORY_MODULE_STAGE == 2:
        step_size = 20
    else:
        raise NotImplementedError

    scheduler_params = {'step_size': step_size, 'gamma': 0.1}

    optimizer, optimizer_scheduler = init_optimizers(params, scheduler_params)

    criterion_optimizer, criterion_optimizer_scheduler = None, None
    if cfg.MODEL.MEMORY_MODULE_STAGE == 2:
        print('Initializing criterion optimizer.')
        feat_loss_optim_param = {'lr': 0.01, 'momentum': 0.9, 'weight_decay': 0.0005}

        optim_params = feat_loss_optim_param
        optim_params = [{'params': maskRCNN.feature_loss_sbj_obj.parameters(),
                         'lr': optim_params['lr'],
                         'momentum': optim_params['momentum'],
                         'weight_decay': optim_params['weight_decay']},
                        {'params': maskRCNN.feature_loss_prd.parameters(),
                         'lr': optim_params['lr'],
                         'momentum': optim_params['momentum'],
                         'weight_decay': optim_params['weight_decay']}
                        ]

        # Initialize criterion optimizer and scheduler
        criterion_optimizer, criterion_optimizer_scheduler = init_optimizers(optim_params, scheduler_params)
    if cfg.MODEL.MEMORY_MODULE_STAGE == 2:
        weights_path = 'Outputs/e2e_relcnn_VGG16_8_epochs_gvqa_y_loss_only_1_gpu/gvqa/Feb07-10-55-03_login104-09_step_with_prd_cls_v3/ckpt/model_step1439.pth'
        weights = torch.load(weights_path)
        #print('weights', weights['model'].keys())
        #print(maskRCNN.state_dict().keys())
        maskRCNN.load_state_dict(weights['model'], strict=False)
        #print(list(maskRCNN.parameters()))
        #print(maskRCNN.state_dict().keys())
        #print(maskRCNN.state_dict()['prd_classifier.fc_hallucinator.weight'] == weights['model']['prd_classifier.fc.weight'])
        #print(torch.all(torch.eq(maskRCNN.state_dict()['prd_classifier.fc_hallucinator.weight'], weights['model']['prd_classifier.fc.weight'])))
        #print(torch.all(torch.eq(maskRCNN.state_dict()['Box_Head.heads.0.weight'], weights['model']['Box_Head.heads.0.weight'])))
        #exit()

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name, map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print('train_size value: %d different from the one in checkpoint: %d'
                          % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer, checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[2]['lr']  # lr of non-backbone parameters, for commmand line outputs.
    backbone_lr = optimizer.param_groups[0]['lr']  # lr of backbone parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN, cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name() + '_step_with_prd_cls_v' + str(cfg.MODEL.SUBTYPE)
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    # CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)
    # CHECKPOINT_PERIOD = cfg.SOLVER.MAX_ITER / cfg.TRAIN.SNAPSHOT_FREQ
    CHECKPOINT_PERIOD = 200000

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args,
        args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):
            # optimizer_scheduler.step()
            # if criterion_optimizer:
            #     criterion_optimizer_scheduler.step()
            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 - alpha) + alpha
                else:
                    raise KeyError('Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate_rel(optimizer, lr, lr_new)
                lr = optimizer.param_groups[2]['lr']
                backbone_lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate_rel(optimizer, lr, cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[2]['lr']
                backbone_lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate_rel(optimizer, lr, lr_new)
                if criterion_optimizer:
                    net_utils.update_learning_rate_rel(criterion_optimizer, lr, lr_new)
                lr = optimizer.param_groups[2]['lr']
                backbone_lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new

                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()
            if criterion_optimizer:
                criterion_optimizer.zero_grad()
            for inner_iter in range(args.iter_size):
                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)
                #print('input_data', [torch.isnan(x) for x in input_data.values()])    
                
                for key in input_data:
                    if key != 'roidb': # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))
                #with autograd.detect_anomaly():
                net_outputs = maskRCNN(**input_data)
                training_stats.UpdateIterStats(net_outputs, inner_iter)
                loss = net_outputs['total_loss']
                loss.backward()
            optimizer.step()
            if criterion_optimizer:
                criterion_optimizer.step()
            training_stats.IterToc()

            training_stats.LogIterStats(step, lr, backbone_lr)

            if (step+1) % CHECKPOINT_PERIOD == 0:
                print('Saving Checkpoint..')
                save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

        # ---- Training ends ----
        # Save last checkpoint
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

    except Exception as e:
        del dataiterator
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
def main():
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    cfg.DATASET = args.dataset
    if args.dataset == "vg80k":
        cfg.TRAIN.DATASETS = ('vg80k_train', )
        cfg.TEST.DATASETS = ('vg80k_val', )
        cfg.MODEL.NUM_CLASSES = 53305  # includes background
        cfg.MODEL.NUM_PRD_CLASSES = 29086  # excludes background
    elif args.dataset == "vg8k":
        cfg.TRAIN.DATASETS = ('vg8k_train', )
        cfg.TEST.DATASETS = ('vg8k_val', )
        cfg.MODEL.NUM_CLASSES = 5331  # includes background
        cfg.MODEL.NUM_PRD_CLASSES = 2000  # excludes background
    elif args.dataset == "vrd":
        cfg.TRAIN.DATASETS = ('vrd_train', )
        cfg.TEST.DATASETS = ('vrd_val', )
        cfg.MODEL.NUM_CLASSES = 101
        cfg.MODEL.NUM_PRD_CLASSES = 70  # exclude background
    elif args.dataset == "vg":
        cfg.TRAIN.DATASETS = ('vg_train', )
        cfg.TEST.DATASETS = ('vg_val', )
        cfg.MODEL.NUM_CLASSES = 151
        cfg.MODEL.NUM_PRD_CLASSES = 50  # exclude background
    elif args.dataset == "gvqa20k":
        cfg.TRAIN.DATASETS = ('gvqa20k_train', )
        cfg.TEST.DATASETS = ('gvqa20k_val', )
        cfg.MODEL.NUM_CLASSES = 1704  # includes background
        cfg.MODEL.NUM_PRD_CLASSES = 310  # exclude background
    elif args.dataset == "gvqa10k":
        cfg.TRAIN.DATASETS = ('gvqa10k_train', )
        cfg.TEST.DATASETS = ('gvqa10k_val', )
        cfg.MODEL.NUM_CLASSES = 1704  # includes background
        cfg.MODEL.NUM_PRD_CLASSES = 310  # exclude background
    elif args.dataset == "gvqa":
        cfg.TRAIN.DATASETS = ('gvqa_train', )
        cfg.TEST.DATASETS = ('gvqa_val', )
        cfg.MODEL.NUM_CLASSES = 1704  # includes background
        cfg.MODEL.NUM_PRD_CLASSES = 310  # exclude background

    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    if args.seed:
        cfg.RNG_SEED = args.seed

    # Some imports need to be done after loading the config to avoid using default values
    from datasets.roidb_rel import combined_roidb_for_training
    from modeling.model_builder_rel import Generalized_RCNN
    from core.test_engine_rel import run_eval_inference, run_inference
    from core.test_engine_rel import get_inference_dataset, get_roidb_and_dataset

    logger.info('Training with config:')
    logger.info(pprint.pformat(cfg))

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH
    original_num_gpus = cfg.NUM_GPUS
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    effective_batch_size = args.iter_size * args.batch_size
    print('effective_batch_size = batch_size * iter_size = %d * %d' %
          (args.batch_size, args.iter_size))

    print('Adaptive config changes:')
    print('    effective_batch_size: %d --> %d' %
          (original_batch_size, effective_batch_size))
    print('    NUM_GPUS:             %d --> %d' %
          (original_num_gpus, cfg.NUM_GPUS))
    print('    IMS_PER_BATCH:        %d --> %d' %
          (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH))

    ### Adjust learning based on batch size change linearly
    # For iter_size > 1, gradients are `accumulated`, so lr is scaled based
    # on batch_size instead of effective_batch_size
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch_size change:\n'
          '    BASE_LR: {} --> {}'.format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Adjust solver steps
    step_scale = original_batch_size / effective_batch_size
    old_solver_steps = cfg.SOLVER.STEPS
    old_max_iter = cfg.SOLVER.MAX_ITER
    cfg.SOLVER.STEPS = list(
        map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS))
    cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5)
    print(
        'Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n'
        '    SOLVER.STEPS: {} --> {}\n'
        '    SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps,
                                                cfg.SOLVER.STEPS, old_max_iter,
                                                cfg.SOLVER.MAX_ITER))

    if cfg.FPN.FPN_ON and cfg.MODEL.FASTER_RCNN:
        cfg.FPN.RPN_COLLECT_SCALE = cfg.TRAIN.IMS_PER_BATCH / original_ims_per_batch
        print(
            'Scale FPN rpn_proposals collect size directly propotional to the change of IMS_PER_BATCH:\n'
            '    cfg.FPN.RPN_COLLECT_SCALE: {}'.format(
                cfg.FPN.RPN_COLLECT_SCALE))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma
    assert_and_infer_cfg()

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb',
                timers['roidb'].average_time)

    # Effective training sample size for one epoch
    train_size = roidb_size // args.batch_size * args.batch_size

    batchSampler = BatchSampler(sampler=MinibatchSampler(
        ratio_list, ratio_index),
                                batch_size=args.batch_size,
                                drop_last=True)
    dataset = RoiDataLoader(roidb, cfg.MODEL.NUM_CLASSES, training=True)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=batchSampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    dataiterator = iter(dataloader)

    ### Model ###
    maskRCNN = Generalized_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    # record backbone params, i.e., conv_body and box_head params
    gn_params = []
    backbone_bias_params = []
    backbone_bias_param_names = []
    prd_branch_bias_params = []
    prd_branch_bias_param_names = []
    backbone_nonbias_params = []
    backbone_nonbias_param_names = []
    prd_branch_nonbias_params = []
    prd_branch_nonbias_param_names = []

    if cfg.MODEL.DECOUPLE:
        for key, value in dict(maskRCNN.named_parameters()).items():
            if not 'so_sem_embeddings.2' in key:
                value.requires_grad = False

    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'gn' in key:
                gn_params.append(value)
            elif 'Conv_Body' in key or 'Box_Head' in key or 'Box_Outs' in key or 'RPN' in key:
                if 'bias' in key:
                    backbone_bias_params.append(value)
                    backbone_bias_param_names.append(key)
                else:
                    backbone_nonbias_params.append(value)
                    backbone_nonbias_param_names.append(key)
            else:
                if 'bias' in key:
                    prd_branch_bias_params.append(value)
                    prd_branch_bias_param_names.append(key)
                else:
                    prd_branch_nonbias_params.append(value)
                    prd_branch_nonbias_param_names.append(key)
    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [{
        'params': backbone_nonbias_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        backbone_bias_params,
        'lr':
        0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }, {
        'params': prd_branch_nonbias_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        prd_branch_bias_params,
        'lr':
        0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }, {
        'params': gn_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN
    }]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    load_ckpt_dir = './'
    ### Load checkpoint
    if args.load_ckpt_dir:
        load_name = get_checkpoint_resume_file(args.load_ckpt_dir)
        load_ckpt_dir = args.load_ckpt_dir
    elif args.load_ckpt:
        load_name = args.load_ckpt
        load_ckpt_dir = os.path.dirname(args.load_ckpt)

    if args.load_ckpt or args.load_ckpt_dir:
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)

        if cfg.MODEL.DECOUPLE:
            del checkpoint['model']['RelDN.so_sem_embeddings.2.weight']
            del checkpoint['model']['RelDN.so_sem_embeddings.2.bias']
            del checkpoint['model']['RelDN.prd_sem_embeddings.2.weight']
            del checkpoint['model']['RelDN.prd_sem_embeddings.2.bias']

        net_utils.load_ckpt(maskRCNN, checkpoint['model'])
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print(
                        'train_size value: %d different from the one in checkpoint: %d'
                        % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer,
                                                 checkpoint['optimizer'])

        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[2][
        'lr']  # lr of non-backbone parameters, for commmand line outputs.
    backbone_lr = optimizer.param_groups[0][
        'lr']  # lr of backbone parameters, for commmand line outputs.

    prd_categories = maskRCNN.prd_categories
    obj_categories = maskRCNN.obj_categories
    prd_freq_dict = maskRCNN.prd_freq_dict
    obj_freq_dict = maskRCNN.obj_freq_dict

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name() + '_step_with_prd_cls_v' + str(
        cfg.MODEL.SUBTYPE)
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        ckpt_dir = os.path.join(output_dir, 'ckpt')

        if not os.path.exists(ckpt_dir):
            os.makedirs(ckpt_dir)

        # if os.path.exists(os.path.join(ckpt_dir, 'best.json')):
        #     best = json.load(open(os.path.join(ckpt_dir, 'best.json')))
        if args.resume and os.path.exists(
                os.path.join(load_ckpt_dir, 'best.json')):
            logger.info('Loading best json from :' +
                        os.path.join(load_ckpt_dir, 'best.json'))
            best = json.load(open(os.path.join(load_ckpt_dir, 'best.json')))
            json.dump(best, open(os.path.join(ckpt_dir, 'best.json'), 'w'))
        else:
            best = {}
            best['avg_per_class_acc'] = 0.0
            best['iteration'] = 0
            best['accuracies'] = []
            json.dump(best, open(os.path.join(ckpt_dir, 'best.json'), 'w'))

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    args.output_dir = output_dir
    args.do_val = True
    args.use_gt_boxes = True
    args.use_gt_labels = True

    logger.info('Creating val roidb')
    val_dataset_name, val_proposal_file = get_inference_dataset(0)
    val_roidb, val_dataset, start_ind, end_ind, total_num_images = get_roidb_and_dataset(
        val_dataset_name, val_proposal_file, None, args.do_val)
    logger.info('Done')

    ### Training Loop ###
    maskRCNN.train()

    # CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)
    # CHECKPOINT_PERIOD = cfg.SOLVER.MAX_ITER / cfg.TRAIN.SNAPSHOT_FREQ
    CHECKPOINT_PERIOD = 10000
    EVAL_PERIOD = cfg.TRAIN.EVAL_PERIOD
    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):

            maskRCNN.train()
            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 -
                                                                 alpha) + alpha
                else:
                    raise KeyError(
                        'Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate_rel(optimizer, lr, lr_new)
                lr = optimizer.param_groups[2]['lr']
                backbone_lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate_rel(optimizer, lr,
                                                   cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[2]['lr']
                backbone_lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate_rel(optimizer, lr, lr_new)
                lr = optimizer.param_groups[2]['lr']
                backbone_lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()
            for inner_iter in range(args.iter_size):
                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)

                for key in input_data:
                    if key != 'roidb':  # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))

                net_outputs = maskRCNN(**input_data)
                training_stats.UpdateIterStats(net_outputs, inner_iter)
                loss = net_outputs['total_loss']
                loss.backward()

            optimizer.step()
            training_stats.IterToc()

            training_stats.LogIterStats(step, lr, backbone_lr)

            if (step + 1) % EVAL_PERIOD == 0 or (step
                                                 == cfg.SOLVER.MAX_ITER - 1):
                logger.info('Validating model')
                eval_model = maskRCNN.module
                eval_model = mynn.DataParallel(
                    eval_model,
                    cpu_keywords=['im_info', 'roidb'],
                    device_ids=[0],
                    minibatch=True)
                eval_model.eval()
                all_results = run_eval_inference(eval_model,
                                                 val_roidb,
                                                 args,
                                                 val_dataset,
                                                 val_dataset_name,
                                                 val_proposal_file,
                                                 ind_range=None,
                                                 multi_gpu_testing=False,
                                                 check_expected_results=True)
                csv_path = os.path.join(output_dir, 'eval.csv')
                all_results = all_results[0]
                generate_csv_file_from_det_obj(all_results, csv_path,
                                               obj_categories, prd_categories,
                                               obj_freq_dict, prd_freq_dict)
                overall_metrics, per_class_metrics = get_metrics_from_csv(
                    csv_path)
                obj_acc = per_class_metrics[(csv_path, 'obj', 'top1')]
                sbj_acc = per_class_metrics[(csv_path, 'sbj', 'top1')]
                prd_acc = per_class_metrics[(csv_path, 'rel', 'top1')]
                avg_obj_sbj = (obj_acc + sbj_acc) / 2.0
                avg_acc = (prd_acc + avg_obj_sbj) / 2.0

                best = json.load(open(os.path.join(ckpt_dir, 'best.json')))
                if avg_acc > best['avg_per_class_acc']:
                    print('Found new best validation accuracy at {:2.2f}%'.
                          format(avg_acc))
                    print('Saving best model..')
                    best['avg_per_class_acc'] = avg_acc
                    best['iteration'] = step
                    best['per_class_metrics'] = {
                        'obj_top1':
                        per_class_metrics[(csv_path, 'obj', 'top1')],
                        'sbj_top1':
                        per_class_metrics[(csv_path, 'sbj', 'top1')],
                        'prd_top1':
                        per_class_metrics[(csv_path, 'rel', 'top1')]
                    }
                    best['overall_metrics'] = {
                        'obj_top1': overall_metrics[(csv_path, 'obj', 'top1')],
                        'sbj_top1': overall_metrics[(csv_path, 'sbj', 'top1')],
                        'prd_top1': overall_metrics[(csv_path, 'rel', 'top1')]
                    }
                    save_best_ckpt(output_dir, args, step, train_size,
                                   maskRCNN, optimizer)
                    json.dump(best,
                              open(os.path.join(ckpt_dir, 'best.json'), 'w'))

            if (step + 1) % CHECKPOINT_PERIOD == 0:
                print('Saving Checkpoint..')
                save_ckpt(output_dir, args, step, train_size, maskRCNN,
                          optimizer)

        # ---- Training ends ----
        # Save last checkpoint
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

    except Exception as e:
        del dataiterator
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()