Пример #1
0
def main():
    """Main function"""

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    args = parse_args()
    print('Called with args:')
    print(args)

    load_config(args.cfg_file)

    if cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH
    original_num_gpus = cfg.NUM_GPUS
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()

    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
     'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    effective_batch_size = args.iter_size * args.batch_size
    print('effective_batch_size = batch_size * iter_size = %d * %d' %
          (args.batch_size, args.iter_size))

    print('Adaptive config changes:')
    print('    effective_batch_size: %d --> %d' %
          (original_batch_size, effective_batch_size))
    print('    NUM_GPUS:             %d --> %d' %
          (original_num_gpus, cfg.NUM_GPUS))
    print('    IMS_PER_BATCH:        %d --> %d' %
          (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH))

    ### Adjust learning based on batch size change linearly
    # For iter_size > 1, gradients are `accumulated`, so lr is scaled based
    # on batch_size instead of effective_batch_size
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch_size change:\n'
          '    BASE_LR: {} --> {}'.format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Adjust solver steps
    step_scale = original_batch_size / effective_batch_size
    old_solver_steps = cfg.SOLVER.STEPS
    old_max_iter = cfg.SOLVER.MAX_ITER
    cfg.SOLVER.STEPS = list(
        map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS))
    cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5)
    print(
        'Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n'
        '    SOLVER.STEPS: {} --> {}\n'
        '    SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps,
                                                cfg.SOLVER.STEPS, old_max_iter,
                                                cfg.SOLVER.MAX_ITER))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma

    # np.random.seed(cfg.RNG_SEED)
    # torch.manual_seed(cfg.RNG_SEED)
    # if cfg.CUDA:
    #     torch.cuda.manual_seed_all(cfg.RNG_SEED)
    torch.backends.cudnn.deterministic = True

    transforms = dt_trans.Compose([
        dt_trans.Normalize([102.9801, 115.9465, 122.7717]),
        dt_trans.HorizontalFlip(),
        dt_trans.Resize(),
        dt_trans.ToTensor(),
    ])

    if cfg.TRAIN.DATASET == 'voc_2007_trainval':
        dataset = VOCDetection(cfg.DATA_DIR + '/',
                               year='2007',
                               image_set='trainval',
                               transforms=transforms)

    elif cfg.TRAIN.DATASET == 'voc_2012_trainval':
        dataset = VOCDetection(cfg.DATA_DIR + '/',
                               year='2012',
                               image_set='trainval',
                               transforms=transforms)

    # Effective training sample size for one epoch
    train_size = len(dataset) // args.batch_size * args.batch_size

    batchSampler = TrainSampler(subdivision=cfg.TRAIN.ITERATION_SIZE,
                                batch_size=cfg.TRAIN.ITERATION_SIZE,
                                max_iterations=cfg.SOLVER.MAX_ITER,
                                num_samples=len(dataset),
                                image_scales=cfg.TRAIN.SCALES,
                                scale_interval=10)

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=batchSampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    dataiterator = iter(dataloader)

    ### Model ###
    model = eval(args.model).loot_model(args)

    if cfg.CUDA:
        model.cuda()

    ### Optimizer ###
    bias_params = []
    bias_param_names = []
    nonbias_params = []
    nonbias_param_names = []
    nograd_param_names = []
    for key, value in model.named_parameters():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
                bias_param_names.append(key)
            else:
                nonbias_params.append(value)
                nonbias_param_names.append(key)
        else:
            nograd_param_names.append(key)

    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [
        {
            'params': nonbias_params,
            'lr': 0,
            'weight_decay': cfg.SOLVER.WEIGHT_DECAY
        },
        {
            'params':
            bias_params,
            'lr':
            0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
            'weight_decay':
            cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
        },
    ]
    # names of paramerters for each paramter
    param_names = [nonbias_param_names, bias_param_names]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        optimizer_utils.load_ckpt(model, checkpoint['model'])
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print(
                        'train_size value: %d different from the one in checkpoint: %d'
                        % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            optimizer.load_state_dict(checkpoint['optimizer'])
            # misc_utils.load_optimizer_state_dict(optimizer, checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    lr = optimizer.param_groups[0][
        'lr']  # lr of non-bias parameters, for commmand line outputs.

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name() + '_step'
    output_dir = misc_utils.get_output_dir(args)

    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    model.train()

    CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS /
                            (cfg.NUM_GPUS * args.iter_size))

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):

            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 -
                                                                 alpha) + alpha
                else:
                    raise KeyError(
                        'Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                optimizer_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                optimizer_utils.update_learning_rate(optimizer, lr,
                                                     cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
              step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                optimizer_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()
            for inner_iter in range(args.iter_size):
                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)

                model.set_inner_iter(step)

                im_data = input_data[0].cuda()
                labels = input_data[1].cuda()
                rois = input_data[2].cuda()

                net_outputs = model(im_data, rois, labels)

                training_stats.UpdateIterStats(net_outputs, inner_iter)
                loss = net_outputs['total_loss']
                loss.backward(retain_graph=True)
            optimizer.step()
            training_stats.IterToc()

            training_stats.LogIterStats(step, lr)

            if (step + 1) % CHECKPOINT_PERIOD == 0:
                save_ckpt(output_dir, args, step, train_size, model, optimizer)

        # ---- Training ends ----
        # Save last checkpoint
        save_ckpt(output_dir, args, step, train_size, model, optimizer)

    except (RuntimeError, KeyboardInterrupt):
        del dataiterator
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, train_size, model, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
Пример #2
0
def main():
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    if args.dataset == "coco2017":
        cfg.TRAIN.DATASETS = ('coco_2017_train',)
        cfg.MODEL.NUM_CLASSES = 81
    elif args.dataset == "keypoints_coco2017":
        cfg.TRAIN.DATASETS = ('keypoints_coco_2017_train',)
        cfg.MODEL.NUM_CLASSES = 2
    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH
    original_num_gpus = cfg.NUM_GPUS
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    effective_batch_size = args.iter_size * args.batch_size
    print('effective_batch_size = batch_size * iter_size = %d * %d' % (args.batch_size, args.iter_size))

    print('Adaptive config changes:')
    print('    effective_batch_size: %d --> %d' % (original_batch_size, effective_batch_size))
    print('    NUM_GPUS:             %d --> %d' % (original_num_gpus, cfg.NUM_GPUS))
    print('    IMS_PER_BATCH:        %d --> %d' % (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH))

    ### Adjust learning based on batch size change linearly
    # For iter_size > 1, gradients are `accumulated`, so lr is scaled based
    # on batch_size instead of effective_batch_size
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch_size change:\n'
          '    BASE_LR: {} --> {}'.format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Adjust solver steps
    step_scale = original_batch_size / effective_batch_size
    old_solver_steps = cfg.SOLVER.STEPS
    old_max_iter = cfg.SOLVER.MAX_ITER
    cfg.SOLVER.STEPS = list(map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS))
    cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5)
    print('Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n'
          '    SOLVER.STEPS: {} --> {}\n'
          '    SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps, cfg.SOLVER.STEPS,
                                                  old_max_iter, cfg.SOLVER.MAX_ITER))

    # Scale FPN rpn_proposals collect size (post_nms_topN) in `collect` function
    # of `collect_and_distribute_fpn_rpn_proposals.py`
    #
    # post_nms_topN = int(cfg[cfg_key].RPN_POST_NMS_TOP_N * cfg.FPN.RPN_COLLECT_SCALE + 0.5)
    if cfg.FPN.FPN_ON and cfg.MODEL.FASTER_RCNN:
        cfg.FPN.RPN_COLLECT_SCALE = cfg.TRAIN.IMS_PER_BATCH / original_ims_per_batch
        print('Scale FPN rpn_proposals collect size directly propotional to the change of IMS_PER_BATCH:\n'
              '    cfg.FPN.RPN_COLLECT_SCALE: {}'.format(cfg.FPN.RPN_COLLECT_SCALE))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma
    assert_and_infer_cfg()

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb', timers['roidb'].average_time)

    # Effective training sample size for one epoch
    train_size = roidb_size // args.batch_size * args.batch_size

    batchSampler = BatchSampler(
        sampler=MinibatchSampler(ratio_list, ratio_index),
        batch_size=args.batch_size,
        drop_last=True
    )
    dataset = RoiDataLoader(
        roidb,
        cfg.MODEL.NUM_CLASSES,
        training=True)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=batchSampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    dataiterator = iter(dataloader)

    ### Model ###
    maskRCNN = Generalized_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    gn_param_nameset = set()
    for name, module in maskRCNN.named_modules():
        if isinstance(module, nn.GroupNorm):
            gn_param_nameset.add(name+'.weight')
            gn_param_nameset.add(name+'.bias')
    gn_params = []
    gn_param_names = []
    bias_params = []
    bias_param_names = []
    nonbias_params = []
    nonbias_param_names = []
    nograd_param_names = []
    for key, value in maskRCNN.named_parameters():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
                bias_param_names.append(key)
            elif key in gn_param_nameset:
                gn_params.append(value)
                gn_param_names.append(key)
            else:
                nonbias_params.append(value)
                nonbias_param_names.append(key)
        else:
            nograd_param_names.append(key)
    assert (gn_param_nameset - set(nograd_param_names) - set(bias_param_names)) == set(gn_param_names)

    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [
        {'params': nonbias_params,
         'lr': 0,
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY},
        {'params': bias_params,
         'lr': 0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0},
        {'params': gn_params,
         'lr': 0,
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN}
    ]
    # names of paramerters for each paramter
    param_names = [nonbias_param_names, bias_param_names, gn_param_names]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name, map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print('train_size value: %d different from the one in checkpoint: %d'
                          % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            optimizer.load_state_dict(checkpoint['optimizer'])
            # misc_utils.load_optimizer_state_dict(optimizer, checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[0]['lr']  # lr of non-bias parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN, cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name() + '_step'
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args,
        args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):

            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 - alpha) + alpha
                else:
                    raise KeyError('Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate(optimizer, lr, cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()
            for inner_iter in range(args.iter_size):
                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)

                for key in input_data:
                    if key != 'roidb': # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))

                net_outputs = maskRCNN(**input_data)
                training_stats.UpdateIterStats(net_outputs, inner_iter)
                loss = net_outputs['total_loss']
                loss.backward()
            optimizer.step()
            training_stats.IterToc()

            training_stats.LogIterStats(step, lr)

            if (step+1) % CHECKPOINT_PERIOD == 0:
                save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

        # ---- Training ends ----
        # Save last checkpoint
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

    except (RuntimeError, KeyboardInterrupt):
        del dataiterator
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
Пример #3
0
def main():

    saveNetStructure = False
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        #set gpu device
        os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
            [str(ids) for ids in args.device_ids])
        torch.backends.cudnn.benchmark = True
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    if args.dataset == "coco2017":
        cfg.TRAIN.DATASETS = ('coco_2017_train', )
        cfg.MODEL.NUM_CLASSES = 81
    elif args.dataset == "keypoints_coco2017":
        cfg.TRAIN.DATASETS = ('keypoints_coco_2017_train', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cityscapes":
        cfg.TRAIN.DATASETS = ('cityscapes_semseg_train', )
        cfg.MODEL.NUM_CLASSES = 19
    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    print('Batch size change from {} (in config file) to {}'.format(
        original_batch_size, args.batch_size))
    print('NUM_GPUs: %d, TRAIN.IMS_PER_BATCH: %d' %
          (cfg.NUM_GPUS, cfg.TRAIN.IMS_PER_BATCH))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Adjust learning based on batch size change linearly
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch size change: {} --> {}'.
          format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    if cfg.SEM.SEM_ON or cfg.DISP.DISP_ON:
        roidb, ratio_list, ratio_index = combined_roidb_for_training_semseg(
            cfg.TRAIN.DATASETS)
    else:
        roidb, ratio_list, ratio_index = combined_roidb_for_training(
            cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    train_size = len(roidb)
    logger.info('{:d} roidb entries'.format(train_size))
    logger.info('Takes %.2f sec(s) to construct roidb',
                timers['roidb'].average_time)

    sampler = MinibatchSampler(ratio_list, ratio_index)
    dataset = RoiDataLoader(roidb, cfg.MODEL.NUM_CLASSES, training=True)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        sampler=sampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch_semseg
        if cfg.SEM.SEM_ON or cfg.DISP.DISP_ON else collate_minibatch)

    assert_and_infer_cfg()
    #for args.step, input_data in zip(range(100), dataloader):
    #    data_L = input_data['data']
    #    data_R = input_data['data_R']
    #    label = input_data['disp_label_0']
    #    cv2.imwrite('ims_L.png', data_L[0].numpy()[0].transpose(1,2,0)[:,:,::-1]+cfg.PIXEL_MEANS)
    #    cv2.imwrite('ims_R.png', data_R[0].numpy()[0].transpose(1,2,0)[:,:,::-1]+cfg.PIXEL_MEANS)
    #    cv2.imwrite('label.png', label[0].numpy()[0])
    #    return
    ### Model ###
    dispSeg = DispSeg()

    if cfg.CUDA:
        dispSeg.to('cuda')

    pspnet_bias_params = []
    pspnet_nonbias_params = []
    for key, value in dict(dispSeg.pspnet.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                pspnet_bias_params.append(value)
            else:
                pspnet_nonbias_params.append(value)

    pspnet_params = [{
        'params': pspnet_nonbias_params,
        'lr': cfg.SOLVER.BASE_LR,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        pspnet_bias_params,
        'lr':
        cfg.SOLVER.BASE_LR * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }]

    glassGCN_bias_params = []
    glassGCN_nonbias_params = []
    for key, value in dict(dispSeg.glassGCN.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                glassGCN_bias_params.append(value)
            else:
                glassGCN_nonbias_params.append(value)

    segdisp3d_params = [{
        'params': glassGCN_nonbias_params,
        'lr': cfg.SOLVER.BASE_LR,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        glassGCN_bias_params,
        'lr':
        cfg.SOLVER.BASE_LR * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }]

    if cfg.SOLVER.TYPE == "SGD":
        optimizerP = torch.optim.SGD(pspnet_params,
                                     momentum=cfg.SOLVER.MOMENTUM)
        optimizerS = torch.optim.SGD(segdisp3d_params,
                                     momentum=cfg.SOLVER.MOMENTUM)

    elif cfg.SOLVER.TYPE == "Adam":
        optimizerP = torch.optim.Adam(pspnet_params)
        optimizerS = torch.optim.Adam(segdisp3d_params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(pspnet, checkpoint['model'])
        net_utils.load_ckpt(segdisp3d, checkpoint['model'])

        if args.resume:
            assert checkpoint['iters_per_epoch'] == train_size // args.batch_size, \
                "iters_per_epoch should match for resume"
            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer,
                                                 checkpoint['optimizer'])
            if checkpoint['step'] == (checkpoint['iters_per_epoch'] - 1):
                # Resume from end of an epoch
                args.start_epoch = checkpoint['epoch'] + 1
                args.start_iter = 0
            else:
                # Resume from the middle of an epoch.
                # NOTE: dataloader is not synced with previous state
                args.start_epoch = checkpoint['epoch']
                args.start_iter = checkpoint['step'] + 1
        del checkpoint
        torch.cuda.empty_cache()

    lr = optimizerP.param_groups[0][
        'lr']  # lr of non-bias parameters, for commmand line outputs.

    dispSeg = mynn.DataParallel(dispSeg,
                                cpu_keywords=['im_info', 'roidb'],
                                minibatch=True)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name()
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            #from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    dispSeg.train()
    training_stats = TrainingStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)

    iters_per_epoch = int(train_size / args.batch_size)  # drop last
    args.iters_per_epoch = iters_per_epoch
    ckpt_interval_per_epoch = iters_per_epoch // args.ckpt_num_per_epoch
    try:
        logger.info('Training starts !')
        args.step = args.start_iter
        global_step = iters_per_epoch * args.start_epoch + args.step
        for args.epoch in range(args.start_epoch,
                                args.start_epoch + args.num_epochs):
            # ---- Start of epoch ----

            # adjust learning rate
            if args.lr_decay_epochs and args.epoch == args.lr_decay_epochs[
                    0] and args.start_iter == 0:
                args.lr_decay_epochs.pop(0)
                net_utils.decay_learning_rate(optimizerP, lr, cfg.SOLVER.GAMMA)
                net_utils.decay_learning_rate(optimizerS, lr, cfg.SOLVER.GAMMA)
                lr *= cfg.SOLVER.GAMMA

            for args.step, input_data in zip(
                    range(args.start_iter, iters_per_epoch), dataloader):

                if cfg.DISP.DISP_ON:
                    input_data['data'] = list(
                        map(lambda x, y: torch.cat((x, y), dim=0),
                            input_data['data'], input_data['data_R']))
                    if cfg.SEM.DECODER_TYPE.endswith('3Ddeepsup'):
                        input_data['disp_scans'] = torch.arange(
                            0, cfg.DISP.MAX_DISPLACEMENT).float().view(
                                1, cfg.DISP.MAX_DISPLACEMENT, 1,
                                1).repeat(args.batch_size, 1, 1, 1)
                        input_data['semseg_scans'] = torch.arange(
                            0, cfg.MODEL.NUM_CLASSES).long().view(
                                1, cfg.MODEL.NUM_CLASSES, 1,
                                1).repeat(args.batch_size, 1, 1, 1)
                    del input_data['data_R']

                for key in input_data:
                    if key != 'roidb':  # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(
                            map(
                                lambda x: Variable(x, requires_grad=False).to(
                                    'cuda'), input_data[key]))
                training_stats.IterTic()
                net_outputs = dispSeg(**input_data)
                training_stats.UpdateIterStats(net_outputs)
                #loss = net_outputs['losses']['loss_semseg']
                #acc  = net_outputs['metrics']['accuracy_pixel']
                #print (loss.item(), acc)
                #for key in net_outputs.keys():
                #    print(key)
                loss = net_outputs['total_loss']

                #print("loss.shape:",loss)
                optimizerP.zero_grad()
                optimizerS.zero_grad()
                loss.backward()
                optimizerP.step()
                optimizerS.step()
                training_stats.IterToc()

                if args.step % args.disp_interval == 0:
                    #disp_image=net_outputs['disp_image']
                    #semseg_image=net_outputs['semseg_image']
                    #tblogger.add_image('disp_image',disp_image,global_step)
                    #tblogger.add_image('semseg_image',semseg_image,global_step)
                    log_training_stats(training_stats, global_step, lr)
                global_step += 1
            # ---- End of epoch ----
            # save checkpoint

            net_utils.save_ckpt(output_dir, args, dispSeg, optimizerS)
            # reset starting iter number after first epoch
            args.start_iter = 0

        # ---- Training ends ----
        #if iters_per_epoch % args.disp_interval != 0:
        # log last stats at the end
        #    log_training_stats(training_stats, global_step, lr)

    except (RuntimeError, KeyboardInterrupt):
        logger.info('Save ckpt on exception ...')
        net_utils.save_ckpt(output_dir, args, dispSeg, optimizerS)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
Пример #4
0
def main():
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    if args.dataset == "coco2017":
        cfg.TRAIN.DATASETS = ('coco_2017_train', )
        cfg.MODEL.NUM_CLASSES = 81
    elif args.dataset == "keypoints_coco2017":
        cfg.TRAIN.DATASETS = ('keypoints_coco_2017_train', )
        cfg.MODEL.NUM_CLASSES = 2
    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    print('Batch size change from {} (in config file) to {}'.format(
        original_batch_size, args.batch_size))
    print('NUM_GPUs: %d, TRAIN.IMS_PER_BATCH: %d' %
          (cfg.NUM_GPUS, cfg.TRAIN.IMS_PER_BATCH))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Adjust learning based on batch size change linearly
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch size change: {} --> {}'.
          format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma
    assert_and_infer_cfg()

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb',
                timers['roidb'].average_time)

    # Effective training sample size for one epoch
    train_size = roidb_size // args.batch_size * args.batch_size

    sampler = MinibatchSampler(ratio_list, ratio_index)
    dataset = RoiDataLoader(roidb, cfg.MODEL.NUM_CLASSES, training=True)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        drop_last=True,
        sampler=sampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    dataiterator = iter(dataloader)

    ### Model ###
    maskRCNN = Generalized_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    bias_params = []
    nonbias_params = []
    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
            else:
                nonbias_params.append(value)
    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [{
        'params': nonbias_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        bias_params,
        'lr':
        0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            assert checkpoint['train_size'] == train_size
            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer,
                                                 checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[0][
        'lr']  # lr of non-bias parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    run_name = misc_utils.get_run_name()
    output_dir = misc_utils.get_output_dir(args, run_name)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] > args.start_step:
            decay_steps_ind = i
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    logger.info('Training starts !')
    loss_avg = 0
    try:
        timers['train_loop'].tic()
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):

            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 -
                                                                 alpha) + alpha
                else:
                    raise KeyError(
                        'Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                lr = cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = lr_new
                decay_steps_ind += 1

            try:
                input_data = next(dataiterator)
            except StopIteration:
                dataiterator = iter(dataloader)
                input_data = next(dataiterator)

            for key in input_data:
                if key != 'roidb':  # roidb is a list of ndarrays with inconsistent length
                    input_data[key] = list(map(Variable, input_data[key]))

            outputs = maskRCNN(**input_data)

            rois_label = outputs['rois_label']
            cls_score = outputs['cls_score']
            bbox_pred = outputs['bbox_pred']
            loss_rpn_cls = outputs['loss_rpn_cls'].mean()
            loss_rpn_bbox = outputs['loss_rpn_bbox'].mean()
            loss_rcnn_cls = outputs['loss_rcnn_cls'].mean()
            loss_rcnn_bbox = outputs['loss_rcnn_bbox'].mean()

            loss = loss_rpn_cls + loss_rpn_bbox + loss_rcnn_cls + loss_rcnn_bbox

            if cfg.MODEL.MASK_ON:
                loss_rcnn_mask = outputs['loss_rcnn_mask'].mean()
                loss += loss_rcnn_mask

            if cfg.MODEL.KEYPOINTS_ON:
                loss_rcnn_keypoints = outputs['loss_rcnn_keypoints'].mean()
                loss += loss_rcnn_keypoints

            loss_avg += loss.data.cpu().numpy()[0]

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (step + 1) % CHECKPOINT_PERIOD == 0:
                save_ckpt(output_dir, args, step, train_size, maskRCNN,
                          optimizer)

            if ((step % args.disp_interval == 0 and
                 (step - args.start_step >= args.disp_interval))
                    or step == cfg.SOLVER.MAX_ITER - 1):
                diff = timers['train_loop'].toc(average=False)
                loss_avg /= args.disp_interval

                loss_rpn_cls = loss_rpn_cls.data[0]
                loss_rpn_bbox = loss_rpn_bbox.data[0]
                loss_rcnn_cls = loss_rcnn_cls.data[0]
                loss_rcnn_bbox = loss_rcnn_bbox.data[0]
                fg_cnt = torch.sum(rois_label.data.ne(0))
                bg_cnt = rois_label.data.numel() - fg_cnt
                print("[ %s ][ step %d ]" % (run_name, step))
                print("\t\tloss: %.4f, lr: %.2e" % (loss_avg, lr))
                print("\t\tfg/bg=(%d/%d), time cost: %f" %
                      (fg_cnt, bg_cnt, diff))
                print(
                    "\t\trpn_cls: %.4f, rpn_bbox: %.4f, rcnn_cls: %.4f, rcnn_bbox %.4f"
                    % (loss_rpn_cls, loss_rpn_bbox, loss_rcnn_cls,
                       loss_rcnn_bbox))

                print_prefix = "\t\t"
                if cfg.MODEL.MASK_ON:
                    loss_rcnn_mask = loss_rcnn_mask.data[0]
                    print("%srcnn_mask %.4f" % (print_prefix, loss_rcnn_mask))
                    print_prefix = ", "
                if cfg.MODEL.KEYPOINTS_ON:
                    loss_rcnn_keypoints = loss_rcnn_keypoints.data[0]
                    print("%srcnn_keypoints %.4f" %
                          (print_prefix, loss_rcnn_keypoints))

                if args.use_tfboard and not args.no_save:
                    info = {
                        'lr': lr,
                        'loss': loss_avg,
                        'loss_rpn_cls': loss_rpn_cls,
                        'loss_rpn_box': loss_rpn_bbox,
                        'loss_rcnn_cls': loss_rcnn_cls,
                        'loss_rcnn_box': loss_rcnn_bbox,
                    }
                    if cfg.MODEL.MASK_ON:
                        info['loss_rcnn_mask'] = loss_rcnn_mask
                    if cfg.MODEL.KEYPOINTS_ON:
                        info['loss_rcnn_keypoints'] = loss_rcnn_keypoints
                    for tag, value in info.items():
                        tblogger.add_scalar(tag, value, step)

                loss_avg = 0
                timers['train_loop'].tic()

        # Save last checkpoint
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

    except (RuntimeError, KeyboardInterrupt) as e:
        print('Save on exception:', e)
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        # ---- Training ends ----
        if args.use_tfboard and not args.no_save:
            tblogger.close()
Пример #5
0
def main():
    """Main function"""
    args = parse_cfg()

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb',
                timers['roidb'].average_time)

    # Effective training sample size for one epoch
    train_size = roidb_size // args.batch_size * args.batch_size

    batchSampler = BatchSampler(sampler=MinibatchSampler(
        ratio_list, ratio_index),
                                batch_size=args.batch_size,
                                drop_last=True)
    dataset = RoiDataLoader(roidb, cfg.MODEL.NUM_CLASSES, training=True)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=batchSampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    dataiterator = iter(dataloader)

    ### Model ###
    maskRCNN = GetRCNNModel()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    gn_param_nameset = set()
    for name, module in maskRCNN.named_modules():
        if isinstance(module, nn.GroupNorm):
            gn_param_nameset.add(name + '.weight')
            gn_param_nameset.add(name + '.bias')
    gn_params = []
    gn_param_names = []
    bias_params = []
    bias_param_names = []
    nonbias_params = []
    nonbias_param_names = []
    nograd_param_names = []
    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
                bias_param_names.append(key)
            elif key in gn_param_nameset:
                gn_params.append(value)
                gn_param_names.append(key)
            else:
                nonbias_params.append(value)
                nonbias_param_names.append(key)
        else:
            nograd_param_names.append(key)
    assert (gn_param_nameset - set(nograd_param_names) -
            set(bias_param_names)) == set(gn_param_names)

    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [{
        'params': nonbias_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        bias_params,
        'lr':
        0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }, {
        'params': gn_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN
    }]
    # names of paramerters for each paramter
    param_names = [nonbias_param_names, bias_param_names, gn_param_names]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print(
                        'train_size value: %d different from the one in checkpoint: %d'
                        % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer,
                                                 checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[0][
        'lr']  # lr of non-bias parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)
    if cfg.TRAIN_SYNC_BN:
        # Shu:For synchorinized BN
        patch_replication_callback(maskRCNN)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name() + '_step'
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):

            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 -
                                                                 alpha) + alpha
                else:
                    raise KeyError(
                        'Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate(optimizer, lr,
                                               cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()
            for inner_iter in range(args.iter_size):
                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)

                for key in input_data:
                    if cfg.MODEL.LR_VIEW_ON or cfg.MODEL.GIF_ON or cfg.MODEL.LRASY_MAHA_ON:
                        if key != 'roidb' and key != 'data':  # roidb is a list of ndarrays with inconsistent length
                            input_data[key] = list(
                                map(Variable, input_data[key]))
                        if key == 'data':
                            input_data[key] = [
                                torch.squeeze(item) for item in input_data[key]
                            ]
                            input_data[key] = list(
                                map(Variable, input_data[key]))
                    else:
                        if key != 'roidb':  # roidb is a list of ndarrays with inconsistent length
                            input_data[key] = list(
                                map(Variable, input_data[key]))
                net_outputs = maskRCNN(**input_data)
                training_stats.UpdateIterStats(net_outputs, inner_iter)
                loss = net_outputs['total_loss']
                loss.backward()
            optimizer.step()
            training_stats.IterToc()

            training_stats.LogIterStats(step, lr)

            if (step + 1) % CHECKPOINT_PERIOD == 0:
                net_utils.train_save_ckpt(output_dir, args, step, train_size,
                                          maskRCNN, optimizer)

        # ---- Training ends ----
        # Save last checkpoint
        net_utils.train_save_ckpt(output_dir, args, step, train_size, maskRCNN,
                                  optimizer)

    except (RuntimeError, KeyboardInterrupt):
        del dataiterator
        logger.info('Save ckpt on exception ...')
        net_utils.train_save_ckpt(output_dir, args, step, train_size, maskRCNN,
                                  optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
Пример #6
0
def main():
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    if args.dataset == "coco2017":
        cfg.TRAIN.DATASETS = ('coco_2017_train',)
    elif args.dataset == "keypoints_coco2017":
        cfg.TRAIN.DATASETS = ('keypoints_coco_2017_train',)
    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    print('Batch size change from {} (in config file) to {}'.format(
        original_batch_size, args.batch_size))
    print('NUM_GPUs: %d, TRAIN.IMS_PER_BATCH: %d' % (cfg.NUM_GPUS, cfg.TRAIN.IMS_PER_BATCH))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Adjust learning based on batch size change linearly
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch size change: {} --> {}'.format(
        old_base_lr, cfg.SOLVER.BASE_LR))

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma

    args.mGPUs = (cfg.NUM_GPUS > 1)

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    train_size = len(roidb)
    logger.info('{:d} roidb entries'.format(train_size))
    logger.info('Takes %.2f sec(s) to construct roidb', timers['roidb'].average_time)

    sampler = MinibatchSampler(ratio_list, ratio_index)
    dataset = RoiDataLoader(
        roidb,
        cfg.MODEL.NUM_CLASSES,
        training=True)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        sampler=sampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)

    assert_and_infer_cfg()

    ### Model ###
    maskRCNN = Generalized_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    bias_params = []
    nonbias_params = []
    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
            else:
                nonbias_params.append(value)
    params = [
        {'params': nonbias_params,
         'lr': cfg.SOLVER.BASE_LR,
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY},
        {'params': bias_params,
         'lr': cfg.SOLVER.BASE_LR * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0}
    ]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name, map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])
        if args.resume:
            assert checkpoint['iters_per_epoch'] == train_size // args.batch_size, \
                "iters_per_epoch should match for resume"
            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer, checkpoint['optimizer'])
            if checkpoint['step'] == (checkpoint['iters_per_epoch'] - 1):
                # Resume from end of an epoch
                args.start_epoch = checkpoint['epoch'] + 1
                args.start_iter = 0
            else:
                # Resume from the middle of an epoch.
                # NOTE: dataloader is not synced with previous state
                args.start_epoch = checkpoint['epoch']
                args.start_iter = checkpoint['step'] + 1
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[0]['lr']  # lr of non-bias parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN, cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    run_name = misc_utils.get_run_name()
    output_dir = misc_utils.get_output_dir(args, run_name)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    iters_per_epoch = int(train_size / args.batch_size)  # drop last
    ckpt_interval_per_epoch = iters_per_epoch // args.ckpt_num_per_epoch
    step = 0
    try:
        logger.info('Training starts !')
        for epoch in range(args.start_epoch, args.start_epoch + args.num_epochs):
            # ---- Start of epoch ----
            loss_avg = 0
            timers['train_loop'].tic()

            # adjust learning rate
            if args.lr_decay_epochs and epoch == args.lr_decay_epochs[0] and args.start_iter == 0:
                args.lr_decay_epochs.pop(0)
                net_utils.decay_learning_rate(optimizer, lr, cfg.SOLVER.GAMMA)
                lr *= cfg.SOLVER.GAMMA

            for step, input_data in zip(range(args.start_iter, iters_per_epoch), dataloader):

                for key in input_data:
                    if key != 'roidb': # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))

                outputs = maskRCNN(**input_data)

                rois_label = outputs['rois_label']
                cls_score = outputs['cls_score']
                bbox_pred = outputs['bbox_pred']
                loss_rpn_cls = outputs['loss_rpn_cls'].mean()
                loss_rpn_bbox = outputs['loss_rpn_bbox'].mean()
                loss_rcnn_cls = outputs['loss_rcnn_cls'].mean()
                loss_rcnn_bbox = outputs['loss_rcnn_bbox'].mean()

                loss = loss_rpn_cls + loss_rpn_bbox + loss_rcnn_cls + loss_rcnn_bbox

                if cfg.MODEL.MASK_ON:
                    loss_rcnn_mask = outputs['loss_rcnn_mask'].mean()
                    loss += loss_rcnn_mask

                if cfg.MODEL.KEYPOINTS_ON:
                    loss_rcnn_keypoints = outputs['loss_rcnn_keypoints'].mean()
                    loss += loss_rcnn_keypoints

                loss_avg += loss.data.cpu().numpy()[0]

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                if (step+1) % ckpt_interval_per_epoch == 0:
                    net_utils.save_ckpt(output_dir, args, epoch, step, maskRCNN, optimizer, iters_per_epoch)

                if (step+1) % args.disp_interval == 0:
                    if (step + 1 - args.start_iter) >= args.disp_interval:  # for the case of resume
                        diff = timers['train_loop'].toc(average=False)
                        loss_avg /= args.disp_interval

                        loss_rpn_cls = loss_rpn_cls.data[0]
                        loss_rpn_bbox = loss_rpn_bbox.data[0]
                        loss_rcnn_cls = loss_rcnn_cls.data[0]
                        loss_rcnn_bbox = loss_rcnn_bbox.data[0]
                        fg_cnt = torch.sum(rois_label.data.ne(0))
                        bg_cnt = rois_label.data.numel() - fg_cnt
                        print("[%s][epoch %2d][iter %4d / %4d]"
                            % (run_name, epoch, step, iters_per_epoch))
                        print("\t\tloss: %.4f, lr: %.2e" % (loss_avg, lr))
                        print("\t\tfg/bg=(%d/%d), time cost: %f" % (fg_cnt, bg_cnt, diff))
                        print("\t\trpn_cls: %.4f, rpn_bbox: %.4f, rcnn_cls: %.4f, rcnn_bbox %.4f"
                            % (loss_rpn_cls, loss_rpn_bbox, loss_rcnn_cls, loss_rcnn_bbox))

                        print_prefix = "\t\t"
                        if cfg.MODEL.MASK_ON:
                            loss_rcnn_mask = loss_rcnn_mask.data[0]
                            print("%srcnn_mask %.4f" % (print_prefix, loss_rcnn_mask))
                            print_prefix = ", "
                        if cfg.MODEL.KEYPOINTS_ON:
                            loss_rcnn_keypoints = loss_rcnn_keypoints.data[0]
                            print("%srcnn_keypoints %.4f" % (print_prefix, loss_rcnn_keypoints))

                        if args.use_tfboard:
                            info = {
                                'loss': loss_avg,
                                'loss_rpn_cls': loss_rpn_cls,
                                'loss_rpn_box': loss_rpn_bbox,
                                'loss_rcnn_cls': loss_rcnn_cls,
                                'loss_rcnn_box': loss_rcnn_bbox,
                            }
                            if cfg.MODEL.MASK_ON:
                                info['loss_rcnn_mask'] = loss_rcnn_mask
                            if cfg.MODEL.KEYPOINTS_ON:
                                info['loss_rcnn_keypoints'] = loss_rcnn_keypoints
                            for tag, value in info.items():
                                tblogger.add_scalar(tag, value, iters_per_epoch * epoch + step)

                    loss_avg = 0
                    timers['train_loop'].tic()

            # ---- End of epoch ----
            # save checkpoint
            net_utils.save_ckpt(output_dir, args, epoch, step, maskRCNN, optimizer, iters_per_epoch)
            # reset timer
            timers['train_loop'].reset()
            # reset starting iter number after first epoch
            args.start_iter = 0

    except (RuntimeError, KeyboardInterrupt) as e:
        print('Save on exception:', e)
        net_utils.save_ckpt(output_dir, args, epoch, step, maskRCNN, optimizer, iters_per_epoch)
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        # ---- Training ends ----
        if args.use_tfboard:
            tblogger.close()
