Exemple #1
0
def train_model():
    """Model training loop."""
    logger = logging.getLogger(__name__)
    model, weights_file, start_iter, checkpoints, output_dir, writer = create_model(
    )
    if 'final' in checkpoints:
        # The final model was found in the output directory, so nothing to do
        return checkpoints

    setup_model_for_training(model, weights_file, output_dir)
    writer.write_graph(model, single_gpu=True, custom_rename=nu.scope_function)
    training_stats = TrainingStats(model)
    CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)

    for cur_iter in range(start_iter, cfg.SOLVER.MAX_ITER):
        training_stats.IterTic()
        lr = model.UpdateWorkspaceLr(cur_iter,
                                     lr_policy.get_lr_at_iter(cur_iter))
        workspace.RunNet(model.net.Proto().name)
        if cur_iter == start_iter:
            nu.print_net(model)
        training_stats.IterToc()
        training_stats.UpdateIterStats()
        training_stats.LogIterStats(cur_iter, lr)

        if (cur_iter + 1) % CHECKPOINT_PERIOD == 0 and cur_iter > start_iter:
            checkpoints[cur_iter] = os.path.join(
                output_dir, 'model_iter{}.pkl'.format(cur_iter))
            nu.save_model_to_weights_file(checkpoints[cur_iter], model)

        if cur_iter == start_iter + training_stats.LOG_PERIOD:
            # Reset the iteration timer to remove outliers from the first few
            # SGD iterations
            training_stats.ResetIterTimer()

        if np.isnan(training_stats.iter_total_loss):
            training_stats.LogIterStats(cur_iter, lr, nan=True)
            logger.critical('Loss is NaN, exiting...')
            model.roi_data_loader.shutdown()
            envu.exit_on_error()

    # Save the final model
    checkpoints['final'] = os.path.join(output_dir, 'model_final.pkl')
    nu.save_model_to_weights_file(checkpoints['final'], model)
    # Shutdown data loading threads
    model.roi_data_loader.shutdown()
    writer.close()
    return checkpoints
