예제 #1
0
	def update_learning_rate(self, step):
		# Warm up
		if step < self.cfg.SOLVER.WARM_UP_ITERS:
		    method = self.cfg.SOLVER.WARM_UP_METHOD
		    if method == 'constant':
		        warmup_factor = self.cfg.SOLVER.WARM_UP_FACTOR
		    elif method == 'linear':
		        alpha = step / self.cfg.SOLVER.WARM_UP_ITERS
		        warmup_factor = self.cfg.SOLVER.WARM_UP_FACTOR * (1 - alpha) + alpha
		    else:
		        raise KeyError('Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
		    lr_new = self.cfg.SOLVER.BASE_LR * warmup_factor
		    net_utils.update_learning_rate(self.optimizer, self.lr, lr_new)
		    self.lr = self.optimizer.param_groups[0]['lr']
		    assert self.lr == lr_new
		elif step == self.cfg.SOLVER.WARM_UP_ITERS:
		    net_utils.update_learning_rate(self.optimizer, self.lr, self.cfg.SOLVER.BASE_LR)
		    self.lr = self.optimizer.param_groups[0]['lr']
		    assert self.lr == self.cfg.SOLVER.BASE_LR

		# Learning rate decay
		if self.decay_steps_ind < len(self.cfg.SOLVER.STEPS) and step == self.cfg.SOLVER.STEPS[self.decay_steps_ind]:
		    logger.info('Decay the learning on step %d', step)
		    lr_new = self.lr * self.cfg.SOLVER.GAMMA
		    net_utils.update_learning_rate(self.optimizer, self.lr, lr_new)
		    self.lr = self.optimizer.param_groups[0]['lr']
		    assert self.lr == lr_new
		    self.decay_steps_ind += 1
예제 #2
0
def main():
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    if args.dataset == "coco2017":
        cfg.TRAIN.DATASETS = ('coco_2017_train',)
        cfg.MODEL.NUM_CLASSES = 81
    elif args.dataset == "keypoints_coco2017":
        cfg.TRAIN.DATASETS = ('keypoints_coco_2017_train',)
        cfg.MODEL.NUM_CLASSES = 2
    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH
    original_num_gpus = cfg.NUM_GPUS
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    effective_batch_size = args.iter_size * args.batch_size
    print('effective_batch_size = batch_size * iter_size = %d * %d' % (args.batch_size, args.iter_size))

    print('Adaptive config changes:')
    print('    effective_batch_size: %d --> %d' % (original_batch_size, effective_batch_size))
    print('    NUM_GPUS:             %d --> %d' % (original_num_gpus, cfg.NUM_GPUS))
    print('    IMS_PER_BATCH:        %d --> %d' % (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH))

    ### Adjust learning based on batch size change linearly
    # For iter_size > 1, gradients are `accumulated`, so lr is scaled based
    # on batch_size instead of effective_batch_size
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch_size change:\n'
          '    BASE_LR: {} --> {}'.format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Adjust solver steps
    step_scale = original_batch_size / effective_batch_size
    old_solver_steps = cfg.SOLVER.STEPS
    old_max_iter = cfg.SOLVER.MAX_ITER
    cfg.SOLVER.STEPS = list(map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS))
    cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5)
    print('Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n'
          '    SOLVER.STEPS: {} --> {}\n'
          '    SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps, cfg.SOLVER.STEPS,
                                                  old_max_iter, cfg.SOLVER.MAX_ITER))

    # Scale FPN rpn_proposals collect size (post_nms_topN) in `collect` function
    # of `collect_and_distribute_fpn_rpn_proposals.py`
    #
    # post_nms_topN = int(cfg[cfg_key].RPN_POST_NMS_TOP_N * cfg.FPN.RPN_COLLECT_SCALE + 0.5)
    if cfg.FPN.FPN_ON and cfg.MODEL.FASTER_RCNN:
        cfg.FPN.RPN_COLLECT_SCALE = cfg.TRAIN.IMS_PER_BATCH / original_ims_per_batch
        print('Scale FPN rpn_proposals collect size directly propotional to the change of IMS_PER_BATCH:\n'
              '    cfg.FPN.RPN_COLLECT_SCALE: {}'.format(cfg.FPN.RPN_COLLECT_SCALE))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma
    assert_and_infer_cfg()

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb', timers['roidb'].average_time)

    # Effective training sample size for one epoch
    train_size = roidb_size // args.batch_size * args.batch_size

    batchSampler = BatchSampler(
        sampler=MinibatchSampler(ratio_list, ratio_index),
        batch_size=args.batch_size,
        drop_last=True
    )
    dataset = RoiDataLoader(
        roidb,
        cfg.MODEL.NUM_CLASSES,
        training=True)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=batchSampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    dataiterator = iter(dataloader)

    ### Model ###
    maskRCNN = Generalized_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    gn_param_nameset = set()
    for name, module in maskRCNN.named_modules():
        if isinstance(module, nn.GroupNorm):
            gn_param_nameset.add(name+'.weight')
            gn_param_nameset.add(name+'.bias')
    gn_params = []
    gn_param_names = []
    bias_params = []
    bias_param_names = []
    nonbias_params = []
    nonbias_param_names = []
    nograd_param_names = []
    for key, value in maskRCNN.named_parameters():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
                bias_param_names.append(key)
            elif key in gn_param_nameset:
                gn_params.append(value)
                gn_param_names.append(key)
            else:
                nonbias_params.append(value)
                nonbias_param_names.append(key)
        else:
            nograd_param_names.append(key)
    assert (gn_param_nameset - set(nograd_param_names) - set(bias_param_names)) == set(gn_param_names)

    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [
        {'params': nonbias_params,
         'lr': 0,
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY},
        {'params': bias_params,
         'lr': 0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0},
        {'params': gn_params,
         'lr': 0,
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN}
    ]
    # names of paramerters for each paramter
    param_names = [nonbias_param_names, bias_param_names, gn_param_names]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name, map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print('train_size value: %d different from the one in checkpoint: %d'
                          % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            optimizer.load_state_dict(checkpoint['optimizer'])
            # misc_utils.load_optimizer_state_dict(optimizer, checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[0]['lr']  # lr of non-bias parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN, cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name() + '_step'
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args,
        args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):

            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 - alpha) + alpha
                else:
                    raise KeyError('Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate(optimizer, lr, cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()
            for inner_iter in range(args.iter_size):
                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)

                for key in input_data:
                    if key != 'roidb': # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))

                net_outputs = maskRCNN(**input_data)
                training_stats.UpdateIterStats(net_outputs, inner_iter)
                loss = net_outputs['total_loss']
                loss.backward()
            optimizer.step()
            training_stats.IterToc()

            training_stats.LogIterStats(step, lr)

            if (step+1) % CHECKPOINT_PERIOD == 0:
                save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

        # ---- Training ends ----
        # Save last checkpoint
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

    except (RuntimeError, KeyboardInterrupt):
        del dataiterator
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
예제 #3
0
def main():
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    if args.dataset == "coco2017":
        cfg.TRAIN.DATASETS = ('coco_2017_train', )
        cfg.MODEL.NUM_CLASSES = 81
    elif args.dataset == "keypoints_coco2017":
        cfg.TRAIN.DATASETS = ('keypoints_coco_2017_train', )
        cfg.MODEL.NUM_CLASSES = 2
    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    print('Batch size change from {} (in config file) to {}'.format(
        original_batch_size, args.batch_size))
    print('NUM_GPUs: %d, TRAIN.IMS_PER_BATCH: %d' %
          (cfg.NUM_GPUS, cfg.TRAIN.IMS_PER_BATCH))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Adjust learning based on batch size change linearly
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch size change: {} --> {}'.
          format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma
    assert_and_infer_cfg()

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb',
                timers['roidb'].average_time)

    # Effective training sample size for one epoch
    train_size = roidb_size // args.batch_size * args.batch_size

    sampler = MinibatchSampler(ratio_list, ratio_index)
    dataset = RoiDataLoader(roidb, cfg.MODEL.NUM_CLASSES, training=True)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        drop_last=True,
        sampler=sampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    dataiterator = iter(dataloader)

    ### Model ###
    maskRCNN = Generalized_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    bias_params = []
    nonbias_params = []
    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
            else:
                nonbias_params.append(value)
    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [{
        'params': nonbias_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        bias_params,
        'lr':
        0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            assert checkpoint['train_size'] == train_size
            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer,
                                                 checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[0][
        'lr']  # lr of non-bias parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    run_name = misc_utils.get_run_name()
    output_dir = misc_utils.get_output_dir(args, run_name)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] > args.start_step:
            decay_steps_ind = i
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    logger.info('Training starts !')
    loss_avg = 0
    try:
        timers['train_loop'].tic()
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):

            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 -
                                                                 alpha) + alpha
                else:
                    raise KeyError(
                        'Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                lr = cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = lr_new
                decay_steps_ind += 1

            try:
                input_data = next(dataiterator)
            except StopIteration:
                dataiterator = iter(dataloader)
                input_data = next(dataiterator)

            for key in input_data:
                if key != 'roidb':  # roidb is a list of ndarrays with inconsistent length
                    input_data[key] = list(map(Variable, input_data[key]))

            outputs = maskRCNN(**input_data)

            rois_label = outputs['rois_label']
            cls_score = outputs['cls_score']
            bbox_pred = outputs['bbox_pred']
            loss_rpn_cls = outputs['loss_rpn_cls'].mean()
            loss_rpn_bbox = outputs['loss_rpn_bbox'].mean()
            loss_rcnn_cls = outputs['loss_rcnn_cls'].mean()
            loss_rcnn_bbox = outputs['loss_rcnn_bbox'].mean()

            loss = loss_rpn_cls + loss_rpn_bbox + loss_rcnn_cls + loss_rcnn_bbox

            if cfg.MODEL.MASK_ON:
                loss_rcnn_mask = outputs['loss_rcnn_mask'].mean()
                loss += loss_rcnn_mask

            if cfg.MODEL.KEYPOINTS_ON:
                loss_rcnn_keypoints = outputs['loss_rcnn_keypoints'].mean()
                loss += loss_rcnn_keypoints

            loss_avg += loss.data.cpu().numpy()[0]

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (step + 1) % CHECKPOINT_PERIOD == 0:
                save_ckpt(output_dir, args, step, train_size, maskRCNN,
                          optimizer)

            if ((step % args.disp_interval == 0 and
                 (step - args.start_step >= args.disp_interval))
                    or step == cfg.SOLVER.MAX_ITER - 1):
                diff = timers['train_loop'].toc(average=False)
                loss_avg /= args.disp_interval

                loss_rpn_cls = loss_rpn_cls.data[0]
                loss_rpn_bbox = loss_rpn_bbox.data[0]
                loss_rcnn_cls = loss_rcnn_cls.data[0]
                loss_rcnn_bbox = loss_rcnn_bbox.data[0]
                fg_cnt = torch.sum(rois_label.data.ne(0))
                bg_cnt = rois_label.data.numel() - fg_cnt
                print("[ %s ][ step %d ]" % (run_name, step))
                print("\t\tloss: %.4f, lr: %.2e" % (loss_avg, lr))
                print("\t\tfg/bg=(%d/%d), time cost: %f" %
                      (fg_cnt, bg_cnt, diff))
                print(
                    "\t\trpn_cls: %.4f, rpn_bbox: %.4f, rcnn_cls: %.4f, rcnn_bbox %.4f"
                    % (loss_rpn_cls, loss_rpn_bbox, loss_rcnn_cls,
                       loss_rcnn_bbox))

                print_prefix = "\t\t"
                if cfg.MODEL.MASK_ON:
                    loss_rcnn_mask = loss_rcnn_mask.data[0]
                    print("%srcnn_mask %.4f" % (print_prefix, loss_rcnn_mask))
                    print_prefix = ", "
                if cfg.MODEL.KEYPOINTS_ON:
                    loss_rcnn_keypoints = loss_rcnn_keypoints.data[0]
                    print("%srcnn_keypoints %.4f" %
                          (print_prefix, loss_rcnn_keypoints))

                if args.use_tfboard and not args.no_save:
                    info = {
                        'lr': lr,
                        'loss': loss_avg,
                        'loss_rpn_cls': loss_rpn_cls,
                        'loss_rpn_box': loss_rpn_bbox,
                        'loss_rcnn_cls': loss_rcnn_cls,
                        'loss_rcnn_box': loss_rcnn_bbox,
                    }
                    if cfg.MODEL.MASK_ON:
                        info['loss_rcnn_mask'] = loss_rcnn_mask
                    if cfg.MODEL.KEYPOINTS_ON:
                        info['loss_rcnn_keypoints'] = loss_rcnn_keypoints
                    for tag, value in info.items():
                        tblogger.add_scalar(tag, value, step)

                loss_avg = 0
                timers['train_loop'].tic()

        # Save last checkpoint
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

    except (RuntimeError, KeyboardInterrupt) as e:
        print('Save on exception:', e)
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        # ---- Training ends ----
        if args.use_tfboard and not args.no_save:
            tblogger.close()
def main():
    args = parse_args()
    use_gpu = torch.cuda.is_available()
    if args.use_cpu: use_gpu = False
    pin_memory = True if use_gpu else False
    #args.labelsmooth = True

    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
    else:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt'))
    print("==========\nArgs:{}\n==========".format(args))

    cfg_from_file(args.cfg_file)

    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    if args.deterministic:
        # Experiment reproducibility is sometimes important.  Pete Warden expounded about this
        # in his blog: https://petewarden.com/2018/03/19/the-machine-learning-reproducibility-crisis/
        # In Pytorch, support for deterministic execution is still a bit clunky.
        # Use a well-known seed, for repeatability of experiments
        torch.manual_seed(args.seed)
        random.seed(args.seed)
        np.random.seed(args.seed)
        cudnn.deterministic = True
    if use_gpu:
        print("Currently using GPU {}".format(args.gpu_devices))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(args.seed)
    else:
        print("Currently using CPU (GPU is highly recommended)")

    # if cfg.RESNETS.SSN:
    #     rank, world_size = init_dist(backend='nccl', port=29500)
    #     print("The world_size is : {}".format(world_size))
    # summary_writer = SummaryWriter(osp.join(args.save_dir, 'tensorboard_log'))

    print("Initializing dataset {}".format(args.dataset))
    dataset = data_manager.init_img_dataset(
        root=args.root,
        name=args.dataset,
        split_id=args.split_id,
        cuhk03_labeled=args.cuhk03_labeled,
        cuhk03_classic_split=args.cuhk03_classic_split,
        WEIGHT_TEST=cfg.REID.WEIGHT_TEST,
    )

    trainloader = DataLoader(
        ImageDataset(dataset.train,
                     transform=TrainTransform(cfg.REID.HEIGHT, cfg.REID.WIDTH,
                                              cfg.REID.PRE_PRO_TYPE)),
        sampler=RandomIdentitySampler(
            dataset.train, num_instances=cfg.REID.TRI_NUM_INSTANCES),
        batch_size=cfg.REID.TRAIN_BATCH,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=True,
    )

    if cfg.REID.WEIGHT_TEST:
        queryloader_1 = DataLoader(
            ImageDataset(dataset.query_1,
                         transform=TestTransform(cfg.REID.HEIGHT,
                                                 cfg.REID.WIDTH,
                                                 cfg.REID.PRE_PRO_TYPE)),
            batch_size=cfg.REID.TEST_BATCH,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=pin_memory,
            drop_last=False,
        )

        galleryloader_1 = DataLoader(
            ImageDataset(dataset.gallery_1,
                         transform=TestTransform(cfg.REID.HEIGHT,
                                                 cfg.REID.WIDTH,
                                                 cfg.REID.PRE_PRO_TYPE)),
            batch_size=cfg.REID.TEST_BATCH,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=pin_memory,
            drop_last=False,
        )

        queryloader_2 = DataLoader(
            ImageDataset(dataset.query_2,
                         transform=TestTransform(cfg.REID.HEIGHT,
                                                 cfg.REID.WIDTH,
                                                 cfg.REID.PRE_PRO_TYPE)),
            batch_size=cfg.REID.TEST_BATCH,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=pin_memory,
            drop_last=False,
        )

        galleryloader_2 = DataLoader(
            ImageDataset(dataset.gallery_2,
                         transform=TestTransform(cfg.REID.HEIGHT,
                                                 cfg.REID.WIDTH,
                                                 cfg.REID.PRE_PRO_TYPE)),
            batch_size=cfg.REID.TEST_BATCH,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=pin_memory,
            drop_last=False,
        )

    else:
        queryloader = DataLoader(
            ImageDataset(dataset.query,
                         transform=TestTransform(cfg.REID.HEIGHT,
                                                 cfg.REID.WIDTH,
                                                 cfg.REID.PRE_PRO_TYPE)),
            batch_size=cfg.REID.TEST_BATCH,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=pin_memory,
            drop_last=False,
        )

        galleryloader = DataLoader(
            ImageDataset(dataset.gallery,
                         transform=TestTransform(cfg.REID.HEIGHT,
                                                 cfg.REID.WIDTH,
                                                 cfg.REID.PRE_PRO_TYPE)),
            batch_size=cfg.REID.TEST_BATCH,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=pin_memory,
            drop_last=False,
        )

    print("Initializing model: {}".format(
        cfg.MODEL.CONV_BODY.split('.')[-1].split('_')[1]))
    model = Generalized_FPN(num_classes=dataset.num_train_pids,
                            loss=cfg.REID.LOSS,
                            aligned=cfg.REID.ALIGNED,
                            strong_baseline=cfg.REID.STRONG_BASELINE)

    print("Model size: {:.5f}M".format(
        sum(p.numel() for p in model.parameters()) / 1000000.0))
    if cfg.REID.FOCALLOSS:
        criterion_class = FocalLoss(gamma=2, alpha=0.25, \
                                     labelsmooth=cfg.REID.LABElSMOOTH,
                                     num_classes=dataset.num_train_pids,
                                     epsilon=0.1)
    elif cfg.REID.LABLESMOOTH:
        criterion_class = CrossEntropyLabelSmooth(
            num_classes=dataset.num_train_pids, use_gpu=use_gpu)
    else:
        criterion_class = CrossEntropyLoss(use_gpu=use_gpu)

    criterion_class_oimwarmup = None

    criterion_metric = TripletLossAlignedReID(margin=cfg.REID.TRI_MARGIN,
                                              aligned=cfg.REID.ALIGNED)

    criterion_center = CenterLoss(num_classes=dataset.num_train_pids,
                                  feat_dim=2048)

    if cfg.REID.CAMIDCLASS:
        criterion_class_cam = nn.BCEWithLogitsLoss().cuda()
    else:
        criterion_class_cam = None

    criterion_class_mixup = None  #CrossEntropyLoss(use_gpu=use_gpu)

    def load_ckpt(model, ckpt):
        """Load checkpoint"""
        mapping, _ = model.detectron_weight_mapping
        model_state_dict = model.state_dict()
        state_dict = {}
        for name in ckpt:
            try:
                if mapping[name] and name.split('.')[1] != 'posthoc_modules'\
                        and  name.split('.')[1] !='topdown_lateral_modules' \
                        and name.split('.')[1] !='conv_top':
                    state_dict[name] = ckpt[name]

            except:
                if name.split('.')[0] != 'bottleneck_sub2' and name.split('.')[0] != 'classifier' \
                and name.split('.')[0].split('_')[0] != 'classifier':
                    state_dict[name] = ckpt[name]
                    print('parameters: {}'.format(name))

        model.load_state_dict(state_dict, strict=False)

    if args.load_ckpt:
        load_name = args.load_ckpt
        print("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        load_ckpt(model, checkpoint['state_dict'])

    start_epoch = args.start_epoch
    if args.resume:
        print("Loading checkpoint from '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        # print(checkpoint['state_dict'].keys())
        model.load_state_dict(checkpoint['state_dict'])
        start_epoch = checkpoint['epoch']

    ### Optimizer ###
    if cfg.REID.DETECTRON_OPTIMIZER:
        gn_params = []
        bias_params = []
        bias_param_names = []
        nonbias_params = []
        nonbias_param_names = []
        ssn_params = []
        ssn_param_names = []
        for key, value in dict(model.named_parameters()).items():
            if value.requires_grad:
                if 'gn' in key:
                    gn_params.append(value)
                elif 'bias' in key:
                    bias_params.append(value)
                    bias_param_names.append(key)
                elif key.endswith('_weight') and cfg.RESNETS.SSN:
                    # print('********{}'.format(key))
                    ssn_params.append(value)
                    ssn_param_names.append(key)
                else:
                    nonbias_params.append(value)
                    nonbias_param_names.append(key)
            else:
                # print('NO!')
                print('FREEZE para: {}'.format(key))
        # Learning rate of 0 is a dummy value to be set properly at the start of training
        params = [
            {
                'params': nonbias_params,
                'lr': cfg.SOLVER.BASE_LR,  #0,
                'weight_decay': cfg.SOLVER.WEIGHT_DECAY
            },
            {
                'params':
                bias_params,
                'lr':
                cfg.SOLVER.BASE_LR,  #0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
                'weight_decay':
                cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
            },
            {
                'params': gn_params,
                'lr': cfg.SOLVER.BASE_LR,  #0,
                'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN
            },
            {
                'params': ssn_params,
                'lr': cfg.SOLVER.BASE_LR / 10,  # 0,  / 10
                'weight_decay': 0
            }  #cfg.SOLVER.WEIGHT_DECAY
        ]
        # names of paramerters for each paramter
        param_names = [nonbias_param_names, bias_param_names, ssn_param_names]

        if cfg.SOLVER.TYPE == "SGD":
            optimizer = torch.optim.SGD(params,
                                        momentum=cfg.SOLVER.MOMENTUM,
                                        weight_decay=cfg.SOLVER.WEIGHT_DECAY)
        elif cfg.SOLVER.TYPE == "Adam":
            optimizer = torch.optim.Adam(params,
                                         weight_decay=cfg.SOLVER.WEIGHT_DECAY)
        elif cfg.SOLVER.TYPE == 'Rmsprop':
            optimizer = torch.optim.RMSprop(
                params,
                momentum=cfg.SOLVER.MOMENTUM,
                weight_decay=cfg.SOLVER.WEIGHT_DECAY)
        else:
            raise KeyError("Unsupported optim: {}".format(cfg.SOLVER.TYPE))

        lr = optimizer.param_groups[0][
            'lr']  # lr of non-bias parameters, for commmand line outputs.
        if cfg.SOLVER.LR_POLICY == 'cosine_annealing':
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, float(400 - 18), eta_min=0.000001)

        if cfg.REID.SWA:
            # SWA: initialize SWA optimizer wrapper
            print('SWA training')
            # steps_per_epoch = len(loaders['train'].dataset) / args.batch_size
            # steps_per_epoch = int(steps_per_epoch)
            # print("Steps per epoch:", steps_per_epoch)
            optimizer = SWA(optimizer,
                            swa_start=cfg.REID.SWA_START,
                            swa_freq=cfg.REID.SWA_FREQ,
                            swa_lr=cfg.REID.SWA_LR)

        # Set index for decay steps
        decay_steps_ind = None
        for i in range(1, len(cfg.SOLVER.STEPS)):
            if cfg.SOLVER.STEPS[i] >= args.start_epoch:
                decay_steps_ind = i
                break
        if decay_steps_ind is None:
            decay_steps_ind = len(cfg.SOLVER.STEPS)

    else:
        raise KeyError("Unimplemented cfg.REID.DETECTRON_OPTIMIZER: {}".format(
            cfg.REID.DETECTRON_OPTIMIZER))

    if use_gpu:
        model = nn.DataParallel(model).cuda()

    if args.evaluate:
        print("Evaluate only")
        if cfg.REID.WEIGHT_TEST:
            rank1 = wwn_test_2(model, queryloader_1, galleryloader_1,
                               queryloader_2, galleryloader_2, use_gpu, args)
        else:
            wwn_test(model, queryloader, galleryloader, use_gpu, args)
        return 0

    start_time = time.time()
    train_time = 0
    best_rank1 = -np.inf
    best_epoch = 0
    print("==> Start training")
    max_epoch = cfg.SOLVER.MAX_ITER
    for epoch in range(start_epoch, max_epoch):
        if cfg.REID.DETECTRON_OPTIMIZER:
            # Warm up
            if epoch < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    # alpha = (epoch+1) / cfg.SOLVER.WARM_UP_ITERS
                    # warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 - alpha) + alpha
                    alpha = (cfg.SOLVER.BASE_LR - cfg.SOLVER.WARM_UP_FACTOR
                             ) / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = epoch * alpha + cfg.SOLVER.WARM_UP_FACTOR
                    # print(warmup_factor)
                else:
                    raise KeyError(
                        'Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = warmup_factor
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif epoch == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate(optimizer, lr,
                                               cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if cfg.SOLVER.LR_POLICY == 'steps_with_decay':
                if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                        epoch == cfg.SOLVER.STEPS[decay_steps_ind]:
                    print('Decay the learning on step %d', epoch)
                    lr_new = lr * cfg.SOLVER.GAMMA
                    net_utils.update_learning_rate(optimizer, lr, lr_new)
                    lr = optimizer.param_groups[0]['lr']
                    assert lr == lr_new
                    decay_steps_ind += 1
            elif cfg.SOLVER.LR_POLICY == 'cosine_annealing' and epoch >= cfg.SOLVER.WARM_UP_ITERS:
                scheduler.step()
                lr_new = scheduler.get_lr()[0]
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new

        else:
            if args.strong_baseline:
                adjust_lr(optimizer, epoch + 1)
            else:
                if args.stepsize > 0:
                    scheduler.step()

        start_train_time = time.time()
        train(epoch, model, criterion_class, criterion_metric, criterion_center,optimizer, trainloader, use_gpu, \
              criterion_class_oimwarmup,criterion_class_cam, \
              criterion_class_mixup, args)
        train_time += round(time.time() - start_train_time)

        # if cfg.REID.SWA and (epoch + 1) >= cfg.REID.SWA_START:
        #     # utils.moving_average(swa_model, model, 1.0 / (swa_n + 1))
        #     # swa_n += 1
        #     if epoch == 0 or epoch % args.eval_step == args.eval_step - 1 or epoch == max_epoch - 1:
        #         # Batchnorm update
        #         optimizer.swap_swa_sgd()
        #         # print('WWN')
        #         optimizer.bn_update(trainloader, model, device='cuda')
        #         # swa_res = utils.eval(loaders['test'], model, criterion)

        if (epoch + 1) > args.start_eval and args.eval_step > 0 and (
                epoch + 1) % args.eval_step == 0 or (epoch + 1) == max_epoch:
            print("==> Test")

            if cfg.REID.SWA and (epoch + 1) >= cfg.REID.SWA_START:
                optimizer.swap_swa_sgd()
                # print('WWN')
                optimizer.bn_update(trainloader, model, device='cuda')
                if cfg.REID.WEIGHT_TEST:
                    rank1 = wwn_test_2(model, queryloader_1, galleryloader_1,
                                       queryloader_2, galleryloader_2, use_gpu,
                                       args)
                else:
                    rank1 = wwn_test(model, queryloader, galleryloader,
                                     use_gpu, args)

                optimizer.swap_swa_sgd()
            else:
                if cfg.REID.WEIGHT_TEST:
                    rank1 = wwn_test_2(model, queryloader_1, galleryloader_1,
                                       queryloader_2, galleryloader_2, use_gpu,
                                       args)
                else:
                    rank1 = wwn_test(model, queryloader, galleryloader,
                                     use_gpu, args)

            is_best = rank1 > best_rank1
            if is_best:
                best_rank1 = rank1
                best_epoch = epoch + 1

            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()
            save_checkpoint_best(
                {
                    'state_dict': state_dict,
                    'rank1': rank1,
                    'epoch': epoch,
                }, is_best,
                osp.join(args.save_dir,
                         'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))

    print("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(
        best_rank1, best_epoch))

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    train_time = str(datetime.timedelta(seconds=train_time))
    print(
        "Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".
        format(elapsed, train_time))
예제 #5
0
def main():
    """Main function"""
    args = parse_cfg()

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb',
                timers['roidb'].average_time)

    # Effective training sample size for one epoch
    train_size = roidb_size // args.batch_size * args.batch_size

    batchSampler = BatchSampler(sampler=MinibatchSampler(
        ratio_list, ratio_index),
                                batch_size=args.batch_size,
                                drop_last=True)
    dataset = RoiDataLoader(roidb, cfg.MODEL.NUM_CLASSES, training=True)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=batchSampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    dataiterator = iter(dataloader)

    ### Model ###
    maskRCNN = GetRCNNModel()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    gn_param_nameset = set()
    for name, module in maskRCNN.named_modules():
        if isinstance(module, nn.GroupNorm):
            gn_param_nameset.add(name + '.weight')
            gn_param_nameset.add(name + '.bias')
    gn_params = []
    gn_param_names = []
    bias_params = []
    bias_param_names = []
    nonbias_params = []
    nonbias_param_names = []
    nograd_param_names = []
    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
                bias_param_names.append(key)
            elif key in gn_param_nameset:
                gn_params.append(value)
                gn_param_names.append(key)
            else:
                nonbias_params.append(value)
                nonbias_param_names.append(key)
        else:
            nograd_param_names.append(key)
    assert (gn_param_nameset - set(nograd_param_names) -
            set(bias_param_names)) == set(gn_param_names)

    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [{
        'params': nonbias_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        bias_params,
        'lr':
        0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }, {
        'params': gn_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN
    }]
    # names of paramerters for each paramter
    param_names = [nonbias_param_names, bias_param_names, gn_param_names]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print(
                        'train_size value: %d different from the one in checkpoint: %d'
                        % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer,
                                                 checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[0][
        'lr']  # lr of non-bias parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)
    if cfg.TRAIN_SYNC_BN:
        # Shu:For synchorinized BN
        patch_replication_callback(maskRCNN)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name() + '_step'
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):

            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 -
                                                                 alpha) + alpha
                else:
                    raise KeyError(
                        'Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate(optimizer, lr,
                                               cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()
            for inner_iter in range(args.iter_size):
                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)

                for key in input_data:
                    if cfg.MODEL.LR_VIEW_ON or cfg.MODEL.GIF_ON or cfg.MODEL.LRASY_MAHA_ON:
                        if key != 'roidb' and key != 'data':  # roidb is a list of ndarrays with inconsistent length
                            input_data[key] = list(
                                map(Variable, input_data[key]))
                        if key == 'data':
                            input_data[key] = [
                                torch.squeeze(item) for item in input_data[key]
                            ]
                            input_data[key] = list(
                                map(Variable, input_data[key]))
                    else:
                        if key != 'roidb':  # roidb is a list of ndarrays with inconsistent length
                            input_data[key] = list(
                                map(Variable, input_data[key]))
                net_outputs = maskRCNN(**input_data)
                training_stats.UpdateIterStats(net_outputs, inner_iter)
                loss = net_outputs['total_loss']
                loss.backward()
            optimizer.step()
            training_stats.IterToc()

            training_stats.LogIterStats(step, lr)

            if (step + 1) % CHECKPOINT_PERIOD == 0:
                net_utils.train_save_ckpt(output_dir, args, step, train_size,
                                          maskRCNN, optimizer)

        # ---- Training ends ----
        # Save last checkpoint
        net_utils.train_save_ckpt(output_dir, args, step, train_size, maskRCNN,
                                  optimizer)

    except (RuntimeError, KeyboardInterrupt):
        del dataiterator
        logger.info('Save ckpt on exception ...')
        net_utils.train_save_ckpt(output_dir, args, step, train_size, maskRCNN,
                                  optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
예제 #6
0
def main():
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    if args.dataset == "coco2017":
        cfg.TRAIN.DATASETS = ('coco_2017_train', )
        cfg.MODEL.NUM_CLASSES = 81
    elif args.dataset == "keypoints_coco2017":
        cfg.TRAIN.DATASETS = ('keypoints_coco_2017_train', )
        cfg.MODEL.NUM_CLASSES = 2

    # ADE20k as a detection dataset
    elif args.dataset == "ade_train":
        cfg.TRAIN.DATASETS = ('ade_train', )
        cfg.MODEL.NUM_CLASSES = 446

    # Noisy CS6+WIDER datasets
    elif args.dataset == 'cs6_noise020+WIDER':
        cfg.TRAIN.DATASETS = ('cs6_noise020', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise050+WIDER':
        cfg.TRAIN.DATASETS = ('cs6_noise050', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise080+WIDER':
        cfg.TRAIN.DATASETS = ('cs6_noise080', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise085+WIDER':
        cfg.TRAIN.DATASETS = ('cs6_noise085', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise090+WIDER':
        cfg.TRAIN.DATASETS = ('cs6_noise090', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise095+WIDER':
        cfg.TRAIN.DATASETS = ('cs6_noise095', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise100+WIDER':
        cfg.TRAIN.DATASETS = ('cs6_noise100', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2

    # Just Noisy CS6 datasets
    elif args.dataset == 'cs6_noise020':
        cfg.TRAIN.DATASETS = ('cs6_noise020', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise030':
        cfg.TRAIN.DATASETS = ('cs6_noise030', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise040':
        cfg.TRAIN.DATASETS = ('cs6_noise040', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise050':
        cfg.TRAIN.DATASETS = ('cs6_noise050', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise060':
        cfg.TRAIN.DATASETS = ('cs6_noise060', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise070':
        cfg.TRAIN.DATASETS = ('cs6_noise070', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise080':
        cfg.TRAIN.DATASETS = ('cs6_noise080', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise085':
        cfg.TRAIN.DATASETS = ('cs6_noise085', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise090':
        cfg.TRAIN.DATASETS = ('cs6_noise090', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise095':
        cfg.TRAIN.DATASETS = ('cs6_noise095', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise100':
        cfg.TRAIN.DATASETS = ('cs6_noise100', )
        cfg.MODEL.NUM_CLASSES = 2

    # Cityscapes 7 classes
    elif args.dataset == "cityscapes":
        cfg.TRAIN.DATASETS = ('cityscapes_train', )
        cfg.MODEL.NUM_CLASSES = 8

    # BDD 7 classes
    elif args.dataset == "bdd_any_any_any":
        cfg.TRAIN.DATASETS = ('bdd_any_any_any_train', )
        cfg.MODEL.NUM_CLASSES = 8
    elif args.dataset == "bdd_any_any_daytime":
        cfg.TRAIN.DATASETS = ('bdd_any_any_daytime_train', )
        cfg.MODEL.NUM_CLASSES = 8
    elif args.dataset == "bdd_clear_any_daytime":
        cfg.TRAIN.DATASETS = ('bdd_clear_any_daytime_train', )
        cfg.MODEL.NUM_CLASSES = 8

    # Cistyscapes Pedestrian sets
    elif args.dataset == "cityscapes_peds":
        cfg.TRAIN.DATASETS = ('cityscapes_peds_train', )
        cfg.MODEL.NUM_CLASSES = 2

    # Cityscapes Car sets
    elif args.dataset == "cityscapes_cars_HPlen3+kitti_car_train":
        cfg.TRAIN.DATASETS = ('cityscapes_cars_HPlen3', 'kitti_car_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cityscapes_cars_HPlen5+kitti_car_train":
        cfg.TRAIN.DATASETS = ('cityscapes_cars_HPlen5', 'kitti_car_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cityscapes_car_train+kitti_car_train":
        cfg.TRAIN.DATASETS = ('cityscapes_car_train', 'kitti_car_train')
        cfg.MODEL.NUM_CLASSES = 2

    # KITTI Car set
    elif args.dataset == "kitti_car_train":
        cfg.TRAIN.DATASETS = ('kitti_car_train', )
        cfg.MODEL.NUM_CLASSES = 2

    # BDD pedestrians sets
    elif args.dataset == "bdd_peds":
        cfg.TRAIN.DATASETS = ('bdd_peds_train',
                              )  # bdd peds: clear_any_daytime
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "bdd_peds_full":
        cfg.TRAIN.DATASETS = ('bdd_peds_full_train', )  # bdd peds: any_any_any
        cfg.MODEL.NUM_CLASSES = 2
    # Pedestrians with constraints
    elif args.dataset == "bdd_peds_not_clear_any_daytime":
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_train', )
        cfg.MODEL.NUM_CLASSES = 2
    # Ashish's  20k samples videos
    elif args.dataset == "bdd_peds_not_clear_any_daytime_20k":
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_20k_train', )
        cfg.MODEL.NUM_CLASSES = 2
    # Source domain + Target domain detections
    elif args.dataset == "bdd_peds+DETS_20k":
        cfg.TRAIN.DATASETS = ('bdd_peds_dets_20k_target_domain',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    # Source domain + Target domain detections -- same 18k images as HP18k
    elif args.dataset == "bdd_peds+DETS18k":
        cfg.TRAIN.DATASETS = ('bdd_peds_dets18k_target_domain',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    # Only Dets
    elif args.dataset == "DETS20k":
        cfg.TRAIN.DATASETS = ('bdd_peds_dets_20k_target_domain', )
        cfg.MODEL.NUM_CLASSES = 2
    # Only Dets18k - same images as HP18k
    elif args.dataset == 'DETS18k':
        cfg.TRAIN.DATASETS = ('bdd_peds_dets18k_target_domain', )
        cfg.MODEL.NUM_CLASSES = 2
    # Only HP
    elif args.dataset == 'HP':
        cfg.TRAIN.DATASETS = ('bdd_peds_HP_target_domain', )
        cfg.MODEL.NUM_CLASSES = 2
    # Only HP 18k videos
    elif args.dataset == 'HP18k':
        cfg.TRAIN.DATASETS = ('bdd_peds_HP18k_target_domain', )
        cfg.MODEL.NUM_CLASSES = 2
    # Source domain + Target domain HP
    elif args.dataset == 'bdd_peds+HP':
        cfg.TRAIN.DATASETS = ('bdd_peds_train', 'bdd_peds_HP_target_domain')
        cfg.MODEL.NUM_CLASSES = 2
        # Source domain + Target domain HP 18k videos
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+HP18k':
        cfg.TRAIN.DATASETS = ('bdd_peds_HP18k_target_domain', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    #### Source domain + Target domain with different conf threshold theta ####
    elif args.dataset == 'bdd_peds+HP18k_thresh-050':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_thresh-050', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+HP18k_thresh-060':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_thresh-060', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+HP18k_thresh-070':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_thresh-070', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+HP18k_thresh-090':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_thresh-090', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    ##############################

    #### Data distillation on BDD -- for rebuttal
    elif args.dataset == 'bdd_peds+bdd_data_dist_small':
        cfg.TRAIN.DATASETS = ('bdd_data_dist_small', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+bdd_data_dist_mid':
        cfg.TRAIN.DATASETS = ('bdd_data_dist_mid', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+bdd_data_dist':
        cfg.TRAIN.DATASETS = ('bdd_data_dist', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    ##############################

    #### Source domain + **Labeled** Target domain with varying number of images
    elif args.dataset == 'bdd_peds+labeled_100':
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_train_100',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+labeled_075':
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_train_075',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+labeled_050':
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_train_050',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+labeled_025':
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_train_025',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+labeled_010':
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_train_010',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+labeled_005':
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_train_005',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+labeled_001':
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_train_001',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    ##############################

    # Source domain + Target domain HP tracker bboxes only
    elif args.dataset == 'bdd_peds+HP18k_track_only':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_track_only', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2

    ##### subsets of bdd_HP18k with different constraints
    # Source domain + HP tracker images at NIGHT
    elif args.dataset == 'bdd_peds+HP18k_any_any_night':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_any_any_night', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == 'bdd_peds+HP18k_rainy_any_daytime':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_rainy_any_daytime', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+HP18k_rainy_any_night':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_rainy_any_night', 'bdd_peds_train')

        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+HP18k_overcast,rainy_any_daytime':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_overcast,rainy_any_daytime',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+HP18k_overcast,rainy_any_night':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_overcast,rainy_any_night',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == 'bdd_peds+HP18k_overcast,rainy,snowy_any_daytime':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_overcast,rainy,snowy_any_daytime',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    #############  end of bdd constraned subsets  #####################

    # Source domain + Target domain HP18k -- after histogram matching
    elif args.dataset == 'bdd_peds+HP18k_remap_hist':
        cfg.TRAIN.DATASETS = ('bdd_peds_HP18k_target_domain_remap_hist',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+HP18k_remap_cityscape_hist':
        cfg.TRAIN.DATASETS = (
            'bdd_peds_HP18k_target_domain_remap_cityscape_hist',
            'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    # Source domain + Target domain HP18k -- after histogram matching
    elif args.dataset == 'bdd_peds+HP18k_remap_random':
        cfg.TRAIN.DATASETS = ('bdd_peds_HP18k_target_domain_remap_random',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    # Source+Noisy Target domain -- prevent domain adv from using HP roi info
    elif args.dataset == 'bdd_peds+bdd_HP18k_noisy_100k':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_noisy_100k', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+bdd_HP18k_noisy_080':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_noisy_080', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+bdd_HP18k_noisy_060':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_noisy_060', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+bdd_HP18k_noisy_070':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_noisy_070', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == "wider_train":
        cfg.TRAIN.DATASETS = ('wider_train', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-subset":
        cfg.TRAIN.DATASETS = ('cs6-subset', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-subset-score":
        cfg.TRAIN.DATASETS = ('cs6-subset-score', )
    elif args.dataset == "cs6-subset-gt":
        cfg.TRAIN.DATASETS = ('cs6-subset-gt', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-3013-gt":
        cfg.TRAIN.DATASETS = ('cs6-3013-gt',
                              )  # DEBUG: overfit on one video annots
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-subset-gt+WIDER":
        cfg.TRAIN.DATASETS = ('cs6-subset-gt', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-subset+WIDER":
        cfg.TRAIN.DATASETS = ('cs6-subset', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-gt":
        cfg.TRAIN.DATASETS = ('cs6-train-gt', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-gt-noisy-0.3":
        cfg.TRAIN.DATASETS = ('cs6-train-gt-noisy-0.3', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-gt-noisy-0.5":
        cfg.TRAIN.DATASETS = ('cs6-train-gt-noisy-0.5', )
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == "cs6-train-det-score":
        cfg.TRAIN.DATASETS = ('cs6-train-det-score', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-det-score-0.5":
        cfg.TRAIN.DATASETS = ('cs6-train-det-score-0.5', )
    elif args.dataset == "cs6-train-det":
        cfg.TRAIN.DATASETS = ('cs6-train-det', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-det-0.5":
        cfg.TRAIN.DATASETS = ('cs6-train-det-0.5', )
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == "cs6-train-hp":
        cfg.TRAIN.DATASETS = ('cs6-train-hp', )
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == "cs6-train-easy-gt":
        cfg.TRAIN.DATASETS = ('cs6-train-easy-gt', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-easy-gt-sub":
        cfg.TRAIN.DATASETS = ('cs6-train-easy-gt-sub', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-easy-hp":
        cfg.TRAIN.DATASETS = ('cs6-train-easy-hp', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-easy-det":
        cfg.TRAIN.DATASETS = ('cs6-train-easy-det', )
        cfg.MODEL.NUM_CLASSES = 2

        # Joint training with CS6 and WIDER
    elif args.dataset == "cs6-train-easy-gt-sub+WIDER":
        cfg.TRAIN.DATASETS = ('cs6-train-easy-gt-sub', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-gt+WIDER":
        cfg.TRAIN.DATASETS = ('cs6-train-gt', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == "cs6-train-hp+WIDER":
        cfg.TRAIN.DATASETS = ('cs6-train-hp', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == "cs6-train-dummy+WIDER":
        cfg.TRAIN.DATASETS = ('cs6-train-dummy', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == "cs6-train-det+WIDER":
        cfg.TRAIN.DATASETS = ('cs6-train-det', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2

    # Dets dataset created by removing tracker results from the HP json
    elif args.dataset == "cs6_train_det_from_hp+WIDER":
        cfg.TRAIN.DATASETS = ('cs6_train_det_from_hp', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    # Dataset created by removing det results from the HP json -- HP tracker only
    elif args.dataset == "cs6_train_hp_tracker_only+WIDER":
        cfg.TRAIN.DATASETS = ('cs6_train_hp_tracker_only', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    # HP dataset with noisy labels: used to prevent DA from getting any info from HP
    elif args.dataset == "cs6_train_hp_noisy_100+WIDER":
        cfg.TRAIN.DATASETS = ('cs6_train_hp_noisy_100', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2

    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH
    original_num_gpus = cfg.NUM_GPUS
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    effective_batch_size = args.iter_size * args.batch_size
    print('effective_batch_size = batch_size * iter_size = %d * %d' %
          (args.batch_size, args.iter_size))

    print('Adaptive config changes:')
    print('    effective_batch_size: %d --> %d' %
          (original_batch_size, effective_batch_size))
    print('    NUM_GPUS:             %d --> %d' %
          (original_num_gpus, cfg.NUM_GPUS))
    print('    IMS_PER_BATCH:        %d --> %d' %
          (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH))

    ### Adjust learning based on batch size change linearly
    # For iter_size > 1, gradients are `accumulated`, so lr is scaled based
    # on batch_size instead of effective_batch_size
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch_size change:\n'
          '    BASE_LR: {} --> {}'.format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Adjust solver steps
    step_scale = original_batch_size / effective_batch_size
    old_solver_steps = cfg.SOLVER.STEPS
    old_max_iter = cfg.SOLVER.MAX_ITER
    cfg.SOLVER.STEPS = list(
        map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS))
    cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5)
    print(
        'Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n'
        '    SOLVER.STEPS: {} --> {}\n'
        '    SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps,
                                                cfg.SOLVER.STEPS, old_max_iter,
                                                cfg.SOLVER.MAX_ITER))

    # Scale FPN rpn_proposals collect size (post_nms_topN) in `collect` function
    # of `collect_and_distribute_fpn_rpn_proposals.py`
    #
    # post_nms_topN = int(cfg[cfg_key].RPN_POST_NMS_TOP_N * cfg.FPN.RPN_COLLECT_SCALE + 0.5)
    if cfg.FPN.FPN_ON and cfg.MODEL.FASTER_RCNN:
        cfg.FPN.RPN_COLLECT_SCALE = cfg.TRAIN.IMS_PER_BATCH / original_ims_per_batch
        print(
            'Scale FPN rpn_proposals collect size directly propotional to the change of IMS_PER_BATCH:\n'
            '    cfg.FPN.RPN_COLLECT_SCALE: {}'.format(
                cfg.FPN.RPN_COLLECT_SCALE))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma
    assert_and_infer_cfg()

    timers = defaultdict(Timer)
    """def _init_fn(worker_id):
        random.seed(999)
        np.random.seed(999)
        torch.cuda.manual_seed(999)
        torch.cuda.manual_seed_all(999)
        torch.manual_seed(999)
        torch.backends.cudnn.deterministic = True
    """

    ### Dataset ###
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb',
                timers['roidb'].average_time)
    if cfg.TRAIN.JOINT_TRAINING:
        if len(cfg.TRAIN.DATASETS) == 2:
            print('Joint training on two datasets')
        else:
            raise NotImplementedError

        joint_training_roidb = []
        for i, dataset_name in enumerate(cfg.TRAIN.DATASETS):
            # ROIDB construction
            timers['roidb'].tic()
            roidb, ratio_list, ratio_index = combined_roidb_for_training(
                (dataset_name), cfg.TRAIN.PROPOSAL_FILES)
            timers['roidb'].toc()
            roidb_size = len(roidb)
            logger.info('{:d} roidb entries'.format(roidb_size))
            logger.info('Takes %.2f sec(s) to construct roidb',
                        timers['roidb'].average_time)

            if i == 0:
                roidb_size = len(roidb)

            batchSampler = BatchSampler(sampler=MinibatchSampler(
                ratio_list, ratio_index),
                                        batch_size=args.batch_size,
                                        drop_last=True)
            dataset = RoiDataLoader(roidb,
                                    cfg.MODEL.NUM_CLASSES,
                                    training=True)
            dataloader = torch.utils.data.DataLoader(
                dataset,
                batch_sampler=batchSampler,
                num_workers=cfg.DATA_LOADER.NUM_THREADS,
                collate_fn=collate_minibatch)
            #worker_init_fn=_init_fn)
            # decrease num-threads when using two dataloaders
            dataiterator = iter(dataloader)

            joint_training_roidb.append({
                'dataloader': dataloader,
                'dataiterator': dataiterator,
                'dataset_name': dataset_name
            })
    else:
        roidb, ratio_list, ratio_index = combined_roidb_for_training(
            cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
        timers['roidb'].toc()
        roidb_size = len(roidb)
        logger.info('{:d} roidb entries'.format(roidb_size))
        logger.info('Takes %.2f sec(s) to construct roidb',
                    timers['roidb'].average_time)

        batchSampler = BatchSampler(sampler=MinibatchSampler(
            ratio_list, ratio_index),
                                    batch_size=args.batch_size,
                                    drop_last=True)
        dataset = RoiDataLoader(roidb, cfg.MODEL.NUM_CLASSES, training=True)
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_sampler=batchSampler,
            num_workers=cfg.DATA_LOADER.NUM_THREADS,
            collate_fn=collate_minibatch)
        #worker_init_fn=init_fn)
        dataiterator = iter(dataloader)

    # Effective training sample size for one epoch
    train_size = roidb_size // args.batch_size * args.batch_size

    if cfg.TRAIN.JOINT_SELECTIVE_FG:
        orig_fg_batch_ratio = cfg.TRAIN.FG_FRACTION

    ### Model ###
    maskRCNN = Generalized_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    gn_param_nameset = set()
    for name, module in maskRCNN.named_modules():
        if isinstance(module, nn.GroupNorm):
            gn_param_nameset.add(name + '.weight')
            gn_param_nameset.add(name + '.bias')
    gn_params = []
    gn_param_names = []
    bias_params = []
    bias_param_names = []
    nonbias_params = []
    nonbias_param_names = []
    nograd_param_names = []
    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
                bias_param_names.append(key)
            elif key in gn_param_nameset:
                gn_params.append(value)
                gn_param_names.append(key)
            else:
                nonbias_params.append(value)
                nonbias_param_names.append(key)
        else:
            nograd_param_names.append(key)
    assert (gn_param_nameset - set(nograd_param_names) -
            set(bias_param_names)) == set(gn_param_names)

    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [{
        'params': nonbias_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        bias_params,
        'lr':
        0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }, {
        'params': gn_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN
    }]
    # names of paramerters for each paramter
    param_names = [nonbias_param_names, bias_param_names, gn_param_names]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print(
                        'train_size value: %d different from the one in checkpoint: %d'
                        % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer,
                                                 checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[0][
        'lr']  # lr of non-bias parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name() + '_' + str(
        cfg.TRAIN.DATASETS) + '_step'
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):
            """
            random.seed(cfg.RNG_SEED)
            np.random.seed(cfg.RNG_SEED)
            torch.cuda.manual_seed(cfg.RNG_SEED)
            torch.cuda.manual_seed_all(cfg.RNG_SEED)
            torch.manual_seed(cfg.RNG_SEED)
            torch.backends.cudnn.deterministic = True
            """

            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 -
                                                                 alpha) + alpha
                else:
                    raise KeyError(
                        'Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate(optimizer, lr,
                                               cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()

            for inner_iter in range(args.iter_size):
                # use a iter counter for optional alternating batches
                if args.iter_size == 1:
                    iter_counter = step
                else:
                    iter_counter = inner_iter

                if cfg.TRAIN.JOINT_TRAINING:
                    # alternate batches between dataset[0] and dataset[1]
                    if iter_counter % 2 == 0:
                        if True:  #DEBUG:
                            print('Dataset: %s' %
                                  joint_training_roidb[0]['dataset_name'])
                        dataloader = joint_training_roidb[0]['dataloader']
                        dataiterator = joint_training_roidb[0]['dataiterator']
                        # NOTE: if available FG samples cannot fill minibatch
                        # then batchsize will be smaller than cfg.TRAIN.BATCH_SIZE_PER_IM.
                    else:
                        if True:  #DEBUG:
                            print('Dataset: %s' %
                                  joint_training_roidb[1]['dataset_name'])
                        dataloader = joint_training_roidb[1]['dataloader']
                        dataiterator = joint_training_roidb[1]['dataiterator']

                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    # end of epoch for dataloader
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)
                    if cfg.TRAIN.JOINT_TRAINING:
                        if iter_counter % 2 == 0:
                            joint_training_roidb[0][
                                'dataiterator'] = dataiterator
                        else:
                            joint_training_roidb[1][
                                'dataiterator'] = dataiterator

                for key in input_data:
                    if key != 'roidb':  # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))

                net_outputs = maskRCNN(**input_data)
                training_stats.UpdateIterStats(net_outputs, inner_iter)
                loss = net_outputs['total_loss']
                # [p.data.get_device() for p in maskRCNN.parameters()]
                # [(name, p.data.get_device()) for name, p in maskRCNN.named_parameters()]
                loss.backward()

            optimizer.step()
            training_stats.IterToc()

            training_stats.LogIterStats(step, lr)

            if (step + 1) % CHECKPOINT_PERIOD == 0:
                save_ckpt(output_dir, args, step, train_size, maskRCNN,
                          optimizer)

        # ---- Training ends ----
        # Save last checkpoint
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

    except (RuntimeError, KeyboardInterrupt):
        del dataiterator
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
예제 #7
0
def main():
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    if args.dataset == "davis2017":
        cfg.TRAIN.DATASETS = ('davis_train', )
        #For davis, coco category is used.
        cfg.MODEL.NUM_CLASSES = 81  #80 foreground + 1 background
    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    if cfg.MODEL.IDENTITY_TRAINING and cfg.MODEL.IDENTITY_REPLACE_CLASS:
        cfg.MODEL.NUM_CLASSES = 145
        cfg.MODEL.IDENTITY_TRAINING = False
        cfg.MODEL.ADD_UNKNOWN_CLASS = False

    #Add unknow class type if necessary.
    if cfg.MODEL.ADD_UNKNOWN_CLASS is True:
        cfg.MODEL.NUM_CLASSES += 1

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH
    original_num_gpus = cfg.NUM_GPUS
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    effective_batch_size = args.iter_size * args.batch_size

    print('effective_batch_size = batch_size * iter_size = %d * %d' %
          (args.batch_size, args.iter_size))
    print('Adaptive config changes:')
    print('    effective_batch_size: %d --> %d' %
          (original_batch_size, effective_batch_size))
    print('    NUM_GPUS:             %d --> %d' %
          (original_num_gpus, cfg.NUM_GPUS))
    print('    IMS_PER_BATCH:        %d --> %d' %
          (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH))

    ### Adjust learning based on batch size change linearly
    # For iter_size > 1, gradients are `accumulated`, so lr is scaled based
    # on batch_size instead of effective_batch_size
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    cfg.SOLVER.BASE_LR *= 1.0 / cfg.MODEL.SEQUENCE_LENGTH
    print(
        'Adjust BASE_LR linearly according to batch_size change and sequence length change:\n'
        '    BASE_LR: {} --> {}'.format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Adjust solver steps
    step_scale = original_batch_size / effective_batch_size
    old_solver_steps = cfg.SOLVER.STEPS
    old_max_iter = cfg.SOLVER.MAX_ITER
    cfg.SOLVER.STEPS = list(
        map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS))
    cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5)
    print(
        'Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n'
        '    SOLVER.STEPS: {} --> {}\n'
        '    SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps,
                                                cfg.SOLVER.STEPS, old_max_iter,
                                                cfg.SOLVER.MAX_ITER))

    # Scale FPN rpn_proposals collect size (post_nms_topN) in `collect` function
    # of `collect_and_distribute_fpn_rpn_proposals.py`
    #
    # post_nms_topN = int(cfg[cfg_key].RPN_POST_NMS_TOP_N * cfg.FPN.RPN_COLLECT_SCALE + 0.5)
    if cfg.FPN.FPN_ON and cfg.MODEL.FASTER_RCNN:
        cfg.FPN.RPN_COLLECT_SCALE = cfg.TRAIN.IMS_PER_BATCH / original_ims_per_batch
        print(
            'Scale FPN rpn_proposals collect size directly propotional to the change of IMS_PER_BATCH:\n'
            '    cfg.FPN.RPN_COLLECT_SCALE: {}'.format(
                cfg.FPN.RPN_COLLECT_SCALE))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma
    assert_and_infer_cfg()

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    merged_roidb, seq_num, seq_start_end = sequenced_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES, load_inv_db=True)

    timers['roidb'].toc()
    roidb_size = len(merged_roidb)
    logger.info('{:d} roidbs sequences.'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidbs',
                timers['roidb'].average_time)

    # Effective training sample size for one epoch, number of sequences.
    train_size = roidb_size // args.batch_size * args.batch_size

    ### Model ###
    maskRCNN = vos_model_builder.Generalized_VOS_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    gn_param_nameset = set()
    for name, module in maskRCNN.named_modules():
        if isinstance(module, nn.GroupNorm):
            gn_param_nameset.add(name + '.weight')
            gn_param_nameset.add(name + '.bias')
    gn_params = []
    gn_param_names = []
    bias_params = []
    bias_param_names = []
    nonbias_params = []
    nonbias_param_names = []
    nograd_param_names = []
    for key, value in maskRCNN.named_parameters():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
                bias_param_names.append(key)
            elif key in gn_param_nameset:
                gn_params.append(value)
                gn_param_names.append(key)
            else:
                nonbias_params.append(value)
                nonbias_param_names.append(key)
        else:
            nograd_param_names.append(key)
    assert (gn_param_nameset - set(nograd_param_names) -
            set(bias_param_names)) == set(gn_param_names)

    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [{
        'params': nonbias_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        bias_params,
        'lr':
        0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }, {
        'params': gn_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN
    }]
    # names of paramerters for each paramter
    param_names = [nonbias_param_names, bias_param_names, gn_param_names]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils.load_ckpt_no_mapping(maskRCNN, checkpoint['model'])
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print(
                        'train_size value: %d different from the one in checkpoint: %d'
                        % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            optimizer.load_state_dict(checkpoint['optimizer'])
            # misc_utils.load_optimizer_state_dict(optimizer, checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN,
                              args.load_detectron,
                              force_load_all=False)

    lr = optimizer.param_groups[0][
        'lr']  # lr of non-bias parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name() + '_step'
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    dataloader, dataiterator, warmup_length = gen_sequence_data_sampler(
        merged_roidb, seq_num, seq_start_end, use_seq_warmup=False)
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):
            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 -
                                                                 alpha) + alpha
                else:
                    raise KeyError(
                        'Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate(optimizer, lr,
                                               cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()

            if cfg.TRAIN.USE_SEQ_WARMUP is True and step >= cfg.TRAIN.ITER_BEFORE_USE_SEQ_WARMUP and (
                (step - cfg.TRAIN.ITER_BEFORE_USE_SEQ_WARMUP) %
                    cfg.TRAIN.WARMUP_LENGTH_CHANGE_STEP) == 0:
                #TODO better dataiterator creator.
                raise NotImplementedError('not able to delete.')
                dataloader, dataiterator, warmup_length = gen_sequence_data_sampler(
                    merged_roidb, seq_num, seq_start_end, use_seq_warmup=True)
                print('update warmup length:', warmup_length)
            try:
                input_data_sequence = next(dataiterator)
            except StopIteration:
                dataiterator = iter(dataloader)
                input_data_sequence = next(dataiterator)

            # clean hidden states before training.
            maskRCNN.module.clean_hidden_states()
            maskRCNN.module.clean_flow_features()

            assert len(input_data_sequence['data']
                       ) == cfg.MODEL.SEQUENCE_LENGTH + warmup_length, print(
                           len(input_data_sequence['data']), '!=',
                           cfg.MODEL.SEQUENCE_LENGTH + warmup_length)

            # if train_part == 0: train backbone.
            # else train_part == 1: train heads.
            train_part = 1
            if cfg.TRAIN.ALTERNATE_TRAINING:
                if step % 2 == 0:
                    maskRCNN.module.freeze_conv_body_only()
                    train_part = 1
                else:
                    maskRCNN.module.train_conv_body_only()
                    train_part = 0
            # this is used for longer sequence training.
            # when reach maximum trainable length, detach the hidden states.
            cnter_for_detach_hidden_states = 0
            for inner_iter in range(cfg.MODEL.SEQUENCE_LENGTH + warmup_length):
                #get input_data
                input_data = {}
                for key in input_data_sequence.keys():
                    input_data[key] = input_data_sequence[key][
                        inner_iter:inner_iter + 1]
                for key in input_data:
                    if key != 'roidb' and key != 'data_flow':  # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))
                    if key == 'data_flow':
                        if inner_iter != 0 and input_data[key][0][0][
                                0] is not None:  # flow is not None.
                            input_data[key] = [
                                Variable(
                                    torch.tensor(
                                        np.expand_dims(
                                            np.squeeze(
                                                np.array(input_data[key],
                                                         dtype=np.float32)),
                                            0),
                                        device=input_data['data'][0].device))
                            ]
                            assert input_data['data'][0].shape[
                                -2:] == input_data[key][0].shape[
                                    -2:], "Spatial shape of image and flow are not equal."
                        else:
                            input_data[key] = [None]
                if cfg.TRAIN.USE_SEQ_WARMUP and inner_iter < warmup_length:
                    maskRCNN.module.set_stop_after_hidden_states(stop=True)
                    net_outputs = maskRCNN(**input_data)
                    assert net_outputs is None
                    maskRCNN.module.detach_hidden_states()
                    maskRCNN.module.set_stop_after_hidden_states(stop=False)
                    continue

                net_outputs = maskRCNN(**input_data)
                training_stats.UpdateIterStats(net_outputs, inner_iter)
                loss = net_outputs['total_loss']

                if train_part == 0 or cnter_for_detach_hidden_states >= cfg.TRAIN.MAX_TRAINABLE_SEQ_LENGTH or inner_iter == cfg.MODEL.SEQUENCE_LENGTH + warmup_length - 1:
                    # if reach the max trainable length or end of the sequence. free the graph and detach the hidden states.
                    loss.backward()
                    maskRCNN.module.detach_hidden_states()
                    cnter_for_detach_hidden_states = 0
                else:
                    loss.backward(retain_graph=True)
                    cnter_for_detach_hidden_states += 1
                #TODO step every time?
                if cnter_for_detach_hidden_states == 0:
                    optimizer.step()
                    optimizer.zero_grad()

            training_stats.IterToc()
            training_stats.LogIterStats(step, lr)

            if (step + 1) % CHECKPOINT_PERIOD == 0:
                save_ckpt(output_dir, args, step, train_size, maskRCNN,
                          optimizer)

        # ---- Training ends ----
        # Save last checkpoint
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

    except (RuntimeError):
        del dataiterator
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
    except (KeyboardInterrupt):
        del dataiterator
        logger.info('Save ckpt on Keyboard Interrupt ...')
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
def main():
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    if args.dataset == "coco2017":
        cfg.TRAIN.DATASETS = ('coco_2017_train',)
        cfg.MODEL.NUM_CLASSES = 81
    elif args.dataset == "keypoints_coco2017":
        cfg.TRAIN.DATASETS = ('keypoints_coco_2017_train',)
        cfg.MODEL.NUM_CLASSES = 2
    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH
    original_num_gpus = cfg.NUM_GPUS
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    effective_batch_size = args.iter_size * args.batch_size
    print('effective_batch_size = batch_size * iter_size = %d * %d' % (args.batch_size, args.iter_size))

    print('Adaptive config changes:')
    print('    effective_batch_size: %d --> %d' % (original_batch_size, effective_batch_size))
    print('    NUM_GPUS:             %d --> %d' % (original_num_gpus, cfg.NUM_GPUS))
    print('    IMS_PER_BATCH:        %d --> %d' % (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH))

    ### Adjust learning based on batch size change linearly
    # For iter_size > 1, gradients are `accumulated`, so lr is scaled based
    # on batch_size instead of effective_batch_size
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch_size change:\n'
          '    BASE_LR: {} --> {}'.format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Adjust solver steps
    step_scale = original_batch_size / effective_batch_size
    old_solver_steps = cfg.SOLVER.STEPS
    old_max_iter = cfg.SOLVER.MAX_ITER
    cfg.SOLVER.STEPS = list(map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS))
    cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5)
    print('Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n'
          '    SOLVER.STEPS: {} --> {}\n'
          '    SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps, cfg.SOLVER.STEPS,
                                                  old_max_iter, cfg.SOLVER.MAX_ITER))

    # Scale FPN rpn_proposals collect size (post_nms_topN) in `collect` function
    # of `collect_and_distribute_fpn_rpn_proposals.py`
    #
    # post_nms_topN = int(cfg[cfg_key].RPN_POST_NMS_TOP_N * cfg.FPN.RPN_COLLECT_SCALE + 0.5)
    if cfg.FPN.FPN_ON and cfg.MODEL.FASTER_RCNN:
        cfg.FPN.RPN_COLLECT_SCALE = cfg.TRAIN.IMS_PER_BATCH / original_ims_per_batch
        print('Scale FPN rpn_proposals collect size directly propotional to the change of IMS_PER_BATCH:\n'
              '    cfg.FPN.RPN_COLLECT_SCALE: {}'.format(cfg.FPN.RPN_COLLECT_SCALE))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma
    assert_and_infer_cfg()

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb', timers['roidb'].average_time)

    # Effective training sample size for one epoch
    train_size = roidb_size // args.batch_size * args.batch_size

    batchSampler = BatchSampler(
        sampler=MinibatchSampler(ratio_list, ratio_index),
        batch_size=args.batch_size,
        drop_last=True
    )
    dataset = RoiDataLoader(
        roidb,
        cfg.MODEL.NUM_CLASSES,
        training=True)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=batchSampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    dataiterator = iter(dataloader)

    ### Model ###
    maskRCNN = Generalized_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    gn_param_nameset = set()
    for name, module in maskRCNN.named_modules():
        if isinstance(module, nn.GroupNorm):
            gn_param_nameset.add(name+'.weight')
            gn_param_nameset.add(name+'.bias')
    gn_params = []
    gn_param_names = []
    bias_params = []
    bias_param_names = []
    nonbias_params = []
    nonbias_param_names = []
    nograd_param_names = []
    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
                bias_param_names.append(key)
            elif key in gn_param_nameset:
                gn_params.append(value)
                gn_param_names.append(key)
            else:
                nonbias_params.append(value)
                nonbias_param_names.append(key)
        else:
            nograd_param_names.append(key)
    assert (gn_param_nameset - set(nograd_param_names) - set(bias_param_names)) == set(gn_param_names)

    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [
        {'params': nonbias_params,
         'lr': 0,
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY},
        {'params': bias_params,
         'lr': 0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0},
        {'params': gn_params,
         'lr': 0,
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN}
    ]
    # names of paramerters for each paramter
    param_names = [nonbias_param_names, bias_param_names, gn_param_names]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name, map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print('train_size value: %d different from the one in checkpoint: %d'
                          % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer, checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[0]['lr']  # lr of non-bias parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN, cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name() + '_step'
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args,
        args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):

            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 - alpha) + alpha
                else:
                    raise KeyError('Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate(optimizer, lr, cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()
            for inner_iter in range(args.iter_size):
                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)

                for key in input_data:
                    if key != 'roidb': # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))

                net_outputs = maskRCNN(**input_data)
                training_stats.UpdateIterStats(net_outputs, inner_iter)
                loss = net_outputs['total_loss']
                loss.backward()
            optimizer.step()
            training_stats.IterToc()

            training_stats.LogIterStats(step, lr)

            if (step+1) % CHECKPOINT_PERIOD == 0:
                save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

        # ---- Training ends ----
        # Save last checkpoint
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

    except (RuntimeError, KeyboardInterrupt):
        del dataiterator
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
예제 #9
0
def train_val(model,
              args,
              optimizer,
              lr,
              dataloader,
              train_size,
              output_dir,
              tblogger=None):

    dataiterator = iter(dataloader)
    model.train()

    CHECKPOINT_PERIOD = cfg.TRAIN.SNAPSHOT_ITERS

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)

    try:
        logger.info('Training starts !')
        step = args.start_step

        best_ap = 0
        best_step = 0
        running_tr_loss = 0.
        DRAW_STEP = args.start_step

        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):
            # print('stepppp: ', step)
            # print(cfg.SOLVER.WARM_UP_ITERS)
            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 -
                                                                 alpha) + alpha
                else:
                    raise KeyError(
                        'Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate(optimizer, lr,
                                               cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()

            for inner_iter in range(args.iter_size):
                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)

                for key in input_data:
                    if key != 'roidb':  # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))

                net_outputs = model(**input_data)
                training_stats.UpdateIterStats(net_outputs, inner_iter)
                loss = net_outputs['total_loss']
                #print('555')
                running_tr_loss += loss.item()
                if loss.requires_grad:
                    loss.backward()

            optimizer.step()
            training_stats.IterToc()
            training_stats.LogIterStats(step, lr)
            # print('CHECKPOINT_PERIOD: ', CHECKPOINT_PERIOD)

            # if (step+1) % 800==0 or step==cfg.SOLVER.MAX_ITER-1:
            #     print('\tAverage Training Runing loss of step {}: {:.8f}'.format(step+1, running_tr_loss/(step+1-DRAW_STEP)))
            #     tblogger.add_scalar('Runing_Training_loss', running_tr_loss/(step+1-DRAW_STEP), step+1)
            #     DRAW_STEP = step+1
            #     running_tr_loss = 0.

            CHECKPOINT_PERIOD = 2000
            if ((step + 1) % CHECKPOINT_PERIOD
                    == 0) or step == cfg.SOLVER.MAX_ITER - 1:
                if (step + 1) > 15000:
                    save_ckpt(output_dir, args, step, train_size, model,
                              optimizer)

        # ---- Training ends ----
        # Save last checkpoint
        #save_ckpt(output_dir, args, step, train_size, model, optimizer)

    except (RuntimeError, KeyboardInterrupt):
        del dataiterator
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, train_size, model, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
예제 #10
0
def main():
    """Main function"""
    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    merge_cfg_from_file(args.cfg_file)

    # Some manual adjustment for the ApolloScape dataset parameters here
    cfg.OUTPUT_DIR = args.output_dir
    cfg.TRAIN.DATASETS = 'Car3D'
    cfg.MODEL.NUM_CLASSES = 8
    if cfg.CAR_CLS.SIM_MAT_LOSS:
        cfg.MODEL.NUMBER_CARS = 79
    else:
        # Loss is only cross entropy, hence, we detect only car categories in the training set.
        cfg.MODEL.NUMBER_CARS = 34
    cfg.TRAIN.MIN_AREA = 196  # 14*14
    cfg.TRAIN.USE_FLIPPED = False  # Currently I don't know how to handle the flipped case
    cfg.TRAIN.IMS_PER_BATCH = 1

    cfg.NUM_GPUS = torch.cuda.device_count()
    effective_batch_size = cfg.TRAIN.IMS_PER_BATCH * cfg.NUM_GPUS * args.iter_size

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH
    original_num_gpus = cfg.NUM_GPUS
    if args.batch_size is None:
        args.batch_size = original_batch_size

    assert (args.batch_size %
            cfg.NUM_GPUS) == 0, 'batch_size: %d, NUM_GPUS: %d' % (
                args.batch_size, cfg.NUM_GPUS)

    print('effective_batch_size = batch_size * iter_size = %d * %d' %
          (args.batch_size, args.iter_size))
    print('Adaptive config changes:')
    print('    effective_batch_size: %d --> %d' %
          (original_batch_size, effective_batch_size))
    print('    NUM_GPUS:             %d --> %d' %
          (original_num_gpus, cfg.NUM_GPUS))
    print('    IMS_PER_BATCH:        %d --> %d' %
          (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH))

    ### Adjust learning based on batch size change linearly
    # For iter_size > 1, gradients are `accumulated`, so lr is scaled based
    # on batch_size instead of effective_batch_size
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print(
        'Adjust BASE_LR linearly according to batch_size change:\n BASE_LR: {} --> {}'
        .format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Adjust solver steps
    step_scale = original_batch_size / effective_batch_size
    old_solver_steps = cfg.SOLVER.STEPS
    old_max_iter = cfg.SOLVER.MAX_ITER
    cfg.SOLVER.STEPS = list(
        map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS))
    cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5)
    print(
        'Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n'
        '    SOLVER.STEPS: {} --> {}\n'
        '    SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps,
                                                cfg.SOLVER.STEPS, old_max_iter,
                                                cfg.SOLVER.MAX_ITER))

    # Scale FPN rpn_proposals collect size (post_nms_topN) in `collect` function
    # of `collect_and_distribute_fpn_rpn_proposals.py`
    #
    # post_nms_topN = int(cfg[cfg_key].RPN_POST_NMS_TOP_N * cfg.FPN.RPN_COLLECT_SCALE + 0.5)
    if cfg.FPN.FPN_ON and cfg.MODEL.FASTER_RCNN:
        cfg.FPN.RPN_COLLECT_SCALE = cfg.TRAIN.IMS_PER_BATCH / original_ims_per_batch
        print(
            'Scale FPN rpn_proposals collect size directly propotional to the change of IMS_PER_BATCH:\n'
            '    cfg.FPN.RPN_COLLECT_SCALE: {}'.format(
                cfg.FPN.RPN_COLLECT_SCALE))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma
    assert_and_infer_cfg()

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    if cfg.MODEL.LOSS_3D_2D_ON:
        roidb, ratio_list, ratio_index, ds = combined_roidb_for_training(
            cfg.TRAIN.DATASETS, args.dataset_dir)
    else:
        roidb, ratio_list, ratio_index = combined_roidb_for_training(
            cfg.TRAIN.DATASETS, args.dataset_dir)

    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb',
                timers['roidb'].average_time)

    # Effective training sample size for one epoch
    train_size = roidb_size // args.batch_size * args.batch_size

    sampler = MinibatchSampler(ratio_list, ratio_index)
    dataset = RoiDataLoader(roidb,
                            cfg.MODEL.NUM_CLASSES,
                            training=True,
                            valid_keys=[
                                'has_visible_keypoints', 'boxes', 'seg_areas',
                                'gt_classes', 'gt_overlaps',
                                'box_to_gt_ind_map', 'is_crowd',
                                'car_cat_classes', 'poses', 'quaternions',
                                'im_info'
                            ])
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        drop_last=True,
        sampler=sampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    dataiterator = iter(dataloader)

    ### Model ###
    if cfg.MODEL.LOSS_3D_2D_ON:
        maskRCNN = Generalized_RCNN(ds.Car3D)
    else:
        maskRCNN = Generalized_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    bias_params = []
    bias_param_names = []
    nonbias_params = []
    nonbias_param_names = []
    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
                bias_param_names.append(key)
            else:
                nonbias_params.append(value)
                nonbias_param_names.append(key)
    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [{
        'params': nonbias_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        bias_params,
        'lr':
        0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN,
                            checkpoint['model'],
                            ignore_list=args.ckpt_ignore_head)
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print(
                        'train_size value: %d different from the one in checkpoint: %d'
                        % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer,
                                                 checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  # TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[0][
        'lr']  # lr of non-bias parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name() + '_step'
    # output_dir = os.path.join('/media/SSD_1TB/zzy/ApolloScape/ECCV2018_apollo/train', args.run_name)
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()
    CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)
    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    # warmup_factor_trans = 1.0
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):

            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 -
                                                                 alpha) + alpha
                    # warmup_factor_trans = cfg.SOLVER.WARM_UP_FACTOR_TRANS * (1 - alpha) + alpha
                    # warmup_factor_trans *= cfg.TRANS_HEAD.LOSS_BETA
                    warmup_factor_trans = 1.0
                else:
                    raise KeyError(
                        'Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate(optimizer, lr,
                                               cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(
                    cfg.SOLVER.STEPS
            ) and step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()
            for inner_iter in range(args.iter_size):
                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)

                for key in input_data:
                    if key != 'roidb':  # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))

                net_outputs = maskRCNN(**input_data)

                net_outputs['losses'][
                    'loss_car_cls'] *= cfg.CAR_CLS.CAR_CLS_LOSS_BETA
                net_outputs['losses']['loss_rot'] *= cfg.CAR_CLS.ROT_LOSS_BETA
                if cfg.MODEL.TRANS_HEAD_ON:
                    net_outputs['losses'][
                        'loss_trans'] *= cfg.TRANS_HEAD.TRANS_LOSS_BETA

                training_stats.UpdateIterStats_car_3d(net_outputs)

                # start training
                # loss_car_cls: 2.233790, loss_rot: 0.296853, loss_trans: ~100
                loss = net_outputs['losses']['loss_car_cls'] + net_outputs[
                    'losses']['loss_rot']
                if cfg.MODEL.TRANS_HEAD_ON:
                    loss += net_outputs['losses']['loss_trans']
                if cfg.MODEL.LOSS_3D_2D_ON:
                    loss += net_outputs['losses']['UV_projection_loss']
                if not cfg.TRAIN.FREEZE_CONV_BODY and not cfg.TRAIN.FREEZE_RPN and not cfg.TRAIN.FREEZE_FPN:
                    loss += net_outputs['total_loss_conv']

                loss.backward()
            optimizer.step()
            training_stats.IterToc()

            training_stats.LogIterStats(step, lr, warmup_factor_trans)

            if (step + 1) % CHECKPOINT_PERIOD == 0:
                save_ckpt(output_dir, args, step, train_size, maskRCNN,
                          optimizer)

        # ---- Training ends ----
        # Save last checkpoint
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

    except (RuntimeError, KeyboardInterrupt):
        del dataiterator
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()