Пример #7
0
def main():

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the training code, sry bro :(.")
    else:
        cfg.CUDA = True
        cfg.NUM_GPUS = torch.cuda.device_count()

    #######~~~.Parameters stuff.~~~#######
    args = parse_args()
    print('Called with args:\n', args)

    # Enables fixed seed
    if args.fixed_seed:
        np.random.seed(cfg.RNG_SEED)
        torch.manual_seed(cfg.RNG_SEED)
        if cfg.CUDA:
            torch.cuda.manual_seed_all(cfg.RNG_SEED)
    torch.backends.cudnn.deterministic = True

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    check_and_overwrite_params(cfg, args)
    assert_and_infer_cfg()

    #indentificantion of a specific running
    args.run_name = misc_utils.get_run_name() + '_step'
    output_dir = misc_utils.get_output_dir(args)

    save_training_config(output_dir, cfg, args)

    if args.use_tfboard:
        from tensorboardX import SummaryWriter
        # Set the Tensorboard logger
        tblogger = SummaryWriter(output_dir)
    else:
        tblogger = None

    #######~~~.Dataset.~~~#######
    timers = defaultdict(Timer)

    # Roi datasets
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb',
                timers['roidb'].average_time)

    batchSampler = BatchSampler(sampler=MinibatchSampler(
        ratio_list, ratio_index),
                                batch_size=args.batch_size,
                                drop_last=True)
    dataset = RoiDataLoader(roidb, cfg.MODEL.NUM_CLASSES, training=True)

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=batchSampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    dataiterator = iter(dataloader)

    ########~~~.Model.~~~#######
    model = eval(args.model).loot_model(args)

    if cfg.CUDA:
        model.cuda()

    optimizer = OptimizerHandler(model, cfg)

    load_ckpt(args.load_ckpt, model, optimizer)

    #######~~~.Training Loop.~~~#######
    try:
        # Effective training sample size for one epoch
        train_size = roidb_size // args.batch_size * args.batch_size
        CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS /
                                (cfg.NUM_GPUS * cfg.TRAIN.ITERATION_SIZE))

        training_stats = TrainingStats(args, tblogger)

        model.train()

        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):

            optimizer.update_learning_rate(step)

            training_stats.IterTic()
            optimizer.zero_grad()
            for inner_iter in range(cfg.TRAIN.ITERATION_SIZE):
                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)

                for key in input_data:
                    if key != 'roidb':  # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))

                # input_data['inner_iter'] = torch.tensor((inner_iter))
                model.set_inner_iter(step)

                im_data = input_data['data'][0].cuda()
                rois = input_data['rois'][0].cuda().type(im_data.dtype)
                labels = input_data['labels'][0].cuda().type(im_data.dtype)

                net_outputs = model(im_data, rois, labels)

                training_stats.UpdateIterStats(net_outputs, inner_iter)
                loss = net_outputs['total_loss']
                loss.backward(retain_graph=True)

            optimizer.step()
            training_stats.IterToc()

            training_stats.LogIterStats(step, optimizer.get_lr())

            if (step + 1) % CHECKPOINT_PERIOD == 0:
                save_ckpt(output_dir, args, step, train_size, model, optimizer)

        # Training ends, saves the last checkpoint
        save_ckpt(output_dir, args, step, train_size, model, optimizer)

    except (RuntimeError, KeyboardInterrupt):
        del dataiterator
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, train_size, model, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
        # optimizer.load_state_dict(checkpoint['optimizer'])
        misc_utils.load_optimizer_state_dict(optimizer, checkpoint['optimizer'])
    del checkpoint
    torch.cuda.empty_cache()