Exemple #2
0
def main():
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    if args.dataset == "coco2017":
        cfg.TRAIN.DATASETS = ('coco_2017_train',)
        cfg.MODEL.NUM_CLASSES = 81
    elif args.dataset == "keypoints_coco2017":
        cfg.TRAIN.DATASETS = ('keypoints_coco_2017_train',)
        cfg.MODEL.NUM_CLASSES = 2
    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH
    original_num_gpus = cfg.NUM_GPUS
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    effective_batch_size = args.iter_size * args.batch_size
    print('effective_batch_size = batch_size * iter_size = %d * %d' % (args.batch_size, args.iter_size))

    print('Adaptive config changes:')
    print('    effective_batch_size: %d --> %d' % (original_batch_size, effective_batch_size))
    print('    NUM_GPUS:             %d --> %d' % (original_num_gpus, cfg.NUM_GPUS))
    print('    IMS_PER_BATCH:        %d --> %d' % (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH))

    ### Adjust learning based on batch size change linearly
    # For iter_size > 1, gradients are `accumulated`, so lr is scaled based
    # on batch_size instead of effective_batch_size
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch_size change:\n'
          '    BASE_LR: {} --> {}'.format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Adjust solver steps
    step_scale = original_batch_size / effective_batch_size
    old_solver_steps = cfg.SOLVER.STEPS
    old_max_iter = cfg.SOLVER.MAX_ITER
    cfg.SOLVER.STEPS = list(map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS))
    cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5)
    print('Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n'
          '    SOLVER.STEPS: {} --> {}\n'
          '    SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps, cfg.SOLVER.STEPS,
                                                  old_max_iter, cfg.SOLVER.MAX_ITER))

    # Scale FPN rpn_proposals collect size (post_nms_topN) in `collect` function
    # of `collect_and_distribute_fpn_rpn_proposals.py`
    #
    # post_nms_topN = int(cfg[cfg_key].RPN_POST_NMS_TOP_N * cfg.FPN.RPN_COLLECT_SCALE + 0.5)
    if cfg.FPN.FPN_ON and cfg.MODEL.FASTER_RCNN:
        cfg.FPN.RPN_COLLECT_SCALE = cfg.TRAIN.IMS_PER_BATCH / original_ims_per_batch
        print('Scale FPN rpn_proposals collect size directly propotional to the change of IMS_PER_BATCH:\n'
              '    cfg.FPN.RPN_COLLECT_SCALE: {}'.format(cfg.FPN.RPN_COLLECT_SCALE))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma
    assert_and_infer_cfg()

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb', timers['roidb'].average_time)

    # Effective training sample size for one epoch
    train_size = roidb_size // args.batch_size * args.batch_size

    batchSampler = BatchSampler(
        sampler=MinibatchSampler(ratio_list, ratio_index),
        batch_size=args.batch_size,
        drop_last=True
    )
    dataset = RoiDataLoader(
        roidb,
        cfg.MODEL.NUM_CLASSES,
        training=True)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=batchSampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    dataiterator = iter(dataloader)

    ### Model ###
    maskRCNN = Generalized_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    gn_param_nameset = set()
    for name, module in maskRCNN.named_modules():
        if isinstance(module, nn.GroupNorm):
            gn_param_nameset.add(name+'.weight')
            gn_param_nameset.add(name+'.bias')
    gn_params = []
    gn_param_names = []
    bias_params = []
    bias_param_names = []
    nonbias_params = []
    nonbias_param_names = []
    nograd_param_names = []
    for key, value in maskRCNN.named_parameters():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
                bias_param_names.append(key)
            elif key in gn_param_nameset:
                gn_params.append(value)
                gn_param_names.append(key)
            else:
                nonbias_params.append(value)
                nonbias_param_names.append(key)
        else:
            nograd_param_names.append(key)
    assert (gn_param_nameset - set(nograd_param_names) - set(bias_param_names)) == set(gn_param_names)

    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [
        {'params': nonbias_params,
         'lr': 0,
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY},
        {'params': bias_params,
         'lr': 0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0},
        {'params': gn_params,
         'lr': 0,
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN}
    ]
    # names of paramerters for each paramter
    param_names = [nonbias_param_names, bias_param_names, gn_param_names]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name, map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print('train_size value: %d different from the one in checkpoint: %d'
                          % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            optimizer.load_state_dict(checkpoint['optimizer'])
            # misc_utils.load_optimizer_state_dict(optimizer, checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[0]['lr']  # lr of non-bias parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN, cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name() + '_step'
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args,
        args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):

            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 - alpha) + alpha
                else:
                    raise KeyError('Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate(optimizer, lr, cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()
            for inner_iter in range(args.iter_size):
                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)

                for key in input_data:
                    if key != 'roidb': # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))

                net_outputs = maskRCNN(**input_data)
                training_stats.UpdateIterStats(net_outputs, inner_iter)
                loss = net_outputs['total_loss']
                loss.backward()
            optimizer.step()
            training_stats.IterToc()

            training_stats.LogIterStats(step, lr)

            if (step+1) % CHECKPOINT_PERIOD == 0:
                save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

        # ---- Training ends ----
        # Save last checkpoint
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

    except (RuntimeError, KeyboardInterrupt):
        del dataiterator
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
Exemple #3
0
def main():
    """Main function"""

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    args = parse_args()
    print('Called with args:')
    print(args)

    load_config(args.cfg_file)

    if cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH
    original_num_gpus = cfg.NUM_GPUS
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()

    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
     'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    effective_batch_size = args.iter_size * args.batch_size
    print('effective_batch_size = batch_size * iter_size = %d * %d' %
          (args.batch_size, args.iter_size))

    print('Adaptive config changes:')
    print('    effective_batch_size: %d --> %d' %
          (original_batch_size, effective_batch_size))
    print('    NUM_GPUS:             %d --> %d' %
          (original_num_gpus, cfg.NUM_GPUS))
    print('    IMS_PER_BATCH:        %d --> %d' %
          (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH))

    ### Adjust learning based on batch size change linearly
    # For iter_size > 1, gradients are `accumulated`, so lr is scaled based
    # on batch_size instead of effective_batch_size
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch_size change:\n'
          '    BASE_LR: {} --> {}'.format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Adjust solver steps
    step_scale = original_batch_size / effective_batch_size
    old_solver_steps = cfg.SOLVER.STEPS
    old_max_iter = cfg.SOLVER.MAX_ITER
    cfg.SOLVER.STEPS = list(
        map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS))
    cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5)
    print(
        'Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n'
        '    SOLVER.STEPS: {} --> {}\n'
        '    SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps,
                                                cfg.SOLVER.STEPS, old_max_iter,
                                                cfg.SOLVER.MAX_ITER))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma

    # np.random.seed(cfg.RNG_SEED)
    # torch.manual_seed(cfg.RNG_SEED)
    # if cfg.CUDA:
    #     torch.cuda.manual_seed_all(cfg.RNG_SEED)
    torch.backends.cudnn.deterministic = True

    transforms = dt_trans.Compose([
        dt_trans.Normalize([102.9801, 115.9465, 122.7717]),
        dt_trans.HorizontalFlip(),
        dt_trans.Resize(),
        dt_trans.ToTensor(),
    ])

    if cfg.TRAIN.DATASET == 'voc_2007_trainval':
        dataset = VOCDetection(cfg.DATA_DIR + '/',
                               year='2007',
                               image_set='trainval',
                               transforms=transforms)

    elif cfg.TRAIN.DATASET == 'voc_2012_trainval':
        dataset = VOCDetection(cfg.DATA_DIR + '/',
                               year='2012',
                               image_set='trainval',
                               transforms=transforms)

    # Effective training sample size for one epoch
    train_size = len(dataset) // args.batch_size * args.batch_size

    batchSampler = TrainSampler(subdivision=cfg.TRAIN.ITERATION_SIZE,
                                batch_size=cfg.TRAIN.ITERATION_SIZE,
                                max_iterations=cfg.SOLVER.MAX_ITER,
                                num_samples=len(dataset),
                                image_scales=cfg.TRAIN.SCALES,
                                scale_interval=10)

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=batchSampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    dataiterator = iter(dataloader)

    ### Model ###
    model = eval(args.model).loot_model(args)

    if cfg.CUDA:
        model.cuda()

    ### Optimizer ###
    bias_params = []
    bias_param_names = []
    nonbias_params = []
    nonbias_param_names = []
    nograd_param_names = []
    for key, value in model.named_parameters():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
                bias_param_names.append(key)
            else:
                nonbias_params.append(value)
                nonbias_param_names.append(key)
        else:
            nograd_param_names.append(key)

    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [
        {
            'params': nonbias_params,
            'lr': 0,
            'weight_decay': cfg.SOLVER.WEIGHT_DECAY
        },
        {
            'params':
            bias_params,
            'lr':
            0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
            'weight_decay':
            cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
        },
    ]
    # names of paramerters for each paramter
    param_names = [nonbias_param_names, bias_param_names]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        optimizer_utils.load_ckpt(model, checkpoint['model'])
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print(
                        'train_size value: %d different from the one in checkpoint: %d'
                        % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            optimizer.load_state_dict(checkpoint['optimizer'])
            # misc_utils.load_optimizer_state_dict(optimizer, checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    lr = optimizer.param_groups[0][
        'lr']  # lr of non-bias parameters, for commmand line outputs.

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name() + '_step'
    output_dir = misc_utils.get_output_dir(args)

    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    model.train()

    CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS /
                            (cfg.NUM_GPUS * args.iter_size))

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):

            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 -
                                                                 alpha) + alpha
                else:
                    raise KeyError(
                        'Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                optimizer_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                optimizer_utils.update_learning_rate(optimizer, lr,
                                                     cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
              step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                optimizer_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()
            for inner_iter in range(args.iter_size):
                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)

                model.set_inner_iter(step)

                im_data = input_data[0].cuda()
                labels = input_data[1].cuda()
                rois = input_data[2].cuda()

                net_outputs = model(im_data, rois, labels)

                training_stats.UpdateIterStats(net_outputs, inner_iter)
                loss = net_outputs['total_loss']
                loss.backward(retain_graph=True)
            optimizer.step()
            training_stats.IterToc()

            training_stats.LogIterStats(step, lr)

            if (step + 1) % CHECKPOINT_PERIOD == 0:
                save_ckpt(output_dir, args, step, train_size, model, optimizer)

        # ---- Training ends ----
        # Save last checkpoint
        save_ckpt(output_dir, args, step, train_size, model, optimizer)

    except (RuntimeError, KeyboardInterrupt):
        del dataiterator
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, train_size, model, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
Exemple #4
0
def main():
    args = parser.parse_args()
    print(args)
    # for now, batch_size should match number of gpus
    assert(args.batch_size==torch.cuda.device_count())

    # create model
    model = detector(arch=args.cnn_arch,
                 base_cnn_pkl_file=args.cnn_pkl,
                 mapping_file=args.cnn_mapping,
                 output_prob=False,
                 return_rois=False,
                 return_img_features=False)
    model = model.cuda()

    # freeze part of the net
    stop_grad=['conv1','bn1','relu','maxpool','layer1']
    model_no_grad=torch.nn.Sequential(*[getattr(model.model,l) for l in stop_grad])
    for param in model_no_grad.parameters():
        param.requires_grad = False

    # define  optimizer
    optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()),
                                lr=args.base_lr,
                                momentum=args.momentum,
                                weight_decay=args.wd)

    # create dataset
    train_dataset = CocoDataset(ann_file=args.dset_ann,
                          img_dir=args.dset_path,
                          proposal_file=args.dset_rois,
                          mode='train',
                          sample_transform=preprocess_sample(target_sizes=[800],
                                                             sample_proposals_for_training=True))
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size,shuffle=False, num_workers=args.workers, collate_fn=collate_custom)

    training_stats = TrainingStats(losses=['loss_cls','loss_bbox'],
                                   metrics=['accuracy_cls'],
                                   solver_max_iters=args.max_iter)

    iter = args.start_iter

    print('starting training')

    while iter<args.max_iter:
        for i, batch in enumerate(train_loader):

            if args.batch_size==1:
                batch = to_cuda_variable(batch,volatile=False)
            else:
                # when using multiple GPUs convert to cuda later in data_parallel and list_to_tensor
                batch = to_variable(batch,volatile=False)             
                

            # update lr
            lr = get_lr_at_iter(iter)
            adjust_learning_rate(optimizer, lr)

            # start measuring time
            training_stats.IterTic()

            # forward pass            
            if args.batch_size==1:
                cls_score,bbox_pred=model(batch['image'],batch['rois'])
                list_to_tensor = lambda x: x                
            else:
                cls_score,bbox_pred=data_parallel(model,(batch['image'],batch['rois'])) # run model distributed over gpus and concatenate outputs for all batch
                # convert gt data from lists to concatenated tensors
                list_to_tensor = lambda x: torch.cat(tuple([i.cuda() for i in x]),0)

            cls_labels = list_to_tensor(batch['labels_int32']).long()
            bbox_targets = list_to_tensor(batch['bbox_targets'])
            bbox_inside_weights = list_to_tensor(batch['bbox_inside_weights'])
            bbox_outside_weights = list_to_tensor(batch['bbox_outside_weights'])            
            
            # compute loss
            loss_cls=cross_entropy(cls_score,cls_labels)
            loss_bbox=smooth_L1(bbox_pred,bbox_targets,bbox_inside_weights,bbox_outside_weights)
                                  
            # compute classification accuracy (for stats reporting)
            acc = accuracy(cls_score,cls_labels)

            # get final loss
            loss = loss_cls + loss_bbox

            # update
            optimizer.zero_grad()
            loss.backward()
            # Without gradient clipping I get inf's and NaNs. 
            # it seems that in Caffe the SGD solver performs grad clipping by default. 
            # https://github.com/BVLC/caffe/blob/master/src/caffe/solvers/sgd_solver.cpp
            # it also seems that Matterport's Mask R-CNN required grad clipping as well 
            # (see README in https://github.com/matterport/Mask_RCNN)            
            # the value max_norm=35 was taken from here https://github.com/BVLC/caffe/blob/master/src/caffe/proto/caffe.proto
            clip_grad_norm(filter(lambda p: p.requires_grad, model.parameters()), max_norm=35, norm_type=2) 
            optimizer.step()

            # stats
            training_stats.IterToc()
            
            training_stats.UpdateIterStats(losses_dict={'loss_cls': loss_cls.data.cpu().numpy().item(),
                                                        'loss_bbox': loss_bbox.data.cpu().numpy().item()},
                                           metrics_dict={'accuracy_cls':acc.data.cpu().numpy().item()})

            training_stats.LogIterStats(iter, lr)
            # save checkpoint
            if (iter+1)%args.checkpoint_period == 0:
                save_checkpoint({
                    'iter': iter,
                    'args': args,
                    'state_dict': model.state_dict(),
                    'optimizer' : optimizer.state_dict(),
                }, args.checkpoint_fn)

            if iter == args.start_iter + 20: # training_stats.LOG_PERIOD=20
                # Reset the iteration timer to remove outliers from the first few
                # SGD iterations
                training_stats.ResetIterTimer()

            # allow finishing in the middle of an epoch
            if iter>args.max_iter:
                break
            # advance iteration
            iter+=1