if args.load_detectron:  # TODO resume for detectron weights (load sgd momentum values)
    logging.info("loading Detectron weights %s", args.load_detectron)
    load_detectron_weight(maskRCNN, args.load_detectron)

lr = optimizer.param_groups[0]['lr']  # lr of non-bias parameters, for commmand line outputs.

# maskRCNN = mynn.DataParallel(maskRCNN, cpu_keywords=['im_info', 'roidb'],
#                              minibatch=True)

### Training Setups ###
args.run_name = misc_utils.get_run_name() + '_step'
output_dir = misc_utils.get_output_dir(args, args.run_name)
args.cfg_filename = os.path.basename(args.cfg_file)

if not args.no_save:
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    blob = {'cfg': yaml.dump(cfg), 'args': args}
    with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
        pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

    if args.use_tfboard:
        from tensorboardX import SummaryWriter

        # Set the Tensorboard logger
def main():
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    cfg.DATASET = args.dataset
    if args.dataset == "vg80k":
        cfg.TRAIN.DATASETS = ('vg80k_train', )
        cfg.TEST.DATASETS = ('vg80k_val', )
        cfg.MODEL.NUM_CLASSES = 53305  # includes background
        cfg.MODEL.NUM_PRD_CLASSES = 29086  # excludes background
    elif args.dataset == "vg8k":
        cfg.TRAIN.DATASETS = ('vg8k_train', )
        cfg.TEST.DATASETS = ('vg8k_val', )
        cfg.MODEL.NUM_CLASSES = 5331  # includes background
        cfg.MODEL.NUM_PRD_CLASSES = 2000  # excludes background
    elif args.dataset == "vrd":
        cfg.TRAIN.DATASETS = ('vrd_train', )
        cfg.TEST.DATASETS = ('vrd_val', )
        cfg.MODEL.NUM_CLASSES = 101
        cfg.MODEL.NUM_PRD_CLASSES = 70  # exclude background
    elif args.dataset == "vg":
        cfg.TRAIN.DATASETS = ('vg_train', )
        cfg.TEST.DATASETS = ('vg_val', )
        cfg.MODEL.NUM_CLASSES = 151
        cfg.MODEL.NUM_PRD_CLASSES = 50  # exclude background
    elif args.dataset == "gvqa20k":
        cfg.TRAIN.DATASETS = ('gvqa20k_train', )
        cfg.TEST.DATASETS = ('gvqa20k_val', )
        cfg.MODEL.NUM_CLASSES = 1704  # includes background
        cfg.MODEL.NUM_PRD_CLASSES = 310  # exclude background
    elif args.dataset == "gvqa10k":
        cfg.TRAIN.DATASETS = ('gvqa10k_train', )
        cfg.TEST.DATASETS = ('gvqa10k_val', )
        cfg.MODEL.NUM_CLASSES = 1704  # includes background
        cfg.MODEL.NUM_PRD_CLASSES = 310  # exclude background
    elif args.dataset == "gvqa":
        cfg.TRAIN.DATASETS = ('gvqa_train', )
        cfg.TEST.DATASETS = ('gvqa_val', )
        cfg.MODEL.NUM_CLASSES = 1704  # includes background
        cfg.MODEL.NUM_PRD_CLASSES = 310  # exclude background

    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    if args.seed:
        cfg.RNG_SEED = args.seed

    # Some imports need to be done after loading the config to avoid using default values
    from datasets.roidb_rel import combined_roidb_for_training
    from modeling.model_builder_rel import Generalized_RCNN
    from core.test_engine_rel import run_eval_inference, run_inference
    from core.test_engine_rel import get_inference_dataset, get_roidb_and_dataset

    logger.info('Training with config:')
    logger.info(pprint.pformat(cfg))

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH
    original_num_gpus = cfg.NUM_GPUS
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    effective_batch_size = args.iter_size * args.batch_size
    print('effective_batch_size = batch_size * iter_size = %d * %d' %
          (args.batch_size, args.iter_size))

    print('Adaptive config changes:')
    print('    effective_batch_size: %d --> %d' %
          (original_batch_size, effective_batch_size))
    print('    NUM_GPUS:             %d --> %d' %
          (original_num_gpus, cfg.NUM_GPUS))
    print('    IMS_PER_BATCH:        %d --> %d' %
          (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH))

    ### Adjust learning based on batch size change linearly
    # For iter_size > 1, gradients are `accumulated`, so lr is scaled based
    # on batch_size instead of effective_batch_size
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch_size change:\n'
          '    BASE_LR: {} --> {}'.format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Adjust solver steps
    step_scale = original_batch_size / effective_batch_size
    old_solver_steps = cfg.SOLVER.STEPS
    old_max_iter = cfg.SOLVER.MAX_ITER
    cfg.SOLVER.STEPS = list(
        map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS))
    cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5)
    print(
        'Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n'
        '    SOLVER.STEPS: {} --> {}\n'
        '    SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps,
                                                cfg.SOLVER.STEPS, old_max_iter,
                                                cfg.SOLVER.MAX_ITER))

    if cfg.FPN.FPN_ON and cfg.MODEL.FASTER_RCNN:
        cfg.FPN.RPN_COLLECT_SCALE = cfg.TRAIN.IMS_PER_BATCH / original_ims_per_batch
        print(
            'Scale FPN rpn_proposals collect size directly propotional to the change of IMS_PER_BATCH:\n'
            '    cfg.FPN.RPN_COLLECT_SCALE: {}'.format(
                cfg.FPN.RPN_COLLECT_SCALE))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma
    assert_and_infer_cfg()

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb',
                timers['roidb'].average_time)

    # Effective training sample size for one epoch
    train_size = roidb_size // args.batch_size * args.batch_size

    batchSampler = BatchSampler(sampler=MinibatchSampler(
        ratio_list, ratio_index),
                                batch_size=args.batch_size,
                                drop_last=True)
    dataset = RoiDataLoader(roidb, cfg.MODEL.NUM_CLASSES, training=True)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=batchSampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    dataiterator = iter(dataloader)

    ### Model ###
    maskRCNN = Generalized_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    # record backbone params, i.e., conv_body and box_head params
    gn_params = []
    backbone_bias_params = []
    backbone_bias_param_names = []
    prd_branch_bias_params = []
    prd_branch_bias_param_names = []
    backbone_nonbias_params = []
    backbone_nonbias_param_names = []
    prd_branch_nonbias_params = []
    prd_branch_nonbias_param_names = []

    if cfg.MODEL.DECOUPLE:
        for key, value in dict(maskRCNN.named_parameters()).items():
            if not 'so_sem_embeddings.2' in key:
                value.requires_grad = False

    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'gn' in key:
                gn_params.append(value)
            elif 'Conv_Body' in key or 'Box_Head' in key or 'Box_Outs' in key or 'RPN' in key:
                if 'bias' in key:
                    backbone_bias_params.append(value)
                    backbone_bias_param_names.append(key)
                else:
                    backbone_nonbias_params.append(value)
                    backbone_nonbias_param_names.append(key)
            else:
                if 'bias' in key:
                    prd_branch_bias_params.append(value)
                    prd_branch_bias_param_names.append(key)
                else:
                    prd_branch_nonbias_params.append(value)
                    prd_branch_nonbias_param_names.append(key)
    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [{
        'params': backbone_nonbias_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        backbone_bias_params,
        'lr':
        0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }, {
        'params': prd_branch_nonbias_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        prd_branch_bias_params,
        'lr':
        0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }, {
        'params': gn_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN
    }]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    load_ckpt_dir = './'
    ### Load checkpoint
    if args.load_ckpt_dir:
        load_name = get_checkpoint_resume_file(args.load_ckpt_dir)
        load_ckpt_dir = args.load_ckpt_dir
    elif args.load_ckpt:
        load_name = args.load_ckpt
        load_ckpt_dir = os.path.dirname(args.load_ckpt)

    if args.load_ckpt or args.load_ckpt_dir:
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)

        if cfg.MODEL.DECOUPLE:
            del checkpoint['model']['RelDN.so_sem_embeddings.2.weight']
            del checkpoint['model']['RelDN.so_sem_embeddings.2.bias']
            del checkpoint['model']['RelDN.prd_sem_embeddings.2.weight']
            del checkpoint['model']['RelDN.prd_sem_embeddings.2.bias']

        net_utils.load_ckpt(maskRCNN, checkpoint['model'])
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print(
                        'train_size value: %d different from the one in checkpoint: %d'
                        % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer,
                                                 checkpoint['optimizer'])

        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[2][
        'lr']  # lr of non-backbone parameters, for commmand line outputs.
    backbone_lr = optimizer.param_groups[0][
        'lr']  # lr of backbone parameters, for commmand line outputs.

    prd_categories = maskRCNN.prd_categories
    obj_categories = maskRCNN.obj_categories
    prd_freq_dict = maskRCNN.prd_freq_dict
    obj_freq_dict = maskRCNN.obj_freq_dict

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name() + '_step_with_prd_cls_v' + str(
        cfg.MODEL.SUBTYPE)
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        ckpt_dir = os.path.join(output_dir, 'ckpt')

        if not os.path.exists(ckpt_dir):
            os.makedirs(ckpt_dir)

        # if os.path.exists(os.path.join(ckpt_dir, 'best.json')):
        #     best = json.load(open(os.path.join(ckpt_dir, 'best.json')))
        if args.resume and os.path.exists(
                os.path.join(load_ckpt_dir, 'best.json')):
            logger.info('Loading best json from :' +
                        os.path.join(load_ckpt_dir, 'best.json'))
            best = json.load(open(os.path.join(load_ckpt_dir, 'best.json')))
            json.dump(best, open(os.path.join(ckpt_dir, 'best.json'), 'w'))
        else:
            best = {}
            best['avg_per_class_acc'] = 0.0
            best['iteration'] = 0
            best['accuracies'] = []
            json.dump(best, open(os.path.join(ckpt_dir, 'best.json'), 'w'))

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    args.output_dir = output_dir
    args.do_val = True
    args.use_gt_boxes = True
    args.use_gt_labels = True

    logger.info('Creating val roidb')
    val_dataset_name, val_proposal_file = get_inference_dataset(0)
    val_roidb, val_dataset, start_ind, end_ind, total_num_images = get_roidb_and_dataset(
        val_dataset_name, val_proposal_file, None, args.do_val)
    logger.info('Done')

    ### Training Loop ###
    maskRCNN.train()

    # CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)
    # CHECKPOINT_PERIOD = cfg.SOLVER.MAX_ITER / cfg.TRAIN.SNAPSHOT_FREQ
    CHECKPOINT_PERIOD = 10000
    EVAL_PERIOD = cfg.TRAIN.EVAL_PERIOD
    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):

            maskRCNN.train()
            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 -
                                                                 alpha) + alpha
                else:
                    raise KeyError(
                        'Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate_rel(optimizer, lr, lr_new)
                lr = optimizer.param_groups[2]['lr']
                backbone_lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate_rel(optimizer, lr,
                                                   cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[2]['lr']
                backbone_lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate_rel(optimizer, lr, lr_new)
                lr = optimizer.param_groups[2]['lr']
                backbone_lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()
            for inner_iter in range(args.iter_size):
                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)

                for key in input_data:
                    if key != 'roidb':  # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))

                net_outputs = maskRCNN(**input_data)
                training_stats.UpdateIterStats(net_outputs, inner_iter)
                loss = net_outputs['total_loss']
                loss.backward()

            optimizer.step()
            training_stats.IterToc()

            training_stats.LogIterStats(step, lr, backbone_lr)

            if (step + 1) % EVAL_PERIOD == 0 or (step
                                                 == cfg.SOLVER.MAX_ITER - 1):
                logger.info('Validating model')
                eval_model = maskRCNN.module
                eval_model = mynn.DataParallel(
                    eval_model,
                    cpu_keywords=['im_info', 'roidb'],
                    device_ids=[0],
                    minibatch=True)
                eval_model.eval()
                all_results = run_eval_inference(eval_model,
                                                 val_roidb,
                                                 args,
                                                 val_dataset,
                                                 val_dataset_name,
                                                 val_proposal_file,
                                                 ind_range=None,
                                                 multi_gpu_testing=False,
                                                 check_expected_results=True)
                csv_path = os.path.join(output_dir, 'eval.csv')
                all_results = all_results[0]
                generate_csv_file_from_det_obj(all_results, csv_path,
                                               obj_categories, prd_categories,
                                               obj_freq_dict, prd_freq_dict)
                overall_metrics, per_class_metrics = get_metrics_from_csv(
                    csv_path)
                obj_acc = per_class_metrics[(csv_path, 'obj', 'top1')]
                sbj_acc = per_class_metrics[(csv_path, 'sbj', 'top1')]
                prd_acc = per_class_metrics[(csv_path, 'rel', 'top1')]
                avg_obj_sbj = (obj_acc + sbj_acc) / 2.0
                avg_acc = (prd_acc + avg_obj_sbj) / 2.0

                best = json.load(open(os.path.join(ckpt_dir, 'best.json')))
                if avg_acc > best['avg_per_class_acc']:
                    print('Found new best validation accuracy at {:2.2f}%'.
                          format(avg_acc))
                    print('Saving best model..')
                    best['avg_per_class_acc'] = avg_acc
                    best['iteration'] = step
                    best['per_class_metrics'] = {
                        'obj_top1':
                        per_class_metrics[(csv_path, 'obj', 'top1')],
                        'sbj_top1':
                        per_class_metrics[(csv_path, 'sbj', 'top1')],
                        'prd_top1':
                        per_class_metrics[(csv_path, 'rel', 'top1')]
                    }
                    best['overall_metrics'] = {
                        'obj_top1': overall_metrics[(csv_path, 'obj', 'top1')],
                        'sbj_top1': overall_metrics[(csv_path, 'sbj', 'top1')],
                        'prd_top1': overall_metrics[(csv_path, 'rel', 'top1')]
                    }
                    save_best_ckpt(output_dir, args, step, train_size,
                                   maskRCNN, optimizer)
                    json.dump(best,
                              open(os.path.join(ckpt_dir, 'best.json'), 'w'))

            if (step + 1) % CHECKPOINT_PERIOD == 0:
                print('Saving Checkpoint..')
                save_ckpt(output_dir, args, step, train_size, maskRCNN,
                          optimizer)

        # ---- Training ends ----
        # Save last checkpoint
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

    except Exception as e:
        del dataiterator
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
Пример #10
0
def main():
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    if args.dataset == "coco2017":
        cfg.TRAIN.DATASETS = ('coco_2017_train', )
        cfg.MODEL.NUM_CLASSES = 81
    elif args.dataset == "keypoints_coco2017":
        cfg.TRAIN.DATASETS = ('keypoints_coco_2017_train', )
        cfg.MODEL.NUM_CLASSES = 2

    # ADE20k as a detection dataset
    elif args.dataset == "ade_train":
        cfg.TRAIN.DATASETS = ('ade_train', )
        cfg.MODEL.NUM_CLASSES = 446

    # Noisy CS6+WIDER datasets
    elif args.dataset == 'cs6_noise020+WIDER':
        cfg.TRAIN.DATASETS = ('cs6_noise020', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise050+WIDER':
        cfg.TRAIN.DATASETS = ('cs6_noise050', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise080+WIDER':
        cfg.TRAIN.DATASETS = ('cs6_noise080', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise085+WIDER':
        cfg.TRAIN.DATASETS = ('cs6_noise085', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise090+WIDER':
        cfg.TRAIN.DATASETS = ('cs6_noise090', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise095+WIDER':
        cfg.TRAIN.DATASETS = ('cs6_noise095', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise100+WIDER':
        cfg.TRAIN.DATASETS = ('cs6_noise100', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2

    # Just Noisy CS6 datasets
    elif args.dataset == 'cs6_noise020':
        cfg.TRAIN.DATASETS = ('cs6_noise020', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise030':
        cfg.TRAIN.DATASETS = ('cs6_noise030', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise040':
        cfg.TRAIN.DATASETS = ('cs6_noise040', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise050':
        cfg.TRAIN.DATASETS = ('cs6_noise050', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise060':
        cfg.TRAIN.DATASETS = ('cs6_noise060', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise070':
        cfg.TRAIN.DATASETS = ('cs6_noise070', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise080':
        cfg.TRAIN.DATASETS = ('cs6_noise080', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise085':
        cfg.TRAIN.DATASETS = ('cs6_noise085', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise090':
        cfg.TRAIN.DATASETS = ('cs6_noise090', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise095':
        cfg.TRAIN.DATASETS = ('cs6_noise095', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise100':
        cfg.TRAIN.DATASETS = ('cs6_noise100', )
        cfg.MODEL.NUM_CLASSES = 2

    # Cityscapes 7 classes
    elif args.dataset == "cityscapes":
        cfg.TRAIN.DATASETS = ('cityscapes_train', )
        cfg.MODEL.NUM_CLASSES = 8

    # BDD 7 classes
    elif args.dataset == "bdd_any_any_any":
        cfg.TRAIN.DATASETS = ('bdd_any_any_any_train', )
        cfg.MODEL.NUM_CLASSES = 8
    elif args.dataset == "bdd_any_any_daytime":
        cfg.TRAIN.DATASETS = ('bdd_any_any_daytime_train', )
        cfg.MODEL.NUM_CLASSES = 8
    elif args.dataset == "bdd_clear_any_daytime":
        cfg.TRAIN.DATASETS = ('bdd_clear_any_daytime_train', )
        cfg.MODEL.NUM_CLASSES = 8

    # Cistyscapes Pedestrian sets
    elif args.dataset == "cityscapes_peds":
        cfg.TRAIN.DATASETS = ('cityscapes_peds_train', )
        cfg.MODEL.NUM_CLASSES = 2

    # Cityscapes Car sets
    elif args.dataset == "cityscapes_cars_HPlen3+kitti_car_train":
        cfg.TRAIN.DATASETS = ('cityscapes_cars_HPlen3', 'kitti_car_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cityscapes_cars_HPlen5+kitti_car_train":
        cfg.TRAIN.DATASETS = ('cityscapes_cars_HPlen5', 'kitti_car_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cityscapes_car_train+kitti_car_train":
        cfg.TRAIN.DATASETS = ('cityscapes_car_train', 'kitti_car_train')
        cfg.MODEL.NUM_CLASSES = 2

    # KITTI Car set
    elif args.dataset == "kitti_car_train":
        cfg.TRAIN.DATASETS = ('kitti_car_train', )
        cfg.MODEL.NUM_CLASSES = 2

    # BDD pedestrians sets
    elif args.dataset == "bdd_peds":
        cfg.TRAIN.DATASETS = ('bdd_peds_train',
                              )  # bdd peds: clear_any_daytime
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "bdd_peds_full":
        cfg.TRAIN.DATASETS = ('bdd_peds_full_train', )  # bdd peds: any_any_any
        cfg.MODEL.NUM_CLASSES = 2
    # Pedestrians with constraints
    elif args.dataset == "bdd_peds_not_clear_any_daytime":
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_train', )
        cfg.MODEL.NUM_CLASSES = 2
    # Ashish's  20k samples videos
    elif args.dataset == "bdd_peds_not_clear_any_daytime_20k":
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_20k_train', )
        cfg.MODEL.NUM_CLASSES = 2
    # Source domain + Target domain detections
    elif args.dataset == "bdd_peds+DETS_20k":
        cfg.TRAIN.DATASETS = ('bdd_peds_dets_20k_target_domain',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    # Source domain + Target domain detections -- same 18k images as HP18k
    elif args.dataset == "bdd_peds+DETS18k":
        cfg.TRAIN.DATASETS = ('bdd_peds_dets18k_target_domain',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    # Only Dets
    elif args.dataset == "DETS20k":
        cfg.TRAIN.DATASETS = ('bdd_peds_dets_20k_target_domain', )
        cfg.MODEL.NUM_CLASSES = 2
    # Only Dets18k - same images as HP18k
    elif args.dataset == 'DETS18k':
        cfg.TRAIN.DATASETS = ('bdd_peds_dets18k_target_domain', )
        cfg.MODEL.NUM_CLASSES = 2
    # Only HP
    elif args.dataset == 'HP':
        cfg.TRAIN.DATASETS = ('bdd_peds_HP_target_domain', )
        cfg.MODEL.NUM_CLASSES = 2
    # Only HP 18k videos
    elif args.dataset == 'HP18k':
        cfg.TRAIN.DATASETS = ('bdd_peds_HP18k_target_domain', )
        cfg.MODEL.NUM_CLASSES = 2
    # Source domain + Target domain HP
    elif args.dataset == 'bdd_peds+HP':
        cfg.TRAIN.DATASETS = ('bdd_peds_train', 'bdd_peds_HP_target_domain')
        cfg.MODEL.NUM_CLASSES = 2
        # Source domain + Target domain HP 18k videos
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+HP18k':
        cfg.TRAIN.DATASETS = ('bdd_peds_HP18k_target_domain', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    #### Source domain + Target domain with different conf threshold theta ####
    elif args.dataset == 'bdd_peds+HP18k_thresh-050':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_thresh-050', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+HP18k_thresh-060':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_thresh-060', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+HP18k_thresh-070':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_thresh-070', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+HP18k_thresh-090':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_thresh-090', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    ##############################

    #### Data distillation on BDD -- for rebuttal
    elif args.dataset == 'bdd_peds+bdd_data_dist_small':
        cfg.TRAIN.DATASETS = ('bdd_data_dist_small', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+bdd_data_dist_mid':
        cfg.TRAIN.DATASETS = ('bdd_data_dist_mid', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+bdd_data_dist':
        cfg.TRAIN.DATASETS = ('bdd_data_dist', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    ##############################

    #### Source domain + **Labeled** Target domain with varying number of images
    elif args.dataset == 'bdd_peds+labeled_100':
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_train_100',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+labeled_075':
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_train_075',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+labeled_050':
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_train_050',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+labeled_025':
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_train_025',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+labeled_010':
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_train_010',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+labeled_005':
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_train_005',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+labeled_001':
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_train_001',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    ##############################

    # Source domain + Target domain HP tracker bboxes only
    elif args.dataset == 'bdd_peds+HP18k_track_only':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_track_only', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2

    ##### subsets of bdd_HP18k with different constraints
    # Source domain + HP tracker images at NIGHT
    elif args.dataset == 'bdd_peds+HP18k_any_any_night':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_any_any_night', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == 'bdd_peds+HP18k_rainy_any_daytime':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_rainy_any_daytime', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+HP18k_rainy_any_night':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_rainy_any_night', 'bdd_peds_train')

        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+HP18k_overcast,rainy_any_daytime':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_overcast,rainy_any_daytime',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+HP18k_overcast,rainy_any_night':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_overcast,rainy_any_night',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == 'bdd_peds+HP18k_overcast,rainy,snowy_any_daytime':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_overcast,rainy,snowy_any_daytime',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    #############  end of bdd constraned subsets  #####################

    # Source domain + Target domain HP18k -- after histogram matching
    elif args.dataset == 'bdd_peds+HP18k_remap_hist':
        cfg.TRAIN.DATASETS = ('bdd_peds_HP18k_target_domain_remap_hist',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+HP18k_remap_cityscape_hist':
        cfg.TRAIN.DATASETS = (
            'bdd_peds_HP18k_target_domain_remap_cityscape_hist',
            'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    # Source domain + Target domain HP18k -- after histogram matching
    elif args.dataset == 'bdd_peds+HP18k_remap_random':
        cfg.TRAIN.DATASETS = ('bdd_peds_HP18k_target_domain_remap_random',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    # Source+Noisy Target domain -- prevent domain adv from using HP roi info
    elif args.dataset == 'bdd_peds+bdd_HP18k_noisy_100k':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_noisy_100k', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+bdd_HP18k_noisy_080':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_noisy_080', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+bdd_HP18k_noisy_060':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_noisy_060', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+bdd_HP18k_noisy_070':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_noisy_070', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == "wider_train":
        cfg.TRAIN.DATASETS = ('wider_train', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-subset":
        cfg.TRAIN.DATASETS = ('cs6-subset', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-subset-score":
        cfg.TRAIN.DATASETS = ('cs6-subset-score', )
    elif args.dataset == "cs6-subset-gt":
        cfg.TRAIN.DATASETS = ('cs6-subset-gt', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-3013-gt":
        cfg.TRAIN.DATASETS = ('cs6-3013-gt',
                              )  # DEBUG: overfit on one video annots
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-subset-gt+WIDER":
        cfg.TRAIN.DATASETS = ('cs6-subset-gt', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-subset+WIDER":
        cfg.TRAIN.DATASETS = ('cs6-subset', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-gt":
        cfg.TRAIN.DATASETS = ('cs6-train-gt', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-gt-noisy-0.3":
        cfg.TRAIN.DATASETS = ('cs6-train-gt-noisy-0.3', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-gt-noisy-0.5":
        cfg.TRAIN.DATASETS = ('cs6-train-gt-noisy-0.5', )
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == "cs6-train-det-score":
        cfg.TRAIN.DATASETS = ('cs6-train-det-score', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-det-score-0.5":
        cfg.TRAIN.DATASETS = ('cs6-train-det-score-0.5', )
    elif args.dataset == "cs6-train-det":
        cfg.TRAIN.DATASETS = ('cs6-train-det', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-det-0.5":
        cfg.TRAIN.DATASETS = ('cs6-train-det-0.5', )
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == "cs6-train-hp":
        cfg.TRAIN.DATASETS = ('cs6-train-hp', )
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == "cs6-train-easy-gt":
        cfg.TRAIN.DATASETS = ('cs6-train-easy-gt', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-easy-gt-sub":
        cfg.TRAIN.DATASETS = ('cs6-train-easy-gt-sub', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-easy-hp":
        cfg.TRAIN.DATASETS = ('cs6-train-easy-hp', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-easy-det":
        cfg.TRAIN.DATASETS = ('cs6-train-easy-det', )
        cfg.MODEL.NUM_CLASSES = 2

        # Joint training with CS6 and WIDER
    elif args.dataset == "cs6-train-easy-gt-sub+WIDER":
        cfg.TRAIN.DATASETS = ('cs6-train-easy-gt-sub', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-gt+WIDER":
        cfg.TRAIN.DATASETS = ('cs6-train-gt', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == "cs6-train-hp+WIDER":
        cfg.TRAIN.DATASETS = ('cs6-train-hp', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == "cs6-train-dummy+WIDER":
        cfg.TRAIN.DATASETS = ('cs6-train-dummy', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == "cs6-train-det+WIDER":
        cfg.TRAIN.DATASETS = ('cs6-train-det', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2

    # Dets dataset created by removing tracker results from the HP json
    elif args.dataset == "cs6_train_det_from_hp+WIDER":
        cfg.TRAIN.DATASETS = ('cs6_train_det_from_hp', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    # Dataset created by removing det results from the HP json -- HP tracker only
    elif args.dataset == "cs6_train_hp_tracker_only+WIDER":
        cfg.TRAIN.DATASETS = ('cs6_train_hp_tracker_only', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    # HP dataset with noisy labels: used to prevent DA from getting any info from HP
    elif args.dataset == "cs6_train_hp_noisy_100+WIDER":
        cfg.TRAIN.DATASETS = ('cs6_train_hp_noisy_100', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2

    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH
    original_num_gpus = cfg.NUM_GPUS
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    effective_batch_size = args.iter_size * args.batch_size
    print('effective_batch_size = batch_size * iter_size = %d * %d' %
          (args.batch_size, args.iter_size))

    print('Adaptive config changes:')
    print('    effective_batch_size: %d --> %d' %
          (original_batch_size, effective_batch_size))
    print('    NUM_GPUS:             %d --> %d' %
          (original_num_gpus, cfg.NUM_GPUS))
    print('    IMS_PER_BATCH:        %d --> %d' %
          (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH))

    ### Adjust learning based on batch size change linearly
    # For iter_size > 1, gradients are `accumulated`, so lr is scaled based
    # on batch_size instead of effective_batch_size
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch_size change:\n'
          '    BASE_LR: {} --> {}'.format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Adjust solver steps
    step_scale = original_batch_size / effective_batch_size
    old_solver_steps = cfg.SOLVER.STEPS
    old_max_iter = cfg.SOLVER.MAX_ITER
    cfg.SOLVER.STEPS = list(
        map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS))
    cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5)
    print(
        'Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n'
        '    SOLVER.STEPS: {} --> {}\n'
        '    SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps,
                                                cfg.SOLVER.STEPS, old_max_iter,
                                                cfg.SOLVER.MAX_ITER))

    # Scale FPN rpn_proposals collect size (post_nms_topN) in `collect` function
    # of `collect_and_distribute_fpn_rpn_proposals.py`
    #
    # post_nms_topN = int(cfg[cfg_key].RPN_POST_NMS_TOP_N * cfg.FPN.RPN_COLLECT_SCALE + 0.5)
    if cfg.FPN.FPN_ON and cfg.MODEL.FASTER_RCNN:
        cfg.FPN.RPN_COLLECT_SCALE = cfg.TRAIN.IMS_PER_BATCH / original_ims_per_batch
        print(
            'Scale FPN rpn_proposals collect size directly propotional to the change of IMS_PER_BATCH:\n'
            '    cfg.FPN.RPN_COLLECT_SCALE: {}'.format(
                cfg.FPN.RPN_COLLECT_SCALE))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma
    assert_and_infer_cfg()

    timers = defaultdict(Timer)
    """def _init_fn(worker_id):
        random.seed(999)
        np.random.seed(999)
        torch.cuda.manual_seed(999)
        torch.cuda.manual_seed_all(999)
        torch.manual_seed(999)
        torch.backends.cudnn.deterministic = True
    """

    ### Dataset ###
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb',
                timers['roidb'].average_time)
    if cfg.TRAIN.JOINT_TRAINING:
        if len(cfg.TRAIN.DATASETS) == 2:
            print('Joint training on two datasets')
        else:
            raise NotImplementedError

        joint_training_roidb = []
        for i, dataset_name in enumerate(cfg.TRAIN.DATASETS):
            # ROIDB construction
            timers['roidb'].tic()
            roidb, ratio_list, ratio_index = combined_roidb_for_training(
                (dataset_name), cfg.TRAIN.PROPOSAL_FILES)
            timers['roidb'].toc()
            roidb_size = len(roidb)
            logger.info('{:d} roidb entries'.format(roidb_size))
            logger.info('Takes %.2f sec(s) to construct roidb',
                        timers['roidb'].average_time)

            if i == 0:
                roidb_size = len(roidb)

            batchSampler = BatchSampler(sampler=MinibatchSampler(
                ratio_list, ratio_index),
                                        batch_size=args.batch_size,
                                        drop_last=True)
            dataset = RoiDataLoader(roidb,
                                    cfg.MODEL.NUM_CLASSES,
                                    training=True)
            dataloader = torch.utils.data.DataLoader(
                dataset,
                batch_sampler=batchSampler,
                num_workers=cfg.DATA_LOADER.NUM_THREADS,
                collate_fn=collate_minibatch)
            #worker_init_fn=_init_fn)
            # decrease num-threads when using two dataloaders
            dataiterator = iter(dataloader)

            joint_training_roidb.append({
                'dataloader': dataloader,
                'dataiterator': dataiterator,
                'dataset_name': dataset_name
            })
    else:
        roidb, ratio_list, ratio_index = combined_roidb_for_training(
            cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
        timers['roidb'].toc()
        roidb_size = len(roidb)
        logger.info('{:d} roidb entries'.format(roidb_size))
        logger.info('Takes %.2f sec(s) to construct roidb',
                    timers['roidb'].average_time)

        batchSampler = BatchSampler(sampler=MinibatchSampler(
            ratio_list, ratio_index),
                                    batch_size=args.batch_size,
                                    drop_last=True)
        dataset = RoiDataLoader(roidb, cfg.MODEL.NUM_CLASSES, training=True)
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_sampler=batchSampler,
            num_workers=cfg.DATA_LOADER.NUM_THREADS,
            collate_fn=collate_minibatch)
        #worker_init_fn=init_fn)
        dataiterator = iter(dataloader)

    # Effective training sample size for one epoch
    train_size = roidb_size // args.batch_size * args.batch_size

    if cfg.TRAIN.JOINT_SELECTIVE_FG:
        orig_fg_batch_ratio = cfg.TRAIN.FG_FRACTION

    ### Model ###
    maskRCNN = Generalized_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    gn_param_nameset = set()
    for name, module in maskRCNN.named_modules():
        if isinstance(module, nn.GroupNorm):
            gn_param_nameset.add(name + '.weight')
            gn_param_nameset.add(name + '.bias')
    gn_params = []
    gn_param_names = []
    bias_params = []
    bias_param_names = []
    nonbias_params = []
    nonbias_param_names = []
    nograd_param_names = []
    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
                bias_param_names.append(key)
            elif key in gn_param_nameset:
                gn_params.append(value)
                gn_param_names.append(key)
            else:
                nonbias_params.append(value)
                nonbias_param_names.append(key)
        else:
            nograd_param_names.append(key)
    assert (gn_param_nameset - set(nograd_param_names) -
            set(bias_param_names)) == set(gn_param_names)

    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [{
        'params': nonbias_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        bias_params,
        'lr':
        0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }, {
        'params': gn_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN
    }]
    # names of paramerters for each paramter
    param_names = [nonbias_param_names, bias_param_names, gn_param_names]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print(
                        'train_size value: %d different from the one in checkpoint: %d'
                        % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer,
                                                 checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[0][
        'lr']  # lr of non-bias parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name() + '_' + str(
        cfg.TRAIN.DATASETS) + '_step'
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):
            """
            random.seed(cfg.RNG_SEED)
            np.random.seed(cfg.RNG_SEED)
            torch.cuda.manual_seed(cfg.RNG_SEED)
            torch.cuda.manual_seed_all(cfg.RNG_SEED)
            torch.manual_seed(cfg.RNG_SEED)
            torch.backends.cudnn.deterministic = True
            """

            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 -
                                                                 alpha) + alpha
                else:
                    raise KeyError(
                        'Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate(optimizer, lr,
                                               cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()

            for inner_iter in range(args.iter_size):
                # use a iter counter for optional alternating batches
                if args.iter_size == 1:
                    iter_counter = step
                else:
                    iter_counter = inner_iter

                if cfg.TRAIN.JOINT_TRAINING:
                    # alternate batches between dataset[0] and dataset[1]
                    if iter_counter % 2 == 0:
                        if True:  #DEBUG:
                            print('Dataset: %s' %
                                  joint_training_roidb[0]['dataset_name'])
                        dataloader = joint_training_roidb[0]['dataloader']
                        dataiterator = joint_training_roidb[0]['dataiterator']
                        # NOTE: if available FG samples cannot fill minibatch
                        # then batchsize will be smaller than cfg.TRAIN.BATCH_SIZE_PER_IM.
                    else:
                        if True:  #DEBUG:
                            print('Dataset: %s' %
                                  joint_training_roidb[1]['dataset_name'])
                        dataloader = joint_training_roidb[1]['dataloader']
                        dataiterator = joint_training_roidb[1]['dataiterator']

                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    # end of epoch for dataloader
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)
                    if cfg.TRAIN.JOINT_TRAINING:
                        if iter_counter % 2 == 0:
                            joint_training_roidb[0][
                                'dataiterator'] = dataiterator
                        else:
                            joint_training_roidb[1][
                                'dataiterator'] = dataiterator

                for key in input_data:
                    if key != 'roidb':  # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))

                net_outputs = maskRCNN(**input_data)
                training_stats.UpdateIterStats(net_outputs, inner_iter)
                loss = net_outputs['total_loss']
                # [p.data.get_device() for p in maskRCNN.parameters()]
                # [(name, p.data.get_device()) for name, p in maskRCNN.named_parameters()]
                loss.backward()

            optimizer.step()
            training_stats.IterToc()

            training_stats.LogIterStats(step, lr)

            if (step + 1) % CHECKPOINT_PERIOD == 0:
                save_ckpt(output_dir, args, step, train_size, maskRCNN,
                          optimizer)

        # ---- Training ends ----
        # Save last checkpoint
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

    except (RuntimeError, KeyboardInterrupt):
        del dataiterator
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
Пример #11
0
def main():
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    if args.dataset == "davis2017":
        cfg.TRAIN.DATASETS = ('davis_train', )
        #For davis, coco category is used.
        cfg.MODEL.NUM_CLASSES = 81  #80 foreground + 1 background
    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    if cfg.MODEL.IDENTITY_TRAINING and cfg.MODEL.IDENTITY_REPLACE_CLASS:
        cfg.MODEL.NUM_CLASSES = 145
        cfg.MODEL.IDENTITY_TRAINING = False
        cfg.MODEL.ADD_UNKNOWN_CLASS = False

    #Add unknow class type if necessary.
    if cfg.MODEL.ADD_UNKNOWN_CLASS is True:
        cfg.MODEL.NUM_CLASSES += 1

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH
    original_num_gpus = cfg.NUM_GPUS
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    effective_batch_size = args.iter_size * args.batch_size

    print('effective_batch_size = batch_size * iter_size = %d * %d' %
          (args.batch_size, args.iter_size))
    print('Adaptive config changes:')
    print('    effective_batch_size: %d --> %d' %
          (original_batch_size, effective_batch_size))
    print('    NUM_GPUS:             %d --> %d' %
          (original_num_gpus, cfg.NUM_GPUS))
    print('    IMS_PER_BATCH:        %d --> %d' %
          (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH))

    ### Adjust learning based on batch size change linearly
    # For iter_size > 1, gradients are `accumulated`, so lr is scaled based
    # on batch_size instead of effective_batch_size
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    cfg.SOLVER.BASE_LR *= 1.0 / cfg.MODEL.SEQUENCE_LENGTH
    print(
        'Adjust BASE_LR linearly according to batch_size change and sequence length change:\n'
        '    BASE_LR: {} --> {}'.format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Adjust solver steps
    step_scale = original_batch_size / effective_batch_size
    old_solver_steps = cfg.SOLVER.STEPS
    old_max_iter = cfg.SOLVER.MAX_ITER
    cfg.SOLVER.STEPS = list(
        map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS))
    cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5)
    print(
        'Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n'
        '    SOLVER.STEPS: {} --> {}\n'
        '    SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps,
                                                cfg.SOLVER.STEPS, old_max_iter,
                                                cfg.SOLVER.MAX_ITER))

    # Scale FPN rpn_proposals collect size (post_nms_topN) in `collect` function
    # of `collect_and_distribute_fpn_rpn_proposals.py`
    #
    # post_nms_topN = int(cfg[cfg_key].RPN_POST_NMS_TOP_N * cfg.FPN.RPN_COLLECT_SCALE + 0.5)
    if cfg.FPN.FPN_ON and cfg.MODEL.FASTER_RCNN:
        cfg.FPN.RPN_COLLECT_SCALE = cfg.TRAIN.IMS_PER_BATCH / original_ims_per_batch
        print(
            'Scale FPN rpn_proposals collect size directly propotional to the change of IMS_PER_BATCH:\n'
            '    cfg.FPN.RPN_COLLECT_SCALE: {}'.format(
                cfg.FPN.RPN_COLLECT_SCALE))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma
    assert_and_infer_cfg()

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    merged_roidb, seq_num, seq_start_end = sequenced_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES, load_inv_db=True)

    timers['roidb'].toc()
    roidb_size = len(merged_roidb)
    logger.info('{:d} roidbs sequences.'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidbs',
                timers['roidb'].average_time)

    # Effective training sample size for one epoch, number of sequences.
    train_size = roidb_size // args.batch_size * args.batch_size

    ### Model ###
    maskRCNN = vos_model_builder.Generalized_VOS_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    gn_param_nameset = set()
    for name, module in maskRCNN.named_modules():
        if isinstance(module, nn.GroupNorm):
            gn_param_nameset.add(name + '.weight')
            gn_param_nameset.add(name + '.bias')
    gn_params = []
    gn_param_names = []
    bias_params = []
    bias_param_names = []
    nonbias_params = []
    nonbias_param_names = []
    nograd_param_names = []
    for key, value in maskRCNN.named_parameters():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
                bias_param_names.append(key)
            elif key in gn_param_nameset:
                gn_params.append(value)
                gn_param_names.append(key)
            else:
                nonbias_params.append(value)
                nonbias_param_names.append(key)
        else:
            nograd_param_names.append(key)
    assert (gn_param_nameset - set(nograd_param_names) -
            set(bias_param_names)) == set(gn_param_names)

    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [{
        'params': nonbias_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        bias_params,
        'lr':
        0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }, {
        'params': gn_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN
    }]
    # names of paramerters for each paramter
    param_names = [nonbias_param_names, bias_param_names, gn_param_names]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils.load_ckpt_no_mapping(maskRCNN, checkpoint['model'])
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print(
                        'train_size value: %d different from the one in checkpoint: %d'
                        % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            optimizer.load_state_dict(checkpoint['optimizer'])
            # misc_utils.load_optimizer_state_dict(optimizer, checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN,
                              args.load_detectron,
                              force_load_all=False)

    lr = optimizer.param_groups[0][
        'lr']  # lr of non-bias parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name() + '_step'
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    dataloader, dataiterator, warmup_length = gen_sequence_data_sampler(
        merged_roidb, seq_num, seq_start_end, use_seq_warmup=False)
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):
            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 -
                                                                 alpha) + alpha
                else:
                    raise KeyError(
                        'Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate(optimizer, lr,
                                               cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()

            if cfg.TRAIN.USE_SEQ_WARMUP is True and step >= cfg.TRAIN.ITER_BEFORE_USE_SEQ_WARMUP and (
                (step - cfg.TRAIN.ITER_BEFORE_USE_SEQ_WARMUP) %
                    cfg.TRAIN.WARMUP_LENGTH_CHANGE_STEP) == 0:
                #TODO better dataiterator creator.
                raise NotImplementedError('not able to delete.')
                dataloader, dataiterator, warmup_length = gen_sequence_data_sampler(
                    merged_roidb, seq_num, seq_start_end, use_seq_warmup=True)
                print('update warmup length:', warmup_length)
            try:
                input_data_sequence = next(dataiterator)
            except StopIteration:
                dataiterator = iter(dataloader)
                input_data_sequence = next(dataiterator)

            # clean hidden states before training.
            maskRCNN.module.clean_hidden_states()
            maskRCNN.module.clean_flow_features()

            assert len(input_data_sequence['data']
                       ) == cfg.MODEL.SEQUENCE_LENGTH + warmup_length, print(
                           len(input_data_sequence['data']), '!=',
                           cfg.MODEL.SEQUENCE_LENGTH + warmup_length)

            # if train_part == 0: train backbone.
            # else train_part == 1: train heads.
            train_part = 1
            if cfg.TRAIN.ALTERNATE_TRAINING:
                if step % 2 == 0:
                    maskRCNN.module.freeze_conv_body_only()
                    train_part = 1
                else:
                    maskRCNN.module.train_conv_body_only()
                    train_part = 0
            # this is used for longer sequence training.
            # when reach maximum trainable length, detach the hidden states.
            cnter_for_detach_hidden_states = 0
            for inner_iter in range(cfg.MODEL.SEQUENCE_LENGTH + warmup_length):
                #get input_data
                input_data = {}
                for key in input_data_sequence.keys():
                    input_data[key] = input_data_sequence[key][
                        inner_iter:inner_iter + 1]
                for key in input_data:
                    if key != 'roidb' and key != 'data_flow':  # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))
                    if key == 'data_flow':
                        if inner_iter != 0 and input_data[key][0][0][
                                0] is not None:  # flow is not None.
                            input_data[key] = [
                                Variable(
                                    torch.tensor(
                                        np.expand_dims(
                                            np.squeeze(
                                                np.array(input_data[key],
                                                         dtype=np.float32)),
                                            0),
                                        device=input_data['data'][0].device))
                            ]
                            assert input_data['data'][0].shape[
                                -2:] == input_data[key][0].shape[
                                    -2:], "Spatial shape of image and flow are not equal."
                        else:
                            input_data[key] = [None]
                if cfg.TRAIN.USE_SEQ_WARMUP and inner_iter < warmup_length:
                    maskRCNN.module.set_stop_after_hidden_states(stop=True)
                    net_outputs = maskRCNN(**input_data)
                    assert net_outputs is None
                    maskRCNN.module.detach_hidden_states()
                    maskRCNN.module.set_stop_after_hidden_states(stop=False)
                    continue

                net_outputs = maskRCNN(**input_data)
                training_stats.UpdateIterStats(net_outputs, inner_iter)
                loss = net_outputs['total_loss']

                if train_part == 0 or cnter_for_detach_hidden_states >= cfg.TRAIN.MAX_TRAINABLE_SEQ_LENGTH or inner_iter == cfg.MODEL.SEQUENCE_LENGTH + warmup_length - 1:
                    # if reach the max trainable length or end of the sequence. free the graph and detach the hidden states.
                    loss.backward()
                    maskRCNN.module.detach_hidden_states()
                    cnter_for_detach_hidden_states = 0
                else:
                    loss.backward(retain_graph=True)
                    cnter_for_detach_hidden_states += 1
                #TODO step every time?
                if cnter_for_detach_hidden_states == 0:
                    optimizer.step()
                    optimizer.zero_grad()

            training_stats.IterToc()
            training_stats.LogIterStats(step, lr)

            if (step + 1) % CHECKPOINT_PERIOD == 0:
                save_ckpt(output_dir, args, step, train_size, maskRCNN,
                          optimizer)

        # ---- Training ends ----
        # Save last checkpoint
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

    except (RuntimeError):
        del dataiterator
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
    except (KeyboardInterrupt):
        del dataiterator
        logger.info('Save ckpt on Keyboard Interrupt ...')
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
def main():
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    if args.dataset == "coco2017":
        cfg.TRAIN.DATASETS = ('coco_2017_train',)
        cfg.MODEL.NUM_CLASSES = 81
    elif args.dataset == "keypoints_coco2017":
        cfg.TRAIN.DATASETS = ('keypoints_coco_2017_train',)
        cfg.MODEL.NUM_CLASSES = 2
    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH
    original_num_gpus = cfg.NUM_GPUS
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    effective_batch_size = args.iter_size * args.batch_size
    print('effective_batch_size = batch_size * iter_size = %d * %d' % (args.batch_size, args.iter_size))

    print('Adaptive config changes:')
    print('    effective_batch_size: %d --> %d' % (original_batch_size, effective_batch_size))
    print('    NUM_GPUS:             %d --> %d' % (original_num_gpus, cfg.NUM_GPUS))
    print('    IMS_PER_BATCH:        %d --> %d' % (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH))

    ### Adjust learning based on batch size change linearly
    # For iter_size > 1, gradients are `accumulated`, so lr is scaled based
    # on batch_size instead of effective_batch_size
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch_size change:\n'
          '    BASE_LR: {} --> {}'.format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Adjust solver steps
    step_scale = original_batch_size / effective_batch_size
    old_solver_steps = cfg.SOLVER.STEPS
    old_max_iter = cfg.SOLVER.MAX_ITER
    cfg.SOLVER.STEPS = list(map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS))
    cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5)
    print('Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n'
          '    SOLVER.STEPS: {} --> {}\n'
          '    SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps, cfg.SOLVER.STEPS,
                                                  old_max_iter, cfg.SOLVER.MAX_ITER))

    # Scale FPN rpn_proposals collect size (post_nms_topN) in `collect` function
    # of `collect_and_distribute_fpn_rpn_proposals.py`
    #
    # post_nms_topN = int(cfg[cfg_key].RPN_POST_NMS_TOP_N * cfg.FPN.RPN_COLLECT_SCALE + 0.5)
    if cfg.FPN.FPN_ON and cfg.MODEL.FASTER_RCNN:
        cfg.FPN.RPN_COLLECT_SCALE = cfg.TRAIN.IMS_PER_BATCH / original_ims_per_batch
        print('Scale FPN rpn_proposals collect size directly propotional to the change of IMS_PER_BATCH:\n'
              '    cfg.FPN.RPN_COLLECT_SCALE: {}'.format(cfg.FPN.RPN_COLLECT_SCALE))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma
    assert_and_infer_cfg()

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb', timers['roidb'].average_time)

    # Effective training sample size for one epoch
    train_size = roidb_size // args.batch_size * args.batch_size

    batchSampler = BatchSampler(
        sampler=MinibatchSampler(ratio_list, ratio_index),
        batch_size=args.batch_size,
        drop_last=True
    )
    dataset = RoiDataLoader(
        roidb,
        cfg.MODEL.NUM_CLASSES,
        training=True)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=batchSampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    dataiterator = iter(dataloader)

    ### Model ###
    maskRCNN = Generalized_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    gn_param_nameset = set()
    for name, module in maskRCNN.named_modules():
        if isinstance(module, nn.GroupNorm):
            gn_param_nameset.add(name+'.weight')
            gn_param_nameset.add(name+'.bias')
    gn_params = []
    gn_param_names = []
    bias_params = []
    bias_param_names = []
    nonbias_params = []
    nonbias_param_names = []
    nograd_param_names = []
    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
                bias_param_names.append(key)
            elif key in gn_param_nameset:
                gn_params.append(value)
                gn_param_names.append(key)
            else:
                nonbias_params.append(value)
                nonbias_param_names.append(key)
        else:
            nograd_param_names.append(key)
    assert (gn_param_nameset - set(nograd_param_names) - set(bias_param_names)) == set(gn_param_names)

    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [
        {'params': nonbias_params,
         'lr': 0,
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY},
        {'params': bias_params,
         'lr': 0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0},
        {'params': gn_params,
         'lr': 0,
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN}
    ]
    # names of paramerters for each paramter
    param_names = [nonbias_param_names, bias_param_names, gn_param_names]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name, map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print('train_size value: %d different from the one in checkpoint: %d'
                          % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer, checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[0]['lr']  # lr of non-bias parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN, cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name() + '_step'
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args,
        args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):

            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 - alpha) + alpha
                else:
                    raise KeyError('Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate(optimizer, lr, cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()
            for inner_iter in range(args.iter_size):
                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)

                for key in input_data:
                    if key != 'roidb': # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))

                net_outputs = maskRCNN(**input_data)
                training_stats.UpdateIterStats(net_outputs, inner_iter)
                loss = net_outputs['total_loss']
                loss.backward()
            optimizer.step()
            training_stats.IterToc()

            training_stats.LogIterStats(step, lr)

            if (step+1) % CHECKPOINT_PERIOD == 0:
                save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

        # ---- Training ends ----
        # Save last checkpoint
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

    except (RuntimeError, KeyboardInterrupt):
        del dataiterator
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
Пример #13
0
def main():
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")
    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    if args.dataset == "coco2017":
        cfg.TRAIN.DATASETS = ('coco_2017_train', )
        cfg.MODEL.NUM_CLASSES = 81
    elif args.dataset == "coco2014":
        cfg.TRAIN.DATASETS = ('coco_2014_train', )
        cfg.MODEL.NUM_CLASSES = 81
    elif args.dataset == 'vcoco_trainval':
        cfg.TRAIN.DATASETS = ('vcoco_trainval', )
        cfg.MODEL.NUM_CLASSES = 81
    elif args.dataset == 'vcoco_train':
        cfg.TRAIN.DATASETS = ('vcoco_train', )
        cfg.MODEL.NUM_CLASSES = 81
    elif args.dataset == 'vcoco_val':
        cfg.TRAIN.DATASETS = ('vcoco_val', )
        cfg.MODEL.NUM_CLASSES = 81
    elif args.dataset == 'keypoints_coco2014':
        cfg.TRAIN.DATASETS = ('keypoints_coco_2014_train', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "keypoints_coco2017":
        cfg.TRAIN.DATASETS = ('keypoints_coco_2017_train', )
        cfg.MODEL.NUM_CLASSES = 2
    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    if args.vcoco_kp_on:
        cfg.VCOCO.KEYPOINTS_ON = True

    cfg.NETWORK_NAME = args.net_name  # network name
    print('Network name:', args.net_name)

    cfg.MODEL.CONV_BODY = args.conv_body  # backbone network name
    print('Conv_body name:', args.conv_body)

    cfg.TRAIN.FG_THRESH = args.fg_thresh
    print('Train fg thresh:', args.fg_thresh)

    cfg.RESNETS.FREEZE_AT = args.freeze_at
    print('Freeze at: ', args.freeze_at)

    cfg.VCOCO.MLP_HEAD_DIM = args.mlp_head_dim
    print('MLP head dim: ', args.mlp_head_dim)

    cfg.SOLVER.MAX_ITER = args.max_iter
    print('MAX iter: ', args.max_iter)

    cfg.TRAIN.SNAPSHOT_ITERS = args.snapshot
    print('Snapshot Iters: ', args.snapshot)

    if args.solver_steps is not None:
        cfg.SOLVER.STEPS = args.solver_steps
    print('Solver_steps: ', cfg.SOLVER.STEPS)

    cfg.VCOCO.TRIPLETS_NUM_PER_IM = args.triplets_num_per_im
    print('triplets_num_per_im: ', cfg.VCOCO.TRIPLETS_NUM_PER_IM)

    cfg.VCOCO.HEATMAP_KERNEL_SIZE = args.heatmap_kernel_size
    print('heatmap_kernel_size: ', cfg.VCOCO.HEATMAP_KERNEL_SIZE)

    cfg.VCOCO.PART_CROP_SIZE = args.part_crop_size
    print('part_crop_size: ', cfg.VCOCO.PART_CROP_SIZE)

    print('use use_kps17 for part Align: ', args.use_kps17)
    if args.use_kps17:
        cfg.VCOCO.USE_KPS17 = True
    else:
        cfg.VCOCO.USE_KPS17 = False

    print('MULTILEVEL_ROIS: ', cfg.FPN.MULTILEVEL_ROIS)

    if args.vcoco_use_spatial:
        cfg.VCOCO.USE_SPATIAL = True

    if args.vcoco_use_union_feat:
        cfg.VCOCO.USE_UNION_FEAT = True

    if args.use_precomp_box:
        cfg.VCOCO.USE_PRECOMP_BOX = True

    cfg.DEBUG_TEST_WITH_GT = True

    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH  # 16
    original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH
    original_num_gpus = cfg.NUM_GPUS
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    effective_batch_size = args.iter_size * args.batch_size
    print('effective_batch_size = batch_size * iter_size = %d * %d' %
          (args.batch_size, args.iter_size))

    print('Adaptive config changes:')
    print('    effective_batch_size: %d --> %d' %
          (original_batch_size, effective_batch_size))
    print('    NUM_GPUS:             %d --> %d' %
          (original_num_gpus, cfg.NUM_GPUS))
    print('    IMS_PER_BATCH:        %d --> %d' %
          (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH))
    print('    FG_THRESH: ', cfg.TRAIN.FG_THRESH)
    ### Adjust learning based on batch size change linearly
    # For iter_size > 1, gradients are `accumulated`, so lr is scaled based
    # on batch_size instead of effective_batch_size
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch_size change:\n'
          '    BASE_LR: {} --> {}'.format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Adjust solver steps
    step_scale = original_batch_size / effective_batch_size
    old_solver_steps = cfg.SOLVER.STEPS
    old_max_iter = cfg.SOLVER.MAX_ITER
    cfg.SOLVER.STEPS = list(
        map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS))
    cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5)
    cfg.SOLVER.VAL_ITER = int(cfg.SOLVER.VAL_ITER * step_scale + 0.5)
    cfg.TRAIN.SNAPSHOT_ITERS = int(cfg.TRAIN.SNAPSHOT_ITERS * step_scale + 0.5)
    print(
        'Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n'
        '    SOLVER.STEPS: {} --> {}\n'
        '    SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps,
                                                cfg.SOLVER.STEPS, old_max_iter,
                                                cfg.SOLVER.MAX_ITER))

    # Scale FPN rpn_proposals collect size (post_nms_topN) in `collect` function
    # of `collect_and_distribute_fpn_rpn_proposals.py`
    #
    # post_nms_topN = int(cfg[cfg_key].RPN_POST_NMS_TOP_N * cfg.FPN.RPN_COLLECT_SCALE + 0.5)
    if cfg.FPN.FPN_ON and cfg.MODEL.FASTER_RCNN:
        cfg.FPN.RPN_COLLECT_SCALE = cfg.TRAIN.IMS_PER_BATCH / original_ims_per_batch
        print(
            'Scale FPN rpn_proposals collect size directly propotional to the change of IMS_PER_BATCH:\n'
            '    cfg.FPN.RPN_COLLECT_SCALE: {}'.format(
                cfg.FPN.RPN_COLLECT_SCALE))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    # ipdb.set_trace()
    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer

    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma
    assert_and_infer_cfg()

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb',
                timers['roidb'].average_time)

    # Effective training sample size for one epoch
    train_size = roidb_size // args.batch_size * args.batch_size
    # ToDo: shuffle?
    batchSampler = BatchSampler(sampler=MinibatchSampler(
        ratio_list, ratio_index),
                                batch_size=args.batch_size,
                                drop_last=True)

    dataset = RoiDataLoader(roidb, cfg.MODEL.NUM_CLASSES, training=True)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=batchSampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    # dataiterator = iter(dataloader)

    ### Model ###
    from modeling.model_builder import Generalized_RCNN
    maskRCNN = Generalized_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    bias_hoi_params = []
    bias_hoi_param_names = []
    bias_faster_params = []
    bias_faster_param_names = []
    nobias_hoi_params = []
    nobias_hoi_param_names = []
    nobias_faster_params = []
    nobias_faster_param_names = []

    # bias_params = []
    # bias_param_names = []
    # nonbias_params = []
    # nonbias_param_names = []

    #base_model = torch.load('Outputs/baseline/baseline_512_32_nogt_1o3/ckpt/model_step47999.pth')

    nograd_param_names = []
    for key, value in maskRCNN.named_parameters():
        #if key in base_model['model'].keys():
        #   value.requires_grad = False

        #print('the key xxx:', key)
        # Fix RPN module same as the paper
        # ToDo: or key.startswith('Box')
        # if 'affinity' not in key:
        #     value.requires_grad = False

        print(key, value.size(), value.requires_grad)
        if value.requires_grad:
            if 'bias' in key:
                if 'HOI_Head' in key:
                    bias_hoi_params.append(value)
                    bias_hoi_param_names.append(key)
                else:
                    bias_faster_params.append(value)
                    bias_faster_param_names.append(key)
            else:
                if 'HOI_Head' in key:
                    nobias_hoi_params.append(value)
                    nobias_hoi_param_names.append(key)
                else:
                    nobias_faster_params.append(value)
                    nobias_faster_param_names.append(key)
        else:
            nograd_param_names.append(key)

    #del base_model
    #ipdb.set_trace()

    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [
        {
            'params': nobias_hoi_params,
            'lr': 0,
            'weight_decay': cfg.SOLVER.WEIGHT_DECAY
        },
        {
            'params': nobias_faster_params,
            'lr': 0 * cfg.SOLVER.FASTER_RCNN_WEIGHT,
            'weight_decay': cfg.SOLVER.WEIGHT_DECAY
        },
        {
            'params':
            bias_hoi_params,
            'lr':
            0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
            'weight_decay':
            cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
        },
        {
            'params':
            bias_faster_params,
            'lr':
            0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1) *
            cfg.SOLVER.FASTER_RCNN_WEIGHT,
            'weight_decay':
            cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
        },
    ]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        if args.krcnn_from_faster:
            net_utils.load_krcnn_from_faster(maskRCNN, checkpoint['model'])
        else:
            net_utils.load_ckpt(maskRCNN, checkpoint['model'])
            print('Original model loaded....')
        if args.resume:
            print('Resume, loaded step\n\n\n: ', checkpoint['step'])
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print(
                        'train_size value: %d different from the one in checkpoint: %d'
                        % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            optimizer.load_state_dict(checkpoint['optimizer'])
            # misc_utils.load_optimizer_state_dict(optimizer, checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[0][
        'lr']  # lr of non-bias parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name() + '_step'
    #output_dir = misc_utils.get_output_dir(args, args.run_name)
    output_dir = os.path.join('Outputs', args.expDir, args.expID)
    os.makedirs(output_dir, exist_ok=True)

    args.cfg_filename = os.path.basename(args.cfg_file)

    tblogger = None
    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)
    if args.expID.__contains__('base'):
        os.environ['FABRICATOR'] = 'base'
    else:
        os.environ['FABRICATOR'] = 'fcl'

    print('log', os.environ['FABRICATOR'])
    ### Training Loop ###
    train_val(maskRCNN, args, optimizer, lr, dataloader, train_size,
              output_dir, tblogger)
def main():
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    if args.dataset == "coco2017":
        cfg.TRAIN.DATASETS = ('coco_2017_train', )
        cfg.MODEL.NUM_CLASSES = 81
    elif args.dataset == "keypoints_coco2017":
        cfg.TRAIN.DATASETS = ('keypoints_coco_2017_train', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "common":
        cfg.TRAIN.DATASETS = ('common_train', )
        cfg.MODEL.NUM_CLASSES = 81
    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    print('Batch size change from {} (in config file) to {}'.format(
        original_batch_size, args.batch_size))
    print('NUM_GPUs: %d, TRAIN.IMS_PER_BATCH: %d' %
          (cfg.NUM_GPUS, cfg.TRAIN.IMS_PER_BATCH))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Adjust learning based on batch size change linearly
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch size change: {} --> {}'.
          format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    train_size = len(roidb)
    logger.info('{:d} roidb entries'.format(train_size))
    logger.info('Takes %.2f sec(s) to construct roidb',
                timers['roidb'].average_time)

    sampler = MinibatchSampler(ratio_list, ratio_index)
    dataset = RoiDataLoader(roidb, cfg.MODEL.NUM_CLASSES, training=True)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        sampler=sampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)

    assert_and_infer_cfg()

    ### Model ###
    maskRCNN = Generalized_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    bias_params = []
    nonbias_params = []
    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
            else:
                nonbias_params.append(value)
    params = [{
        'params': nonbias_params,
        'lr': cfg.SOLVER.BASE_LR,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        bias_params,
        'lr':
        cfg.SOLVER.BASE_LR * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])
        if args.resume:
            assert checkpoint['iters_per_epoch'] == train_size // args.batch_size, \
                "iters_per_epoch should match for resume"
            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer,
                                                 checkpoint['optimizer'])
            if checkpoint['step'] == (checkpoint['iters_per_epoch'] - 1):
                # Resume from end of an epoch
                args.start_epoch = checkpoint['epoch'] + 1
                args.start_iter = 0
            else:
                # Resume from the middle of an epoch.
                # NOTE: dataloader is not synced with previous state
                args.start_epoch = checkpoint['epoch']
                args.start_iter = checkpoint['step'] + 1
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[0][
        'lr']  # lr of non-bias parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name()
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    training_stats = TrainingStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)

    iters_per_epoch = int(train_size / args.batch_size)  # drop last
    args.iters_per_epoch = iters_per_epoch
    ckpt_interval_per_epoch = iters_per_epoch // args.ckpt_num_per_epoch
    try:
        logger.info('Training starts !')
        args.step = args.start_iter
        global_step = iters_per_epoch * args.start_epoch + args.step
        for args.epoch in range(args.start_epoch,
                                args.start_epoch + args.num_epochs):
            # ---- Start of epoch ----

            # adjust learning rate
            if args.lr_decay_epochs and args.epoch == args.lr_decay_epochs[
                    0] and args.start_iter == 0:
                args.lr_decay_epochs.pop(0)
                net_utils.decay_learning_rate(optimizer, lr, cfg.SOLVER.GAMMA)
                lr *= cfg.SOLVER.GAMMA

            for args.step, input_data in zip(
                    range(args.start_iter, iters_per_epoch), dataloader):

                for key in input_data:
                    if key != 'roidb':  # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))

                training_stats.IterTic()
                net_outputs = maskRCNN(**input_data)
                training_stats.UpdateIterStats(net_outputs)
                loss = net_outputs['total_loss']
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                training_stats.IterToc()

                if (args.step + 1) % ckpt_interval_per_epoch == 0:
                    net_utils.save_ckpt(output_dir, args, maskRCNN, optimizer)

                if args.step % args.disp_interval == 0:
                    log_training_stats(training_stats, global_step, lr)

                global_step += 1

            # ---- End of epoch ----
            # save checkpoint
            net_utils.save_ckpt(output_dir, args, maskRCNN, optimizer)
            # reset starting iter number after first epoch
            args.start_iter = 0

        # ---- Training ends ----
        if iters_per_epoch % args.disp_interval != 0:
            # log last stats at the end
            log_training_stats(training_stats, global_step, lr)

    except (RuntimeError, KeyboardInterrupt):
        logger.info('Save ckpt on exception ...')
        net_utils.save_ckpt(output_dir, args, maskRCNN, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
Пример #15
0
def main():
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    cfg.DATASET = args.dataset
    if args.dataset == "vrd":
        cfg.TRAIN.DATASETS = ('vrd_train',)
        cfg.MODEL.NUM_CLASSES = 101
        cfg.MODEL.NUM_PRD_CLASSES = 70  # exclude background
    elif args.dataset == "vg":
        cfg.TRAIN.DATASETS = ('vg_train',)
        cfg.MODEL.NUM_CLASSES = 151
        cfg.MODEL.NUM_PRD_CLASSES = 50  # exclude background
    elif args.dataset == "vg80k":
        cfg.TRAIN.DATASETS = ('vg80k_train',)
        cfg.MODEL.NUM_CLASSES = 53305 # includes background
        cfg.MODEL.NUM_PRD_CLASSES = 29086  # excludes background
    elif args.dataset == "gvqa20k":
        cfg.TRAIN.DATASETS = ('gvqa20k_train',)
        cfg.MODEL.NUM_CLASSES = 1704 # includes background
        cfg.MODEL.NUM_PRD_CLASSES = 310  # exclude background
    elif args.dataset == "gvqa10k":
        cfg.TRAIN.DATASETS = ('gvqa10k_train',)
        cfg.MODEL.NUM_CLASSES = 1704 # includes background
        cfg.MODEL.NUM_PRD_CLASSES = 310  # exclude background
    elif args.dataset == "gvqa":
        cfg.TRAIN.DATASETS = ('gvqa_train',)
        cfg.MODEL.NUM_CLASSES = 1704 # includes background
        cfg.MODEL.NUM_PRD_CLASSES = 310  # exclude background

    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH
    original_num_gpus = cfg.NUM_GPUS
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    effective_batch_size = args.iter_size * args.batch_size
    print('effective_batch_size = batch_size * iter_size = %d * %d' % (args.batch_size, args.iter_size))

    print('Adaptive config changes:')
    print('    effective_batch_size: %d --> %d' % (original_batch_size, effective_batch_size))
    print('    NUM_GPUS:             %d --> %d' % (original_num_gpus, cfg.NUM_GPUS))
    print('    IMS_PER_BATCH:        %d --> %d' % (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH))

    ### Adjust learning based on batch size change linearly
    # For iter_size > 1, gradients are `accumulated`, so lr is scaled based
    # on batch_size instead of effective_batch_size
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch_size change:\n'
          '    BASE_LR: {} --> {}'.format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Adjust solver steps
    step_scale = original_batch_size / effective_batch_size
    old_solver_steps = cfg.SOLVER.STEPS
    old_max_iter = cfg.SOLVER.MAX_ITER
    cfg.SOLVER.STEPS = list(map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS))
    cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5)
    print('Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n'
          '    SOLVER.STEPS: {} --> {}\n'
          '    SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps, cfg.SOLVER.STEPS,
                                                  old_max_iter, cfg.SOLVER.MAX_ITER))


    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma
    assert_and_infer_cfg()

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb', timers['roidb'].average_time)

    # Effective training sample size for one epoch
    train_size = roidb_size // args.batch_size * args.batch_size

    batchSampler = BatchSampler(
        sampler=MinibatchSampler(ratio_list, ratio_index),
        batch_size=args.batch_size,
        drop_last=True
    )
    dataset = RoiDataLoader(
        roidb,
        cfg.MODEL.NUM_CLASSES,
        training=True)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=batchSampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    dataiterator = iter(dataloader)

    ### Model ###
    maskRCNN = Generalized_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()
        
    ### Optimizer ###
    # record backbone params, i.e., conv_body and box_head params
    gn_params = []
    backbone_bias_params = []
    backbone_bias_param_names = []
    prd_branch_bias_params = []
    prd_branch_bias_param_names = []
    backbone_nonbias_params = []
    backbone_nonbias_param_names = []
    prd_branch_nonbias_params = []
    prd_branch_nonbias_param_names = []
    classifier_params = []
    classifier_param_names = []
    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'gn' in key:
                gn_params.append(value)
            elif 'Conv_Body' in key or 'Box_Head' in key or 'Box_Outs' in key or 'RPN' in key:
                if 'bias' in key:
                    backbone_bias_params.append(value)
                    backbone_bias_param_names.append(key)
                else:
                    backbone_nonbias_params.append(value)
                    backbone_nonbias_param_names.append(key)
            #elif 'classifier' in key:
            #    classifier_params.append(value)
            #    classifier_param_names.append(value)
            else:
                if 'bias' in key:
                    prd_branch_bias_params.append(value)
                    prd_branch_bias_param_names.append(key)
                else:
                    prd_branch_nonbias_params.append(value)
                    prd_branch_nonbias_param_names.append(key)
    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [
        {'params': backbone_nonbias_params,
         'lr': 0,
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY},
        {'params': backbone_bias_params,
         'lr': 0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0},
        {'params': prd_branch_nonbias_params,
         'lr': 0,
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY},
        {'params': prd_branch_bias_params,
         'lr': 0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0},
        {'params': gn_params,
         'lr': 0,
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN}#,
        #{'params': classifier_params,
        # 'lr': 0.1, 'momentum': 0.9, 'weight_decay': 0.0005}
    ]

    # print('Initializing model optimizer.')

    # if cfg.SOLVER.TYPE == "SGD":
    #     optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    # elif cfg.SOLVER.TYPE == "Adam":
    #     optimizer = torch.optim.Adam(params)


    print('Initializing model and classifier optimizers.')
    #classifier_optim_param = {'lr': 0.1, 'momentum': 0.9, 'weight_decay': 0.0005}
    #params.append({'params': maskRCNN.classifier.parameters(),
    #                'lr': classifier_optim_param['lr']),
    #                'momentum': classifier_optim_param['momentum'],
    #                'weight_decay': classifier_optim_param['weight_decay']})
    #params.append({'params': maskRCNN.prd_classifier.parameters(),
    #                'lr': classifier_optim_param['lr'],
    #                'momentum': classifier_optim_param['momentum'],
    #                'weight_decay': classifier_optim_param['weight_decay']})

    if cfg.MODEL.MEMORY_MODULE_STAGE == 1:
        step_size = 10
    elif cfg.MODEL.MEMORY_MODULE_STAGE == 2:
        step_size = 20
    else:
        raise NotImplementedError

    scheduler_params = {'step_size': step_size, 'gamma': 0.1}

    optimizer, optimizer_scheduler = init_optimizers(params, scheduler_params)

    criterion_optimizer, criterion_optimizer_scheduler = None, None
    if cfg.MODEL.MEMORY_MODULE_STAGE == 2:
        print('Initializing criterion optimizer.')
        feat_loss_optim_param = {'lr': 0.01, 'momentum': 0.9, 'weight_decay': 0.0005}

        optim_params = feat_loss_optim_param
        optim_params = [{'params': maskRCNN.feature_loss_sbj_obj.parameters(),
                         'lr': optim_params['lr'],
                         'momentum': optim_params['momentum'],
                         'weight_decay': optim_params['weight_decay']},
                        {'params': maskRCNN.feature_loss_prd.parameters(),
                         'lr': optim_params['lr'],
                         'momentum': optim_params['momentum'],
                         'weight_decay': optim_params['weight_decay']}
                        ]

        # Initialize criterion optimizer and scheduler
        criterion_optimizer, criterion_optimizer_scheduler = init_optimizers(optim_params, scheduler_params)
    if cfg.MODEL.MEMORY_MODULE_STAGE == 2:
        weights_path = 'Outputs/e2e_relcnn_VGG16_8_epochs_gvqa_y_loss_only_1_gpu/gvqa/Feb07-10-55-03_login104-09_step_with_prd_cls_v3/ckpt/model_step1439.pth'
        weights = torch.load(weights_path)
        #print('weights', weights['model'].keys())
        #print(maskRCNN.state_dict().keys())
        maskRCNN.load_state_dict(weights['model'], strict=False)
        #print(list(maskRCNN.parameters()))
        #print(maskRCNN.state_dict().keys())
        #print(maskRCNN.state_dict()['prd_classifier.fc_hallucinator.weight'] == weights['model']['prd_classifier.fc.weight'])
        #print(torch.all(torch.eq(maskRCNN.state_dict()['prd_classifier.fc_hallucinator.weight'], weights['model']['prd_classifier.fc.weight'])))
        #print(torch.all(torch.eq(maskRCNN.state_dict()['Box_Head.heads.0.weight'], weights['model']['Box_Head.heads.0.weight'])))
        #exit()

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name, map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print('train_size value: %d different from the one in checkpoint: %d'
                          % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer, checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[2]['lr']  # lr of non-backbone parameters, for commmand line outputs.
    backbone_lr = optimizer.param_groups[0]['lr']  # lr of backbone parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN, cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name() + '_step_with_prd_cls_v' + str(cfg.MODEL.SUBTYPE)
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    # CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)
    # CHECKPOINT_PERIOD = cfg.SOLVER.MAX_ITER / cfg.TRAIN.SNAPSHOT_FREQ
    CHECKPOINT_PERIOD = 200000

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args,
        args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):
            # optimizer_scheduler.step()
            # if criterion_optimizer:
            #     criterion_optimizer_scheduler.step()
            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 - alpha) + alpha
                else:
                    raise KeyError('Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate_rel(optimizer, lr, lr_new)
                lr = optimizer.param_groups[2]['lr']
                backbone_lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate_rel(optimizer, lr, cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[2]['lr']
                backbone_lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate_rel(optimizer, lr, lr_new)
                if criterion_optimizer:
                    net_utils.update_learning_rate_rel(criterion_optimizer, lr, lr_new)
                lr = optimizer.param_groups[2]['lr']
                backbone_lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new

                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()
            if criterion_optimizer:
                criterion_optimizer.zero_grad()
            for inner_iter in range(args.iter_size):
                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)
                #print('input_data', [torch.isnan(x) for x in input_data.values()])    
                
                for key in input_data:
                    if key != 'roidb': # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))
                #with autograd.detect_anomaly():
                net_outputs = maskRCNN(**input_data)
                training_stats.UpdateIterStats(net_outputs, inner_iter)
                loss = net_outputs['total_loss']
                loss.backward()
            optimizer.step()
            if criterion_optimizer:
                criterion_optimizer.step()
            training_stats.IterToc()

            training_stats.LogIterStats(step, lr, backbone_lr)

            if (step+1) % CHECKPOINT_PERIOD == 0:
                print('Saving Checkpoint..')
                save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

        # ---- Training ends ----
        # Save last checkpoint
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

    except Exception as e:
        del dataiterator
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
Пример #16
0
def main():
    saveNetStructure=False
    
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)


    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        #set gpu device
        os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(ids) for ids in args.device_ids])
        torch.backends.cudnn.benchmark=True
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    if args.dataset == "coco2017":
        cfg.TRAIN.DATASETS = ('coco_2017_train',)
        cfg.MODEL.NUM_CLASSES = 81
    elif args.dataset == "keypoints_coco2017":
        cfg.TRAIN.DATASETS = ('keypoints_coco_2017_train',)
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cityscapes":
        cfg.TRAIN.DATASETS = ('cityscapes_semseg_train', )
        cfg.MODEL.NUM_CLASSES = 19
    elif args.dataset == "cityscape_train_on_val":
        cfg.TRAIN.DATASETS = ('cityscape_train_on_val', )
        cfg.MODEL.NUM_CLASSES = 19
    elif args.dataset == "cityscapes_coarse":
        cfg.TRAIN.DATASETS = ('cityscapes_coarse', )
        cfg.MODEL.NUM_CLASSES = 19
    elif args.dataset == "cityscapes_all":
        cfg.TRAIN.DATASETS = ('cityscapes_all', )
        cfg.MODEL.NUM_CLASSES = 19
    elif args.dataset == "cityscapes_trainval":
        cfg.TRAIN.DATASETS = ('cityscapes_trainval', )
        cfg.MODEL.NUM_CLASSES = 19
    elif args.dataset == "cityscapes_fineturn":
        cfg.TRAIN.DATASETS = ('cityscapes_fineturn', )
        cfg.MODEL.NUM_CLASSES = 19
    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)
    

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    print('Batch size change from {} (in config file) to {}'.format(
        original_batch_size, args.batch_size))
    print('NUM_GPUs: %d, TRAIN.IMS_PER_BATCH: %d' % (cfg.NUM_GPUS, cfg.TRAIN.IMS_PER_BATCH))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Adjust learning based on batch size change linearly
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch size change: {} --> {}'.format(
        old_base_lr, cfg.SOLVER.BASE_LR))

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma

    timers = defaultdict(Timer)
    
    ### Dataset ###
    timers['roidb'].tic()
    if cfg.SEM.SEM_ON or cfg.DISP.DISP_ON:
        roidb, ratio_list, ratio_index = combined_roidb_for_training_semseg(
            cfg.TRAIN.DATASETS)
    else:
        roidb, ratio_list, ratio_index = combined_roidb_for_training(
            cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    train_size = len(roidb)
    logger.info('{:d} roidb entries'.format(train_size))
    logger.info('Takes %.2f sec(s) to construct roidb', timers['roidb'].average_time)

    #sampler = MinibatchSampler(ratio_list, ratio_index)
    sampler = None
    dataset = RoiDataLoader(
        roidb,
        cfg.MODEL.NUM_CLASSES,
        training=True)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        sampler=sampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch_semseg_all if cfg.SEM.SEM_ON or cfg.DISP.DISP_ON else collate_minibatch,
        drop_last=False,
        shuffle=True,
        pin_memory=True)

    assert_and_infer_cfg()
    #for data in dataloader:
    #    image = data['data'][0][0].numpy()
    #    print (image.shape)
    #    image=image.transpose(1,2,0)+cfg.PIXEL_MEANS
    #    cv2.imwrite('image.png', image[:,:,::-1])
    #    cv2.imwrite('label.png',10*data['semseg_label_0'][0][0].numpy())
    #    return
    
    maskRCNN = eval(cfg.MODEL.TYPE)()
    if len(cfg.SEM.PSPNET_PRETRAINED_WEIGHTS)>1:
        print("loading pspnet weights")
        state_dict={}
        pretrained=torch.load(cfg.SEM.PSPNET_PRETRAINED_WEIGHTS, map_location=lambda storage, loc: storage)
        pretrained = pretrained['model']
        if cfg.SEM.SPN_ON:
            maskRCNN.pspnet.load_state_dict(pretrained,strict=True)
        elif  'deeplab' in cfg.SEM.DECODER_TYPE:
            encoder = dict()
            for k, v in pretrained.items():
                if 'decoder' in k:
                    continue
                encoder[k.replace('encoder.','')] = v
            maskRCNN.encoder.load_state_dict(encoder,strict=True)
            del encoder
        else:
            maskRCNN.load_state_dict(pretrained,strict=True)
        del pretrained
        print("weights load success")

    if cfg.SEM.SPN_ON:
        maskRCNN.pspnet.eval()
        for p in maskRCNN.pspnet.parameters():
            p.requires_grad = False

    # load nets into gpu
    maskRCNN = UserScatteredDataParallel(maskRCNN)
    # For sync bn
    patch_replication_callback(maskRCNN)
    if cfg.CUDA:
        maskRCNN.to('cuda')

    ### Optimizer ###
    bias_params = []
    nonbias_params = []
    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
            else:
                nonbias_params.append(value)
    params = [
        {'params': nonbias_params,
         'lr': cfg.SOLVER.BASE_LR,
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY},
        {'params': bias_params,
         'lr': cfg.SOLVER.BASE_LR * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0}
    ]


    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
        print("Using STEP as Lr reduce policy!")
    if cfg.SOLVER.TYPE == 'SGD' and cfg.SOLVER.LR_POLICY == 'ReduceLROnPlateau':
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,'min',patience=10)
        print("Using ReduceLROnPlateau as Lr reduce policy!")
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)
    elif "poly" in cfg.SOLVER.TYPE:
        optimizer = create_optimizers(maskRCNN,args)
        print("Using Poly as Lr reduce policy!")

    args.max_iters = (int(train_size / args.batch_size)) * args.num_epochs
    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name, map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])
        if args.resume:
            assert checkpoint['iters_per_epoch'] == train_size // args.batch_size, \
                "iters_per_epoch should match for resume"
            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer, checkpoint['optimizer'])
            if checkpoint['step'] == (checkpoint['iters_per_epoch'] - 1):
                # Resume from end of an epoch
                args.start_epoch = checkpoint['epoch'] + 1
                args.start_iter = 0
            else:
                # Resume from the middle of an epoch.
                # NOTE: dataloader is not synced with previous state
                args.start_epoch = checkpoint['epoch']
                args.start_iter = checkpoint['step'] + 1
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)
    
    if cfg.SOLVER.TYPE=='step_poly':
        lr  = cfg.SOLVER.BASE_LR / (cfg.SOLVER.GAMMA**len(args.lr_decay_epochs))
    else:
        lr = optimizer.param_groups[0]['lr']  # lr of non-bias parameters, for commmand line outputs.


	
    ### Training Setups ###
    args.run_name = misc_utils.get_run_name()
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            #from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    training_stats = TrainingStats(
        args,
        args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)

    iters_per_epoch = int(train_size / args.batch_size)  # drop last
    args.iters_per_epoch = iters_per_epoch
    ckpt_interval_per_epoch = iters_per_epoch // args.ckpt_num_per_epoch
    try:
        logger.info('Training starts !')
        args.step = args.start_iter
        global_step = iters_per_epoch * args.start_epoch + args.step
        for args.epoch in range(args.start_epoch, args.start_epoch + args.num_epochs):
            # ---- Start of epoch ----
            # adjust learning rate
            if args.lr_decay_epochs and args.epoch == args.lr_decay_epochs[0] and args.start_iter == 0 and cfg.SOLVER.LR_POLICY=='steps_with_decay' :
                args.lr_decay_epochs.pop(0)
                net_utils.decay_learning_rate(optimizer, lr, cfg.SOLVER.GAMMA)
                lr *= cfg.SOLVER.GAMMA

            

            for args.step, input_data in zip(range(args.start_iter, iters_per_epoch), dataloader):
                
                #if cfg.DISP.DISP_ON:
                #    input_data['data'] = list(map(lambda x,y: torch.cat((x,y), dim=0), 
                #                input_data['data'], input_data['data_R']))
                #    if cfg.SEM.DECODER_TYPE.endswith('3D'):
                #        input_data['disp_scans'] = torch.arange(1,
                #                cfg.DISP.MAX_DISPLACEMENT+1).float().view(1,cfg.DISP.MAX_DISPLACEMENT).repeat(args.batch_size,1)
                #    del input_data['data_R']
                #for key in input_data:
                #    if key != 'roidb': # roidb is a list of ndarrays with inconsistent length
                #        input_data[key] = list(map(lambda x: Variable(x, requires_grad=False).to('cuda'), input_data[key]))
                training_stats.IterTic()
                net_outputs = maskRCNN(input_data)

                training_stats.UpdateIterStats(net_outputs)
                #loss = net_outputs['losses']['loss_semseg']
                #acc  = net_outputs['metrics']['accuracy_pixel']
                #print (loss.item(), acc)
                #for key in net_outputs.keys():
                #    print(key)
                loss = net_outputs['total_loss']
                
                #print("loss.shape:",loss)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                if cfg.SOLVER.TYPE=='poly':
                    lr = adjust_learning_rate(optimizer, global_step, args)
                
                if cfg.SOLVER.TYPE=='step_poly':
                    lr = step_adjust_learning_rate(optimizer, lr, global_step, args)
                
                training_stats.IterToc()

                if args.step % args.disp_interval == 0:
                    disp_image=''
                    semseg_image=''
                    #tblogger.add_image('disp_image',disp_image,global_step)
                    #tblogger.add_image('semseg_image',semseg_image,global_step)
                    log_training_stats(training_stats, global_step, lr)
                global_step += 1
            # ---- End of epoch ----
            # save checkpoint
            if cfg.SOLVER.TYPE == 'SGD' and cfg.SOLVER.LR_POLICY == 'ReduceLROnPlateau':
                    lr_scheduler.step(loss)
                    lr = optimizer.param_groups[0]['lr']
            if (args.epoch+1) % args.ckpt_num_per_epoch ==0:
                net_utils.save_ckpt(output_dir, args, maskRCNN, optimizer)
            # reset starting iter number after first epoch
            args.start_iter = 0

        # ---- Training ends ----
        #if iters_per_epoch % args.disp_interval != 0:
            # log last stats at the end
        #    log_training_stats(training_stats, global_step, lr)
        # save final model
        if (args.epoch+1) % args.ckpt_num_per_epoch:
            net_utils.save_ckpt(output_dir, args, maskRCNN, optimizer)
    except (RuntimeError, KeyboardInterrupt):
        logger.info('Save ckpt on exception ...')
        net_utils.save_ckpt(output_dir, args, maskRCNN, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
Пример #17
0
def main():
    """Main function"""
    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    merge_cfg_from_file(args.cfg_file)

    # Some manual adjustment for the ApolloScape dataset parameters here
    cfg.OUTPUT_DIR = args.output_dir
    cfg.TRAIN.DATASETS = 'Car3D'
    cfg.MODEL.NUM_CLASSES = 8
    if cfg.CAR_CLS.SIM_MAT_LOSS:
        cfg.MODEL.NUMBER_CARS = 79
    else:
        # Loss is only cross entropy, hence, we detect only car categories in the training set.
        cfg.MODEL.NUMBER_CARS = 34
    cfg.TRAIN.MIN_AREA = 196  # 14*14
    cfg.TRAIN.USE_FLIPPED = False  # Currently I don't know how to handle the flipped case
    cfg.TRAIN.IMS_PER_BATCH = 1

    cfg.NUM_GPUS = torch.cuda.device_count()
    effective_batch_size = cfg.TRAIN.IMS_PER_BATCH * cfg.NUM_GPUS * args.iter_size

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH
    original_num_gpus = cfg.NUM_GPUS
    if args.batch_size is None:
        args.batch_size = original_batch_size

    assert (args.batch_size %
            cfg.NUM_GPUS) == 0, 'batch_size: %d, NUM_GPUS: %d' % (
                args.batch_size, cfg.NUM_GPUS)

    print('effective_batch_size = batch_size * iter_size = %d * %d' %
          (args.batch_size, args.iter_size))
    print('Adaptive config changes:')
    print('    effective_batch_size: %d --> %d' %
          (original_batch_size, effective_batch_size))
    print('    NUM_GPUS:             %d --> %d' %
          (original_num_gpus, cfg.NUM_GPUS))
    print('    IMS_PER_BATCH:        %d --> %d' %
          (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH))

    ### Adjust learning based on batch size change linearly
    # For iter_size > 1, gradients are `accumulated`, so lr is scaled based
    # on batch_size instead of effective_batch_size
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print(
        'Adjust BASE_LR linearly according to batch_size change:\n BASE_LR: {} --> {}'
        .format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Adjust solver steps
    step_scale = original_batch_size / effective_batch_size
    old_solver_steps = cfg.SOLVER.STEPS
    old_max_iter = cfg.SOLVER.MAX_ITER
    cfg.SOLVER.STEPS = list(
        map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS))
    cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5)
    print(
        'Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n'
        '    SOLVER.STEPS: {} --> {}\n'
        '    SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps,
                                                cfg.SOLVER.STEPS, old_max_iter,
                                                cfg.SOLVER.MAX_ITER))

    # Scale FPN rpn_proposals collect size (post_nms_topN) in `collect` function
    # of `collect_and_distribute_fpn_rpn_proposals.py`
    #
    # post_nms_topN = int(cfg[cfg_key].RPN_POST_NMS_TOP_N * cfg.FPN.RPN_COLLECT_SCALE + 0.5)
    if cfg.FPN.FPN_ON and cfg.MODEL.FASTER_RCNN:
        cfg.FPN.RPN_COLLECT_SCALE = cfg.TRAIN.IMS_PER_BATCH / original_ims_per_batch
        print(
            'Scale FPN rpn_proposals collect size directly propotional to the change of IMS_PER_BATCH:\n'
            '    cfg.FPN.RPN_COLLECT_SCALE: {}'.format(
                cfg.FPN.RPN_COLLECT_SCALE))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma
    assert_and_infer_cfg()

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    if cfg.MODEL.LOSS_3D_2D_ON:
        roidb, ratio_list, ratio_index, ds = combined_roidb_for_training(
            cfg.TRAIN.DATASETS, args.dataset_dir)
    else:
        roidb, ratio_list, ratio_index = combined_roidb_for_training(
            cfg.TRAIN.DATASETS, args.dataset_dir)

    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb',
                timers['roidb'].average_time)

    # Effective training sample size for one epoch
    train_size = roidb_size // args.batch_size * args.batch_size

    sampler = MinibatchSampler(ratio_list, ratio_index)
    dataset = RoiDataLoader(roidb,
                            cfg.MODEL.NUM_CLASSES,
                            training=True,
                            valid_keys=[
                                'has_visible_keypoints', 'boxes', 'seg_areas',
                                'gt_classes', 'gt_overlaps',
                                'box_to_gt_ind_map', 'is_crowd',
                                'car_cat_classes', 'poses', 'quaternions',
                                'im_info'
                            ])
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        drop_last=True,
        sampler=sampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    dataiterator = iter(dataloader)

    ### Model ###
    if cfg.MODEL.LOSS_3D_2D_ON:
        maskRCNN = Generalized_RCNN(ds.Car3D)
    else:
        maskRCNN = Generalized_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    bias_params = []
    bias_param_names = []
    nonbias_params = []
    nonbias_param_names = []
    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
                bias_param_names.append(key)
            else:
                nonbias_params.append(value)
                nonbias_param_names.append(key)
    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [{
        'params': nonbias_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        bias_params,
        'lr':
        0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN,
                            checkpoint['model'],
                            ignore_list=args.ckpt_ignore_head)
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print(
                        'train_size value: %d different from the one in checkpoint: %d'
                        % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer,
                                                 checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  # TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[0][
        'lr']  # lr of non-bias parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name() + '_step'
    # output_dir = os.path.join('/media/SSD_1TB/zzy/ApolloScape/ECCV2018_apollo/train', args.run_name)
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()
    CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)
    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    # warmup_factor_trans = 1.0
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):

            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 -
                                                                 alpha) + alpha
                    # warmup_factor_trans = cfg.SOLVER.WARM_UP_FACTOR_TRANS * (1 - alpha) + alpha
                    # warmup_factor_trans *= cfg.TRANS_HEAD.LOSS_BETA
                    warmup_factor_trans = 1.0
                else:
                    raise KeyError(
                        'Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate(optimizer, lr,
                                               cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(
                    cfg.SOLVER.STEPS
            ) and step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()
            for inner_iter in range(args.iter_size):
                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)

                for key in input_data:
                    if key != 'roidb':  # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))

                net_outputs = maskRCNN(**input_data)

                net_outputs['losses'][
                    'loss_car_cls'] *= cfg.CAR_CLS.CAR_CLS_LOSS_BETA
                net_outputs['losses']['loss_rot'] *= cfg.CAR_CLS.ROT_LOSS_BETA
                if cfg.MODEL.TRANS_HEAD_ON:
                    net_outputs['losses'][
                        'loss_trans'] *= cfg.TRANS_HEAD.TRANS_LOSS_BETA

                training_stats.UpdateIterStats_car_3d(net_outputs)

                # start training
                # loss_car_cls: 2.233790, loss_rot: 0.296853, loss_trans: ~100
                loss = net_outputs['losses']['loss_car_cls'] + net_outputs[
                    'losses']['loss_rot']
                if cfg.MODEL.TRANS_HEAD_ON:
                    loss += net_outputs['losses']['loss_trans']
                if cfg.MODEL.LOSS_3D_2D_ON:
                    loss += net_outputs['losses']['UV_projection_loss']
                if not cfg.TRAIN.FREEZE_CONV_BODY and not cfg.TRAIN.FREEZE_RPN and not cfg.TRAIN.FREEZE_FPN:
                    loss += net_outputs['total_loss_conv']

                loss.backward()
            optimizer.step()
            training_stats.IterToc()

            training_stats.LogIterStats(step, lr, warmup_factor_trans)

            if (step + 1) % CHECKPOINT_PERIOD == 0:
                save_ckpt(output_dir, args, step, train_size, maskRCNN,
                          optimizer)

        # ---- Training ends ----
        # Save last checkpoint
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

    except (RuntimeError, KeyboardInterrupt):
        del dataiterator
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()