Exemple #5
0
def main():
    """Main function"""
    args = parse_cfg()

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb',
                timers['roidb'].average_time)

    # Effective training sample size for one epoch
    train_size = roidb_size // args.batch_size * args.batch_size

    batchSampler = BatchSampler(sampler=MinibatchSampler(
        ratio_list, ratio_index),
                                batch_size=args.batch_size,
                                drop_last=True)
    dataset = RoiDataLoader(roidb, cfg.MODEL.NUM_CLASSES, training=True)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=batchSampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    dataiterator = iter(dataloader)

    ### Model ###
    maskRCNN = GetRCNNModel()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    gn_param_nameset = set()
    for name, module in maskRCNN.named_modules():
        if isinstance(module, nn.GroupNorm):
            gn_param_nameset.add(name + '.weight')
            gn_param_nameset.add(name + '.bias')
    gn_params = []
    gn_param_names = []
    bias_params = []
    bias_param_names = []
    nonbias_params = []
    nonbias_param_names = []
    nograd_param_names = []
    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
                bias_param_names.append(key)
            elif key in gn_param_nameset:
                gn_params.append(value)
                gn_param_names.append(key)
            else:
                nonbias_params.append(value)
                nonbias_param_names.append(key)
        else:
            nograd_param_names.append(key)
    assert (gn_param_nameset - set(nograd_param_names) -
            set(bias_param_names)) == set(gn_param_names)

    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [{
        'params': nonbias_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        bias_params,
        'lr':
        0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }, {
        'params': gn_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN
    }]
    # names of paramerters for each paramter
    param_names = [nonbias_param_names, bias_param_names, gn_param_names]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print(
                        'train_size value: %d different from the one in checkpoint: %d'
                        % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer,
                                                 checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[0][
        'lr']  # lr of non-bias parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)
    if cfg.TRAIN_SYNC_BN:
        # Shu:For synchorinized BN
        patch_replication_callback(maskRCNN)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name() + '_step'
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):

            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 -
                                                                 alpha) + alpha
                else:
                    raise KeyError(
                        'Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate(optimizer, lr,
                                               cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()
            for inner_iter in range(args.iter_size):
                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)

                for key in input_data:
                    if cfg.MODEL.LR_VIEW_ON or cfg.MODEL.GIF_ON or cfg.MODEL.LRASY_MAHA_ON:
                        if key != 'roidb' and key != 'data':  # roidb is a list of ndarrays with inconsistent length
                            input_data[key] = list(
                                map(Variable, input_data[key]))
                        if key == 'data':
                            input_data[key] = [
                                torch.squeeze(item) for item in input_data[key]
                            ]
                            input_data[key] = list(
                                map(Variable, input_data[key]))
                    else:
                        if key != 'roidb':  # roidb is a list of ndarrays with inconsistent length
                            input_data[key] = list(
                                map(Variable, input_data[key]))
                net_outputs = maskRCNN(**input_data)
                training_stats.UpdateIterStats(net_outputs, inner_iter)
                loss = net_outputs['total_loss']
                loss.backward()
            optimizer.step()
            training_stats.IterToc()

            training_stats.LogIterStats(step, lr)

            if (step + 1) % CHECKPOINT_PERIOD == 0:
                net_utils.train_save_ckpt(output_dir, args, step, train_size,
                                          maskRCNN, optimizer)

        # ---- Training ends ----
        # Save last checkpoint
        net_utils.train_save_ckpt(output_dir, args, step, train_size, maskRCNN,
                                  optimizer)

    except (RuntimeError, KeyboardInterrupt):
        del dataiterator
        logger.info('Save ckpt on exception ...')
        net_utils.train_save_ckpt(output_dir, args, step, train_size, maskRCNN,
                                  optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
Exemple #6
0
def main():

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the training code, sry bro :(.")
    else:
        cfg.CUDA = True
        cfg.NUM_GPUS = torch.cuda.device_count()

    #######~~~.Parameters stuff.~~~#######
    args = parse_args()
    print('Called with args:\n', args)

    # Enables fixed seed
    if args.fixed_seed:
        np.random.seed(cfg.RNG_SEED)
        torch.manual_seed(cfg.RNG_SEED)
        if cfg.CUDA:
            torch.cuda.manual_seed_all(cfg.RNG_SEED)
    torch.backends.cudnn.deterministic = True

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    check_and_overwrite_params(cfg, args)
    assert_and_infer_cfg()

    #indentificantion of a specific running
    args.run_name = misc_utils.get_run_name() + '_step'
    output_dir = misc_utils.get_output_dir(args)

    save_training_config(output_dir, cfg, args)

    if args.use_tfboard:
        from tensorboardX import SummaryWriter
        # Set the Tensorboard logger
        tblogger = SummaryWriter(output_dir)
    else:
        tblogger = None

    #######~~~.Dataset.~~~#######
    timers = defaultdict(Timer)

    # Roi datasets
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb',
                timers['roidb'].average_time)

    batchSampler = BatchSampler(sampler=MinibatchSampler(
        ratio_list, ratio_index),
                                batch_size=args.batch_size,
                                drop_last=True)
    dataset = RoiDataLoader(roidb, cfg.MODEL.NUM_CLASSES, training=True)

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=batchSampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    dataiterator = iter(dataloader)

    ########~~~.Model.~~~#######
    model = eval(args.model).loot_model(args)

    if cfg.CUDA:
        model.cuda()

    optimizer = OptimizerHandler(model, cfg)

    load_ckpt(args.load_ckpt, model, optimizer)

    #######~~~.Training Loop.~~~#######
    try:
        # Effective training sample size for one epoch
        train_size = roidb_size // args.batch_size * args.batch_size
        CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS /
                                (cfg.NUM_GPUS * cfg.TRAIN.ITERATION_SIZE))

        training_stats = TrainingStats(args, tblogger)

        model.train()

        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):

            optimizer.update_learning_rate(step)

            training_stats.IterTic()
            optimizer.zero_grad()
            for inner_iter in range(cfg.TRAIN.ITERATION_SIZE):
                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)

                for key in input_data:
                    if key != 'roidb':  # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))

                # input_data['inner_iter'] = torch.tensor((inner_iter))
                model.set_inner_iter(step)

                im_data = input_data['data'][0].cuda()
                rois = input_data['rois'][0].cuda().type(im_data.dtype)
                labels = input_data['labels'][0].cuda().type(im_data.dtype)

                net_outputs = model(im_data, rois, labels)

                training_stats.UpdateIterStats(net_outputs, inner_iter)
                loss = net_outputs['total_loss']
                loss.backward(retain_graph=True)

            optimizer.step()
            training_stats.IterToc()

            training_stats.LogIterStats(step, optimizer.get_lr())

            if (step + 1) % CHECKPOINT_PERIOD == 0:
                save_ckpt(output_dir, args, step, train_size, model, optimizer)

        # Training ends, saves the last checkpoint
        save_ckpt(output_dir, args, step, train_size, model, optimizer)

    except (RuntimeError, KeyboardInterrupt):
        del dataiterator
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, train_size, model, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
Exemple #7
0
def main():
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    if args.dataset == "coco2017":
        cfg.TRAIN.DATASETS = ('coco_2017_train', )
        cfg.MODEL.NUM_CLASSES = 81
    elif args.dataset == "keypoints_coco2017":
        cfg.TRAIN.DATASETS = ('keypoints_coco_2017_train', )
        cfg.MODEL.NUM_CLASSES = 2

    # ADE20k as a detection dataset
    elif args.dataset == "ade_train":
        cfg.TRAIN.DATASETS = ('ade_train', )
        cfg.MODEL.NUM_CLASSES = 446

    # Noisy CS6+WIDER datasets
    elif args.dataset == 'cs6_noise020+WIDER':
        cfg.TRAIN.DATASETS = ('cs6_noise020', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise050+WIDER':
        cfg.TRAIN.DATASETS = ('cs6_noise050', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise080+WIDER':
        cfg.TRAIN.DATASETS = ('cs6_noise080', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise085+WIDER':
        cfg.TRAIN.DATASETS = ('cs6_noise085', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise090+WIDER':
        cfg.TRAIN.DATASETS = ('cs6_noise090', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise095+WIDER':
        cfg.TRAIN.DATASETS = ('cs6_noise095', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise100+WIDER':
        cfg.TRAIN.DATASETS = ('cs6_noise100', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2

    # Just Noisy CS6 datasets
    elif args.dataset == 'cs6_noise020':
        cfg.TRAIN.DATASETS = ('cs6_noise020', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise030':
        cfg.TRAIN.DATASETS = ('cs6_noise030', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise040':
        cfg.TRAIN.DATASETS = ('cs6_noise040', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise050':
        cfg.TRAIN.DATASETS = ('cs6_noise050', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise060':
        cfg.TRAIN.DATASETS = ('cs6_noise060', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise070':
        cfg.TRAIN.DATASETS = ('cs6_noise070', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise080':
        cfg.TRAIN.DATASETS = ('cs6_noise080', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise085':
        cfg.TRAIN.DATASETS = ('cs6_noise085', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise090':
        cfg.TRAIN.DATASETS = ('cs6_noise090', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise095':
        cfg.TRAIN.DATASETS = ('cs6_noise095', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise100':
        cfg.TRAIN.DATASETS = ('cs6_noise100', )
        cfg.MODEL.NUM_CLASSES = 2

    # Cityscapes 7 classes
    elif args.dataset == "cityscapes":
        cfg.TRAIN.DATASETS = ('cityscapes_train', )
        cfg.MODEL.NUM_CLASSES = 8

    # BDD 7 classes
    elif args.dataset == "bdd_any_any_any":
        cfg.TRAIN.DATASETS = ('bdd_any_any_any_train', )
        cfg.MODEL.NUM_CLASSES = 8
    elif args.dataset == "bdd_any_any_daytime":
        cfg.TRAIN.DATASETS = ('bdd_any_any_daytime_train', )
        cfg.MODEL.NUM_CLASSES = 8
    elif args.dataset == "bdd_clear_any_daytime":
        cfg.TRAIN.DATASETS = ('bdd_clear_any_daytime_train', )
        cfg.MODEL.NUM_CLASSES = 8

    # Cistyscapes Pedestrian sets
    elif args.dataset == "cityscapes_peds":
        cfg.TRAIN.DATASETS = ('cityscapes_peds_train', )
        cfg.MODEL.NUM_CLASSES = 2

    # Cityscapes Car sets
    elif args.dataset == "cityscapes_cars_HPlen3+kitti_car_train":
        cfg.TRAIN.DATASETS = ('cityscapes_cars_HPlen3', 'kitti_car_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cityscapes_cars_HPlen5+kitti_car_train":
        cfg.TRAIN.DATASETS = ('cityscapes_cars_HPlen5', 'kitti_car_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cityscapes_car_train+kitti_car_train":
        cfg.TRAIN.DATASETS = ('cityscapes_car_train', 'kitti_car_train')
        cfg.MODEL.NUM_CLASSES = 2

    # KITTI Car set
    elif args.dataset == "kitti_car_train":
        cfg.TRAIN.DATASETS = ('kitti_car_train', )
        cfg.MODEL.NUM_CLASSES = 2

    # BDD pedestrians sets
    elif args.dataset == "bdd_peds":
        cfg.TRAIN.DATASETS = ('bdd_peds_train',
                              )  # bdd peds: clear_any_daytime
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "bdd_peds_full":
        cfg.TRAIN.DATASETS = ('bdd_peds_full_train', )  # bdd peds: any_any_any
        cfg.MODEL.NUM_CLASSES = 2
    # Pedestrians with constraints
    elif args.dataset == "bdd_peds_not_clear_any_daytime":
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_train', )
        cfg.MODEL.NUM_CLASSES = 2
    # Ashish's  20k samples videos
    elif args.dataset == "bdd_peds_not_clear_any_daytime_20k":
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_20k_train', )
        cfg.MODEL.NUM_CLASSES = 2
    # Source domain + Target domain detections
    elif args.dataset == "bdd_peds+DETS_20k":
        cfg.TRAIN.DATASETS = ('bdd_peds_dets_20k_target_domain',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    # Source domain + Target domain detections -- same 18k images as HP18k
    elif args.dataset == "bdd_peds+DETS18k":
        cfg.TRAIN.DATASETS = ('bdd_peds_dets18k_target_domain',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    # Only Dets
    elif args.dataset == "DETS20k":
        cfg.TRAIN.DATASETS = ('bdd_peds_dets_20k_target_domain', )
        cfg.MODEL.NUM_CLASSES = 2
    # Only Dets18k - same images as HP18k
    elif args.dataset == 'DETS18k':
        cfg.TRAIN.DATASETS = ('bdd_peds_dets18k_target_domain', )
        cfg.MODEL.NUM_CLASSES = 2
    # Only HP
    elif args.dataset == 'HP':
        cfg.TRAIN.DATASETS = ('bdd_peds_HP_target_domain', )
        cfg.MODEL.NUM_CLASSES = 2
    # Only HP 18k videos
    elif args.dataset == 'HP18k':
        cfg.TRAIN.DATASETS = ('bdd_peds_HP18k_target_domain', )
        cfg.MODEL.NUM_CLASSES = 2
    # Source domain + Target domain HP
    elif args.dataset == 'bdd_peds+HP':
        cfg.TRAIN.DATASETS = ('bdd_peds_train', 'bdd_peds_HP_target_domain')
        cfg.MODEL.NUM_CLASSES = 2
        # Source domain + Target domain HP 18k videos
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+HP18k':
        cfg.TRAIN.DATASETS = ('bdd_peds_HP18k_target_domain', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    #### Source domain + Target domain with different conf threshold theta ####
    elif args.dataset == 'bdd_peds+HP18k_thresh-050':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_thresh-050', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+HP18k_thresh-060':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_thresh-060', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+HP18k_thresh-070':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_thresh-070', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+HP18k_thresh-090':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_thresh-090', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    ##############################

    #### Data distillation on BDD -- for rebuttal
    elif args.dataset == 'bdd_peds+bdd_data_dist_small':
        cfg.TRAIN.DATASETS = ('bdd_data_dist_small', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+bdd_data_dist_mid':
        cfg.TRAIN.DATASETS = ('bdd_data_dist_mid', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+bdd_data_dist':
        cfg.TRAIN.DATASETS = ('bdd_data_dist', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    ##############################

    #### Source domain + **Labeled** Target domain with varying number of images
    elif args.dataset == 'bdd_peds+labeled_100':
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_train_100',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+labeled_075':
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_train_075',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+labeled_050':
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_train_050',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+labeled_025':
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_train_025',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+labeled_010':
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_train_010',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+labeled_005':
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_train_005',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+labeled_001':
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_train_001',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    ##############################

    # Source domain + Target domain HP tracker bboxes only
    elif args.dataset == 'bdd_peds+HP18k_track_only':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_track_only', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2

    ##### subsets of bdd_HP18k with different constraints
    # Source domain + HP tracker images at NIGHT
    elif args.dataset == 'bdd_peds+HP18k_any_any_night':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_any_any_night', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == 'bdd_peds+HP18k_rainy_any_daytime':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_rainy_any_daytime', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+HP18k_rainy_any_night':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_rainy_any_night', 'bdd_peds_train')

        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+HP18k_overcast,rainy_any_daytime':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_overcast,rainy_any_daytime',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+HP18k_overcast,rainy_any_night':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_overcast,rainy_any_night',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == 'bdd_peds+HP18k_overcast,rainy,snowy_any_daytime':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_overcast,rainy,snowy_any_daytime',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    #############  end of bdd constraned subsets  #####################

    # Source domain + Target domain HP18k -- after histogram matching
    elif args.dataset == 'bdd_peds+HP18k_remap_hist':
        cfg.TRAIN.DATASETS = ('bdd_peds_HP18k_target_domain_remap_hist',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+HP18k_remap_cityscape_hist':
        cfg.TRAIN.DATASETS = (
            'bdd_peds_HP18k_target_domain_remap_cityscape_hist',
            'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    # Source domain + Target domain HP18k -- after histogram matching
    elif args.dataset == 'bdd_peds+HP18k_remap_random':
        cfg.TRAIN.DATASETS = ('bdd_peds_HP18k_target_domain_remap_random',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    # Source+Noisy Target domain -- prevent domain adv from using HP roi info
    elif args.dataset == 'bdd_peds+bdd_HP18k_noisy_100k':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_noisy_100k', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+bdd_HP18k_noisy_080':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_noisy_080', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+bdd_HP18k_noisy_060':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_noisy_060', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+bdd_HP18k_noisy_070':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_noisy_070', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == "wider_train":
        cfg.TRAIN.DATASETS = ('wider_train', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-subset":
        cfg.TRAIN.DATASETS = ('cs6-subset', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-subset-score":
        cfg.TRAIN.DATASETS = ('cs6-subset-score', )
    elif args.dataset == "cs6-subset-gt":
        cfg.TRAIN.DATASETS = ('cs6-subset-gt', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-3013-gt":
        cfg.TRAIN.DATASETS = ('cs6-3013-gt',
                              )  # DEBUG: overfit on one video annots
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-subset-gt+WIDER":
        cfg.TRAIN.DATASETS = ('cs6-subset-gt', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-subset+WIDER":
        cfg.TRAIN.DATASETS = ('cs6-subset', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-gt":
        cfg.TRAIN.DATASETS = ('cs6-train-gt', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-gt-noisy-0.3":
        cfg.TRAIN.DATASETS = ('cs6-train-gt-noisy-0.3', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-gt-noisy-0.5":
        cfg.TRAIN.DATASETS = ('cs6-train-gt-noisy-0.5', )
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == "cs6-train-det-score":
        cfg.TRAIN.DATASETS = ('cs6-train-det-score', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-det-score-0.5":
        cfg.TRAIN.DATASETS = ('cs6-train-det-score-0.5', )
    elif args.dataset == "cs6-train-det":
        cfg.TRAIN.DATASETS = ('cs6-train-det', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-det-0.5":
        cfg.TRAIN.DATASETS = ('cs6-train-det-0.5', )
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == "cs6-train-hp":
        cfg.TRAIN.DATASETS = ('cs6-train-hp', )
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == "cs6-train-easy-gt":
        cfg.TRAIN.DATASETS = ('cs6-train-easy-gt', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-easy-gt-sub":
        cfg.TRAIN.DATASETS = ('cs6-train-easy-gt-sub', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-easy-hp":
        cfg.TRAIN.DATASETS = ('cs6-train-easy-hp', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-easy-det":
        cfg.TRAIN.DATASETS = ('cs6-train-easy-det', )
        cfg.MODEL.NUM_CLASSES = 2

        # Joint training with CS6 and WIDER
    elif args.dataset == "cs6-train-easy-gt-sub+WIDER":
        cfg.TRAIN.DATASETS = ('cs6-train-easy-gt-sub', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-gt+WIDER":
        cfg.TRAIN.DATASETS = ('cs6-train-gt', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == "cs6-train-hp+WIDER":
        cfg.TRAIN.DATASETS = ('cs6-train-hp', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == "cs6-train-dummy+WIDER":
        cfg.TRAIN.DATASETS = ('cs6-train-dummy', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == "cs6-train-det+WIDER":
        cfg.TRAIN.DATASETS = ('cs6-train-det', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2

    # Dets dataset created by removing tracker results from the HP json
    elif args.dataset == "cs6_train_det_from_hp+WIDER":
        cfg.TRAIN.DATASETS = ('cs6_train_det_from_hp', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    # Dataset created by removing det results from the HP json -- HP tracker only
    elif args.dataset == "cs6_train_hp_tracker_only+WIDER":
        cfg.TRAIN.DATASETS = ('cs6_train_hp_tracker_only', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    # HP dataset with noisy labels: used to prevent DA from getting any info from HP
    elif args.dataset == "cs6_train_hp_noisy_100+WIDER":
        cfg.TRAIN.DATASETS = ('cs6_train_hp_noisy_100', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2

    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH
    original_num_gpus = cfg.NUM_GPUS
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    effective_batch_size = args.iter_size * args.batch_size
    print('effective_batch_size = batch_size * iter_size = %d * %d' %
          (args.batch_size, args.iter_size))

    print('Adaptive config changes:')
    print('    effective_batch_size: %d --> %d' %
          (original_batch_size, effective_batch_size))
    print('    NUM_GPUS:             %d --> %d' %
          (original_num_gpus, cfg.NUM_GPUS))
    print('    IMS_PER_BATCH:        %d --> %d' %
          (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH))

    ### Adjust learning based on batch size change linearly
    # For iter_size > 1, gradients are `accumulated`, so lr is scaled based
    # on batch_size instead of effective_batch_size
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch_size change:\n'
          '    BASE_LR: {} --> {}'.format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Adjust solver steps
    step_scale = original_batch_size / effective_batch_size
    old_solver_steps = cfg.SOLVER.STEPS
    old_max_iter = cfg.SOLVER.MAX_ITER
    cfg.SOLVER.STEPS = list(
        map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS))
    cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5)
    print(
        'Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n'
        '    SOLVER.STEPS: {} --> {}\n'
        '    SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps,
                                                cfg.SOLVER.STEPS, old_max_iter,
                                                cfg.SOLVER.MAX_ITER))

    # Scale FPN rpn_proposals collect size (post_nms_topN) in `collect` function
    # of `collect_and_distribute_fpn_rpn_proposals.py`
    #
    # post_nms_topN = int(cfg[cfg_key].RPN_POST_NMS_TOP_N * cfg.FPN.RPN_COLLECT_SCALE + 0.5)
    if cfg.FPN.FPN_ON and cfg.MODEL.FASTER_RCNN:
        cfg.FPN.RPN_COLLECT_SCALE = cfg.TRAIN.IMS_PER_BATCH / original_ims_per_batch
        print(
            'Scale FPN rpn_proposals collect size directly propotional to the change of IMS_PER_BATCH:\n'
            '    cfg.FPN.RPN_COLLECT_SCALE: {}'.format(
                cfg.FPN.RPN_COLLECT_SCALE))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma
    assert_and_infer_cfg()

    timers = defaultdict(Timer)
    """def _init_fn(worker_id):
        random.seed(999)
        np.random.seed(999)
        torch.cuda.manual_seed(999)
        torch.cuda.manual_seed_all(999)
        torch.manual_seed(999)
        torch.backends.cudnn.deterministic = True
    """

    ### Dataset ###
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb',
                timers['roidb'].average_time)
    if cfg.TRAIN.JOINT_TRAINING:
        if len(cfg.TRAIN.DATASETS) == 2:
            print('Joint training on two datasets')
        else:
            raise NotImplementedError

        joint_training_roidb = []
        for i, dataset_name in enumerate(cfg.TRAIN.DATASETS):
            # ROIDB construction
            timers['roidb'].tic()
            roidb, ratio_list, ratio_index = combined_roidb_for_training(
                (dataset_name), cfg.TRAIN.PROPOSAL_FILES)
            timers['roidb'].toc()
            roidb_size = len(roidb)
            logger.info('{:d} roidb entries'.format(roidb_size))
            logger.info('Takes %.2f sec(s) to construct roidb',
                        timers['roidb'].average_time)

            if i == 0:
                roidb_size = len(roidb)

            batchSampler = BatchSampler(sampler=MinibatchSampler(
                ratio_list, ratio_index),
                                        batch_size=args.batch_size,
                                        drop_last=True)
            dataset = RoiDataLoader(roidb,
                                    cfg.MODEL.NUM_CLASSES,
                                    training=True)
            dataloader = torch.utils.data.DataLoader(
                dataset,
                batch_sampler=batchSampler,
                num_workers=cfg.DATA_LOADER.NUM_THREADS,
                collate_fn=collate_minibatch)
            #worker_init_fn=_init_fn)
            # decrease num-threads when using two dataloaders
            dataiterator = iter(dataloader)

            joint_training_roidb.append({
                'dataloader': dataloader,
                'dataiterator': dataiterator,
                'dataset_name': dataset_name
            })
    else:
        roidb, ratio_list, ratio_index = combined_roidb_for_training(
            cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
        timers['roidb'].toc()
        roidb_size = len(roidb)
        logger.info('{:d} roidb entries'.format(roidb_size))
        logger.info('Takes %.2f sec(s) to construct roidb',
                    timers['roidb'].average_time)

        batchSampler = BatchSampler(sampler=MinibatchSampler(
            ratio_list, ratio_index),
                                    batch_size=args.batch_size,
                                    drop_last=True)
        dataset = RoiDataLoader(roidb, cfg.MODEL.NUM_CLASSES, training=True)
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_sampler=batchSampler,
            num_workers=cfg.DATA_LOADER.NUM_THREADS,
            collate_fn=collate_minibatch)
        #worker_init_fn=init_fn)
        dataiterator = iter(dataloader)

    # Effective training sample size for one epoch
    train_size = roidb_size // args.batch_size * args.batch_size

    if cfg.TRAIN.JOINT_SELECTIVE_FG:
        orig_fg_batch_ratio = cfg.TRAIN.FG_FRACTION

    ### Model ###
    maskRCNN = Generalized_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    gn_param_nameset = set()
    for name, module in maskRCNN.named_modules():
        if isinstance(module, nn.GroupNorm):
            gn_param_nameset.add(name + '.weight')
            gn_param_nameset.add(name + '.bias')
    gn_params = []
    gn_param_names = []
    bias_params = []
    bias_param_names = []
    nonbias_params = []
    nonbias_param_names = []
    nograd_param_names = []
    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
                bias_param_names.append(key)
            elif key in gn_param_nameset:
                gn_params.append(value)
                gn_param_names.append(key)
            else:
                nonbias_params.append(value)
                nonbias_param_names.append(key)
        else:
            nograd_param_names.append(key)
    assert (gn_param_nameset - set(nograd_param_names) -
            set(bias_param_names)) == set(gn_param_names)

    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [{
        'params': nonbias_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        bias_params,
        'lr':
        0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }, {
        'params': gn_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN
    }]
    # names of paramerters for each paramter
    param_names = [nonbias_param_names, bias_param_names, gn_param_names]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print(
                        'train_size value: %d different from the one in checkpoint: %d'
                        % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer,
                                                 checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[0][
        'lr']  # lr of non-bias parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name() + '_' + str(
        cfg.TRAIN.DATASETS) + '_step'
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):
            """
            random.seed(cfg.RNG_SEED)
            np.random.seed(cfg.RNG_SEED)
            torch.cuda.manual_seed(cfg.RNG_SEED)
            torch.cuda.manual_seed_all(cfg.RNG_SEED)
            torch.manual_seed(cfg.RNG_SEED)
            torch.backends.cudnn.deterministic = True
            """

            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 -
                                                                 alpha) + alpha
                else:
                    raise KeyError(
                        'Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate(optimizer, lr,
                                               cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()

            for inner_iter in range(args.iter_size):
                # use a iter counter for optional alternating batches
                if args.iter_size == 1:
                    iter_counter = step
                else:
                    iter_counter = inner_iter

                if cfg.TRAIN.JOINT_TRAINING:
                    # alternate batches between dataset[0] and dataset[1]
                    if iter_counter % 2 == 0:
                        if True:  #DEBUG:
                            print('Dataset: %s' %
                                  joint_training_roidb[0]['dataset_name'])
                        dataloader = joint_training_roidb[0]['dataloader']
                        dataiterator = joint_training_roidb[0]['dataiterator']
                        # NOTE: if available FG samples cannot fill minibatch
                        # then batchsize will be smaller than cfg.TRAIN.BATCH_SIZE_PER_IM.
                    else:
                        if True:  #DEBUG:
                            print('Dataset: %s' %
                                  joint_training_roidb[1]['dataset_name'])
                        dataloader = joint_training_roidb[1]['dataloader']
                        dataiterator = joint_training_roidb[1]['dataiterator']

                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    # end of epoch for dataloader
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)
                    if cfg.TRAIN.JOINT_TRAINING:
                        if iter_counter % 2 == 0:
                            joint_training_roidb[0][
                                'dataiterator'] = dataiterator
                        else:
                            joint_training_roidb[1][
                                'dataiterator'] = dataiterator

                for key in input_data:
                    if key != 'roidb':  # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))

                net_outputs = maskRCNN(**input_data)
                training_stats.UpdateIterStats(net_outputs, inner_iter)
                loss = net_outputs['total_loss']
                # [p.data.get_device() for p in maskRCNN.parameters()]
                # [(name, p.data.get_device()) for name, p in maskRCNN.named_parameters()]
                loss.backward()

            optimizer.step()
            training_stats.IterToc()

            training_stats.LogIterStats(step, lr)

            if (step + 1) % CHECKPOINT_PERIOD == 0:
                save_ckpt(output_dir, args, step, train_size, maskRCNN,
                          optimizer)

        # ---- Training ends ----
        # Save last checkpoint
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

    except (RuntimeError, KeyboardInterrupt):
        del dataiterator
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
Exemple #8
0
def main():
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    if args.dataset == "davis2017":
        cfg.TRAIN.DATASETS = ('davis_train', )
        #For davis, coco category is used.
        cfg.MODEL.NUM_CLASSES = 81  #80 foreground + 1 background
    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    if cfg.MODEL.IDENTITY_TRAINING and cfg.MODEL.IDENTITY_REPLACE_CLASS:
        cfg.MODEL.NUM_CLASSES = 145
        cfg.MODEL.IDENTITY_TRAINING = False
        cfg.MODEL.ADD_UNKNOWN_CLASS = False

    #Add unknow class type if necessary.
    if cfg.MODEL.ADD_UNKNOWN_CLASS is True:
        cfg.MODEL.NUM_CLASSES += 1

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH
    original_num_gpus = cfg.NUM_GPUS
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    effective_batch_size = args.iter_size * args.batch_size

    print('effective_batch_size = batch_size * iter_size = %d * %d' %
          (args.batch_size, args.iter_size))
    print('Adaptive config changes:')
    print('    effective_batch_size: %d --> %d' %
          (original_batch_size, effective_batch_size))
    print('    NUM_GPUS:             %d --> %d' %
          (original_num_gpus, cfg.NUM_GPUS))
    print('    IMS_PER_BATCH:        %d --> %d' %
          (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH))

    ### Adjust learning based on batch size change linearly
    # For iter_size > 1, gradients are `accumulated`, so lr is scaled based
    # on batch_size instead of effective_batch_size
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    cfg.SOLVER.BASE_LR *= 1.0 / cfg.MODEL.SEQUENCE_LENGTH
    print(
        'Adjust BASE_LR linearly according to batch_size change and sequence length change:\n'
        '    BASE_LR: {} --> {}'.format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Adjust solver steps
    step_scale = original_batch_size / effective_batch_size
    old_solver_steps = cfg.SOLVER.STEPS
    old_max_iter = cfg.SOLVER.MAX_ITER
    cfg.SOLVER.STEPS = list(
        map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS))
    cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5)
    print(
        'Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n'
        '    SOLVER.STEPS: {} --> {}\n'
        '    SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps,
                                                cfg.SOLVER.STEPS, old_max_iter,
                                                cfg.SOLVER.MAX_ITER))

    # Scale FPN rpn_proposals collect size (post_nms_topN) in `collect` function
    # of `collect_and_distribute_fpn_rpn_proposals.py`
    #
    # post_nms_topN = int(cfg[cfg_key].RPN_POST_NMS_TOP_N * cfg.FPN.RPN_COLLECT_SCALE + 0.5)
    if cfg.FPN.FPN_ON and cfg.MODEL.FASTER_RCNN:
        cfg.FPN.RPN_COLLECT_SCALE = cfg.TRAIN.IMS_PER_BATCH / original_ims_per_batch
        print(
            'Scale FPN rpn_proposals collect size directly propotional to the change of IMS_PER_BATCH:\n'
            '    cfg.FPN.RPN_COLLECT_SCALE: {}'.format(
                cfg.FPN.RPN_COLLECT_SCALE))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma
    assert_and_infer_cfg()

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    merged_roidb, seq_num, seq_start_end = sequenced_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES, load_inv_db=True)

    timers['roidb'].toc()
    roidb_size = len(merged_roidb)
    logger.info('{:d} roidbs sequences.'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidbs',
                timers['roidb'].average_time)

    # Effective training sample size for one epoch, number of sequences.
    train_size = roidb_size // args.batch_size * args.batch_size

    ### Model ###
    maskRCNN = vos_model_builder.Generalized_VOS_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    gn_param_nameset = set()
    for name, module in maskRCNN.named_modules():
        if isinstance(module, nn.GroupNorm):
            gn_param_nameset.add(name + '.weight')
            gn_param_nameset.add(name + '.bias')
    gn_params = []
    gn_param_names = []
    bias_params = []
    bias_param_names = []
    nonbias_params = []
    nonbias_param_names = []
    nograd_param_names = []
    for key, value in maskRCNN.named_parameters():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
                bias_param_names.append(key)
            elif key in gn_param_nameset:
                gn_params.append(value)
                gn_param_names.append(key)
            else:
                nonbias_params.append(value)
                nonbias_param_names.append(key)
        else:
            nograd_param_names.append(key)
    assert (gn_param_nameset - set(nograd_param_names) -
            set(bias_param_names)) == set(gn_param_names)

    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [{
        'params': nonbias_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        bias_params,
        'lr':
        0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }, {
        'params': gn_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN
    }]
    # names of paramerters for each paramter
    param_names = [nonbias_param_names, bias_param_names, gn_param_names]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils.load_ckpt_no_mapping(maskRCNN, checkpoint['model'])
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print(
                        'train_size value: %d different from the one in checkpoint: %d'
                        % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            optimizer.load_state_dict(checkpoint['optimizer'])
            # misc_utils.load_optimizer_state_dict(optimizer, checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN,
                              args.load_detectron,
                              force_load_all=False)

    lr = optimizer.param_groups[0][
        'lr']  # lr of non-bias parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name() + '_step'
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    dataloader, dataiterator, warmup_length = gen_sequence_data_sampler(
        merged_roidb, seq_num, seq_start_end, use_seq_warmup=False)
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):
            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 -
                                                                 alpha) + alpha
                else:
                    raise KeyError(
                        'Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate(optimizer, lr,
                                               cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()

            if cfg.TRAIN.USE_SEQ_WARMUP is True and step >= cfg.TRAIN.ITER_BEFORE_USE_SEQ_WARMUP and (
                (step - cfg.TRAIN.ITER_BEFORE_USE_SEQ_WARMUP) %
                    cfg.TRAIN.WARMUP_LENGTH_CHANGE_STEP) == 0:
                #TODO better dataiterator creator.
                raise NotImplementedError('not able to delete.')
                dataloader, dataiterator, warmup_length = gen_sequence_data_sampler(
                    merged_roidb, seq_num, seq_start_end, use_seq_warmup=True)
                print('update warmup length:', warmup_length)
            try:
                input_data_sequence = next(dataiterator)
            except StopIteration:
                dataiterator = iter(dataloader)
                input_data_sequence = next(dataiterator)

            # clean hidden states before training.
            maskRCNN.module.clean_hidden_states()
            maskRCNN.module.clean_flow_features()

            assert len(input_data_sequence['data']
                       ) == cfg.MODEL.SEQUENCE_LENGTH + warmup_length, print(
                           len(input_data_sequence['data']), '!=',
                           cfg.MODEL.SEQUENCE_LENGTH + warmup_length)

            # if train_part == 0: train backbone.
            # else train_part == 1: train heads.
            train_part = 1
            if cfg.TRAIN.ALTERNATE_TRAINING:
                if step % 2 == 0:
                    maskRCNN.module.freeze_conv_body_only()
                    train_part = 1
                else:
                    maskRCNN.module.train_conv_body_only()
                    train_part = 0
            # this is used for longer sequence training.
            # when reach maximum trainable length, detach the hidden states.
            cnter_for_detach_hidden_states = 0
            for inner_iter in range(cfg.MODEL.SEQUENCE_LENGTH + warmup_length):
                #get input_data
                input_data = {}
                for key in input_data_sequence.keys():
                    input_data[key] = input_data_sequence[key][
                        inner_iter:inner_iter + 1]
                for key in input_data:
                    if key != 'roidb' and key != 'data_flow':  # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))
                    if key == 'data_flow':
                        if inner_iter != 0 and input_data[key][0][0][
                                0] is not None:  # flow is not None.
                            input_data[key] = [
                                Variable(
                                    torch.tensor(
                                        np.expand_dims(
                                            np.squeeze(
                                                np.array(input_data[key],
                                                         dtype=np.float32)),
                                            0),
                                        device=input_data['data'][0].device))
                            ]
                            assert input_data['data'][0].shape[
                                -2:] == input_data[key][0].shape[
                                    -2:], "Spatial shape of image and flow are not equal."
                        else:
                            input_data[key] = [None]
                if cfg.TRAIN.USE_SEQ_WARMUP and inner_iter < warmup_length:
                    maskRCNN.module.set_stop_after_hidden_states(stop=True)
                    net_outputs = maskRCNN(**input_data)
                    assert net_outputs is None
                    maskRCNN.module.detach_hidden_states()
                    maskRCNN.module.set_stop_after_hidden_states(stop=False)
                    continue

                net_outputs = maskRCNN(**input_data)
                training_stats.UpdateIterStats(net_outputs, inner_iter)
                loss = net_outputs['total_loss']

                if train_part == 0 or cnter_for_detach_hidden_states >= cfg.TRAIN.MAX_TRAINABLE_SEQ_LENGTH or inner_iter == cfg.MODEL.SEQUENCE_LENGTH + warmup_length - 1:
                    # if reach the max trainable length or end of the sequence. free the graph and detach the hidden states.
                    loss.backward()
                    maskRCNN.module.detach_hidden_states()
                    cnter_for_detach_hidden_states = 0
                else:
                    loss.backward(retain_graph=True)
                    cnter_for_detach_hidden_states += 1
                #TODO step every time?
                if cnter_for_detach_hidden_states == 0:
                    optimizer.step()
                    optimizer.zero_grad()

            training_stats.IterToc()
            training_stats.LogIterStats(step, lr)

            if (step + 1) % CHECKPOINT_PERIOD == 0:
                save_ckpt(output_dir, args, step, train_size, maskRCNN,
                          optimizer)

        # ---- Training ends ----
        # Save last checkpoint
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

    except (RuntimeError):
        del dataiterator
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
    except (KeyboardInterrupt):
        del dataiterator
        logger.info('Save ckpt on Keyboard Interrupt ...')
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
def main():
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    if args.dataset == "coco2017":
        cfg.TRAIN.DATASETS = ('coco_2017_train', )
        cfg.MODEL.NUM_CLASSES = 81
    elif args.dataset == "keypoints_coco2017":
        cfg.TRAIN.DATASETS = ('keypoints_coco_2017_train', )
        cfg.MODEL.NUM_CLASSES = 2
    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    print('Batch size change from {} (in config file) to {}'.format(
        original_batch_size, args.batch_size))
    print('NUM_GPUs: %d, TRAIN.IMS_PER_BATCH: %d' %
          (cfg.NUM_GPUS, cfg.TRAIN.IMS_PER_BATCH))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Adjust learning based on batch size change linearly
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch size change: {} --> {}'.
          format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma
    assert_and_infer_cfg()

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb',
                timers['roidb'].average_time)

    # Effective training sample size for one epoch
    train_size = roidb_size // args.batch_size * args.batch_size

    sampler = MinibatchSampler(ratio_list, ratio_index)
    dataset = RoiDataLoader(roidb, cfg.MODEL.NUM_CLASSES, training=True)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        drop_last=True,
        sampler=sampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    dataiterator = iter(dataloader)

    ### Model ###
    maskRCNN = Generalized_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    bias_params = []
    nonbias_params = []
    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
            else:
                nonbias_params.append(value)
    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [{
        'params': nonbias_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        bias_params,
        'lr':
        0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                assert checkpoint['train_size'] == train_size
            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer,
                                                 checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[0][
        'lr']  # lr of non-bias parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name() + '_step'
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    try:
        logger.info('Training starts !')
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):

            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 -
                                                                 alpha) + alpha
                else:
                    raise KeyError(
                        'Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate(optimizer, lr,
                                               cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            try:
                input_data = next(dataiterator)
            except StopIteration:
                dataiterator = iter(dataloader)
                input_data = next(dataiterator)

            for key in input_data:
                if key != 'roidb':  # roidb is a list of ndarrays with inconsistent length
                    input_data[key] = list(map(Variable, input_data[key]))

            training_stats.IterTic()
            net_outputs = maskRCNN(**input_data)
            training_stats.IterToc()

            training_stats.UpdateIterStats(net_outputs)
            training_stats.LogIterStats(step, lr)

            loss = net_outputs['total_loss']
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (step + 1) % CHECKPOINT_PERIOD == 0:
                save_ckpt(output_dir, args, step, train_size, maskRCNN,
                          optimizer)

        # Save last checkpoint
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

    except (RuntimeError, KeyboardInterrupt) as e:
        print('Save on exception:', e)
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        # ---- Training ends ----
        if args.use_tfboard and not args.no_save:
            tblogger.close()
Exemple #10
0
def train_val(model,
              args,
              optimizer,
              lr,
              dataloader,
              train_size,
              output_dir,
              tblogger=None):

    dataiterator = iter(dataloader)
    model.train()

    CHECKPOINT_PERIOD = cfg.TRAIN.SNAPSHOT_ITERS

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)

    try:
        logger.info('Training starts !')
        step = args.start_step

        best_ap = 0
        best_step = 0
        running_tr_loss = 0.
        DRAW_STEP = args.start_step

        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):
            # print('stepppp: ', step)
            # print(cfg.SOLVER.WARM_UP_ITERS)
            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 -
                                                                 alpha) + alpha
                else:
                    raise KeyError(
                        'Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate(optimizer, lr,
                                               cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()

            for inner_iter in range(args.iter_size):
                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)

                for key in input_data:
                    if key != 'roidb':  # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))

                net_outputs = model(**input_data)
                training_stats.UpdateIterStats(net_outputs, inner_iter)
                loss = net_outputs['total_loss']
                #print('555')
                running_tr_loss += loss.item()
                if loss.requires_grad:
                    loss.backward()

            optimizer.step()
            training_stats.IterToc()
            training_stats.LogIterStats(step, lr)
            # print('CHECKPOINT_PERIOD: ', CHECKPOINT_PERIOD)

            # if (step+1) % 800==0 or step==cfg.SOLVER.MAX_ITER-1:
            #     print('\tAverage Training Runing loss of step {}: {:.8f}'.format(step+1, running_tr_loss/(step+1-DRAW_STEP)))
            #     tblogger.add_scalar('Runing_Training_loss', running_tr_loss/(step+1-DRAW_STEP), step+1)
            #     DRAW_STEP = step+1
            #     running_tr_loss = 0.

            CHECKPOINT_PERIOD = 2000
            if ((step + 1) % CHECKPOINT_PERIOD
                    == 0) or step == cfg.SOLVER.MAX_ITER - 1:
                if (step + 1) > 15000:
                    save_ckpt(output_dir, args, step, train_size, model,
                              optimizer)

        # ---- Training ends ----
        # Save last checkpoint
        #save_ckpt(output_dir, args, step, train_size, model, optimizer)

    except (RuntimeError, KeyboardInterrupt):
        del dataiterator
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, train_size, model, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
Exemple #11
0
def main():
    """Main function"""
    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    merge_cfg_from_file(args.cfg_file)

    # Some manual adjustment for the ApolloScape dataset parameters here
    cfg.OUTPUT_DIR = args.output_dir
    cfg.TRAIN.DATASETS = 'Car3D'
    cfg.MODEL.NUM_CLASSES = 8
    if cfg.CAR_CLS.SIM_MAT_LOSS:
        cfg.MODEL.NUMBER_CARS = 79
    else:
        # Loss is only cross entropy, hence, we detect only car categories in the training set.
        cfg.MODEL.NUMBER_CARS = 34
    cfg.TRAIN.MIN_AREA = 196  # 14*14
    cfg.TRAIN.USE_FLIPPED = False  # Currently I don't know how to handle the flipped case
    cfg.TRAIN.IMS_PER_BATCH = 1

    cfg.NUM_GPUS = torch.cuda.device_count()
    effective_batch_size = cfg.TRAIN.IMS_PER_BATCH * cfg.NUM_GPUS * args.iter_size

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH
    original_num_gpus = cfg.NUM_GPUS
    if args.batch_size is None:
        args.batch_size = original_batch_size

    assert (args.batch_size %
            cfg.NUM_GPUS) == 0, 'batch_size: %d, NUM_GPUS: %d' % (
                args.batch_size, cfg.NUM_GPUS)

    print('effective_batch_size = batch_size * iter_size = %d * %d' %
          (args.batch_size, args.iter_size))
    print('Adaptive config changes:')
    print('    effective_batch_size: %d --> %d' %
          (original_batch_size, effective_batch_size))
    print('    NUM_GPUS:             %d --> %d' %
          (original_num_gpus, cfg.NUM_GPUS))
    print('    IMS_PER_BATCH:        %d --> %d' %
          (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH))

    ### Adjust learning based on batch size change linearly
    # For iter_size > 1, gradients are `accumulated`, so lr is scaled based
    # on batch_size instead of effective_batch_size
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print(
        'Adjust BASE_LR linearly according to batch_size change:\n BASE_LR: {} --> {}'
        .format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Adjust solver steps
    step_scale = original_batch_size / effective_batch_size
    old_solver_steps = cfg.SOLVER.STEPS
    old_max_iter = cfg.SOLVER.MAX_ITER
    cfg.SOLVER.STEPS = list(
        map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS))
    cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5)
    print(
        'Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n'
        '    SOLVER.STEPS: {} --> {}\n'
        '    SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps,
                                                cfg.SOLVER.STEPS, old_max_iter,
                                                cfg.SOLVER.MAX_ITER))

    # Scale FPN rpn_proposals collect size (post_nms_topN) in `collect` function
    # of `collect_and_distribute_fpn_rpn_proposals.py`
    #
    # post_nms_topN = int(cfg[cfg_key].RPN_POST_NMS_TOP_N * cfg.FPN.RPN_COLLECT_SCALE + 0.5)
    if cfg.FPN.FPN_ON and cfg.MODEL.FASTER_RCNN:
        cfg.FPN.RPN_COLLECT_SCALE = cfg.TRAIN.IMS_PER_BATCH / original_ims_per_batch
        print(
            'Scale FPN rpn_proposals collect size directly propotional to the change of IMS_PER_BATCH:\n'
            '    cfg.FPN.RPN_COLLECT_SCALE: {}'.format(
                cfg.FPN.RPN_COLLECT_SCALE))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma
    assert_and_infer_cfg()

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    if cfg.MODEL.LOSS_3D_2D_ON:
        roidb, ratio_list, ratio_index, ds = combined_roidb_for_training(
            cfg.TRAIN.DATASETS, args.dataset_dir)
    else:
        roidb, ratio_list, ratio_index = combined_roidb_for_training(
            cfg.TRAIN.DATASETS, args.dataset_dir)

    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb',
                timers['roidb'].average_time)

    # Effective training sample size for one epoch
    train_size = roidb_size // args.batch_size * args.batch_size

    sampler = MinibatchSampler(ratio_list, ratio_index)
    dataset = RoiDataLoader(roidb,
                            cfg.MODEL.NUM_CLASSES,
                            training=True,
                            valid_keys=[
                                'has_visible_keypoints', 'boxes', 'seg_areas',
                                'gt_classes', 'gt_overlaps',
                                'box_to_gt_ind_map', 'is_crowd',
                                'car_cat_classes', 'poses', 'quaternions',
                                'im_info'
                            ])
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        drop_last=True,
        sampler=sampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    dataiterator = iter(dataloader)

    ### Model ###
    if cfg.MODEL.LOSS_3D_2D_ON:
        maskRCNN = Generalized_RCNN(ds.Car3D)
    else:
        maskRCNN = Generalized_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    bias_params = []
    bias_param_names = []
    nonbias_params = []
    nonbias_param_names = []
    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
                bias_param_names.append(key)
            else:
                nonbias_params.append(value)
                nonbias_param_names.append(key)
    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [{
        'params': nonbias_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        bias_params,
        'lr':
        0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN,
                            checkpoint['model'],
                            ignore_list=args.ckpt_ignore_head)
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print(
                        'train_size value: %d different from the one in checkpoint: %d'
                        % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer,
                                                 checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  # TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[0][
        'lr']  # lr of non-bias parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name() + '_step'
    # output_dir = os.path.join('/media/SSD_1TB/zzy/ApolloScape/ECCV2018_apollo/train', args.run_name)
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()
    CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)
    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    # warmup_factor_trans = 1.0
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):

            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 -
                                                                 alpha) + alpha
                    # warmup_factor_trans = cfg.SOLVER.WARM_UP_FACTOR_TRANS * (1 - alpha) + alpha
                    # warmup_factor_trans *= cfg.TRANS_HEAD.LOSS_BETA
                    warmup_factor_trans = 1.0
                else:
                    raise KeyError(
                        'Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate(optimizer, lr,
                                               cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(
                    cfg.SOLVER.STEPS
            ) and step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()
            for inner_iter in range(args.iter_size):
                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)

                for key in input_data:
                    if key != 'roidb':  # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))

                net_outputs = maskRCNN(**input_data)

                net_outputs['losses'][
                    'loss_car_cls'] *= cfg.CAR_CLS.CAR_CLS_LOSS_BETA
                net_outputs['losses']['loss_rot'] *= cfg.CAR_CLS.ROT_LOSS_BETA
                if cfg.MODEL.TRANS_HEAD_ON:
                    net_outputs['losses'][
                        'loss_trans'] *= cfg.TRANS_HEAD.TRANS_LOSS_BETA

                training_stats.UpdateIterStats_car_3d(net_outputs)

                # start training
                # loss_car_cls: 2.233790, loss_rot: 0.296853, loss_trans: ~100
                loss = net_outputs['losses']['loss_car_cls'] + net_outputs[
                    'losses']['loss_rot']
                if cfg.MODEL.TRANS_HEAD_ON:
                    loss += net_outputs['losses']['loss_trans']
                if cfg.MODEL.LOSS_3D_2D_ON:
                    loss += net_outputs['losses']['UV_projection_loss']
                if not cfg.TRAIN.FREEZE_CONV_BODY and not cfg.TRAIN.FREEZE_RPN and not cfg.TRAIN.FREEZE_FPN:
                    loss += net_outputs['total_loss_conv']

                loss.backward()
            optimizer.step()
            training_stats.IterToc()

            training_stats.LogIterStats(step, lr, warmup_factor_trans)

            if (step + 1) % CHECKPOINT_PERIOD == 0:
                save_ckpt(output_dir, args, step, train_size, maskRCNN,
                          optimizer)

        # ---- Training ends ----
        # Save last checkpoint
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

    except (RuntimeError, KeyboardInterrupt):
        del dataiterator
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()