示例#1
0
def main():
    """main function"""

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    args = parse_args()
    print('Called with args:')
    print(args)

    assert args.image_dir or args.images
    assert bool(args.image_dir) ^ bool(args.images)

    if args.dataset.startswith("coco"):
        dataset = datasets.get_coco_dataset()
        cfg.MODEL.NUM_CLASSES = len(dataset.classes)   
    elif args.dataset == "miotcd":
        dataset = datasets.get_miotcd_dataset()
        cfg.MODEL.NUM_CLASSES = 12
    elif args.dataset.startswith("keypoints_coco"):
        dataset = datasets.get_coco_dataset()
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset.startswith("bogota"):
        dataset = datasets.get_bogota_dataset()
        cfg.MODEL.NUM_CLASSES = 12
    else:
        raise ValueError('Unexpected dataset name: {}'.format(args.dataset))

    print('load cfg from file: {}'.format(args.cfg_file))
    cfg_from_file(args.cfg_file)

    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    assert bool(args.load_ckpt) ^ bool(args.load_detectron), \
        'Exactly one of --load_ckpt and --load_detectron should be specified.'
    cfg.MODEL.LOAD_IMAGENET_PRETRAINED_WEIGHTS = False  # Don't need to load imagenet pretrained weights
    assert_and_infer_cfg()

    maskRCNN = Generalized_RCNN()

    if args.cuda:
        maskRCNN.cuda()

    if args.load_ckpt:
        load_name = args.load_ckpt
        print("loading checkpoint %s" % (load_name))
        checkpoint = torch.load(load_name, map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])

    if args.load_detectron:
        print("loading detectron weights %s" % args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    maskRCNN = mynn.DataParallel(maskRCNN, cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True, device_ids=[0])  # only support single GPU

    maskRCNN.eval()
    if args.image_dir:
        imglist = misc_utils.get_imagelist_from_dir(args.image_dir)
    else:
        imglist = args.images
    num_images = len(imglist)
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    dataset_result = {}
    for i in xrange(num_images):
        print('img', i)
        im = cv2.imread(imglist[i])
        assert im is not None

        try:
            timers = defaultdict(Timer)

            cls_boxes, cls_segms, cls_keyps = im_detect_all(maskRCNN, im, timers=timers)
            boxes_, segme_ , keyps_ ,clasies= convert_from_cls_format(cls_boxes,cls_segms,cls_keyps)
            im_name, _ = os.path.splitext(os.path.basename(imglist[i]))
            dataset_result[im_name] = localize_obj_in_image(im_name,boxes_,clasies)

        except(e):
            import pdb
            pdb.set_trace()
    np.save('dictionary_answer.npy', dataset_result)     
    tmp = args.image_dir.split('/')       
    txt = str(tmp[-2:])+'_'+str(tmp[-1:])+'.csv'
    save_localization_result(dataset_result,txt)
示例#2
0
def main():
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")
    cfg_from_file(args.cfg_file)
    if args.dataset == "coco2017":
        cfg.TRAIN.DATASETS = ('coco_2017_train', )
        cfg.MODEL.NUM_CLASSES = 81
    elif args.dataset == "coco2014":
        cfg.TRAIN.DATASETS = ('coco_2014_train', )
        cfg.MODEL.NUM_CLASSES = 81
    elif args.dataset == "monuseg2018":
        cfg.TRAIN.DATASETS = ('monuseg_2018_train', )
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == "monuseg_all":
        cfg.TRAIN.DATASETS = ('monuseg_all_train', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False
    elif args.dataset == "monuseg_all_gan":
        cfg.TRAIN.DATASETS = ('monuseg_all_train_gan', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False

    elif args.dataset == "monuseg_baseline":
        cfg.TRAIN.DATASETS = ('monuseg_baseline_train', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 6000
        cfg.TRAIN.USE_FLIPPED = False
    elif args.dataset == "monuseg_baseline_best":
        cfg.TRAIN.DATASETS = ('monuseg_baseline_best_train', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False
    elif args.dataset == "monuseg_baseline_best_gan":
        cfg.TRAIN.DATASETS = ('monuseg_baseline_best_train_gan', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False
    elif args.dataset == "monuseg_0":
        cfg.TRAIN.DATASETS = ('monuseg_0_train', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False
    elif args.dataset == "monuseg_1":
        cfg.TRAIN.DATASETS = ('monuseg_1_train', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False
    elif args.dataset == "monuseg_2":
        cfg.TRAIN.DATASETS = ('monuseg_2_train', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False
    elif args.dataset == "monuseg_3":
        cfg.TRAIN.DATASETS = ('monuseg_3_train', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False
    elif args.dataset == "monuseg_4":
        cfg.TRAIN.DATASETS = ('monuseg_4_train', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False
    elif args.dataset == "monuseg_5":
        cfg.TRAIN.DATASETS = ('monuseg_5_train', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False
    elif args.dataset == "monuseg_6":
        cfg.TRAIN.DATASETS = ('monuseg_6_train', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False
    elif args.dataset == "monuseg_0_gan":
        cfg.TRAIN.DATASETS = ('monuseg_0_train_gan', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False
    elif args.dataset == "monuseg_1_gan":
        cfg.TRAIN.DATASETS = ('monuseg_1_train_gan', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False
    elif args.dataset == "monuseg_2_gan":
        cfg.TRAIN.DATASETS = ('monuseg_2_train_gan', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False
    elif args.dataset == "monuseg_3_gan":
        cfg.TRAIN.DATASETS = ('monuseg_3_train_gan', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False
    elif args.dataset == "monuseg_4_gan":
        cfg.TRAIN.DATASETS = ('monuseg_4_train_gan', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False
    elif args.dataset == "monuseg_5_gan":
        cfg.TRAIN.DATASETS = ('monuseg_5_train_gan', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False
    elif args.dataset == "monuseg_6_gan":
        cfg.TRAIN.DATASETS = ('monuseg_6_train_gan', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False

    elif args.dataset == "BNS_all":
        cfg.TRAIN.DATASETS = ('BNS_all_train', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False
    elif args.dataset == "BNS_all_gan":
        cfg.TRAIN.DATASETS = ('BNS_all_train_gan', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False
    elif args.dataset == "BNS_0":
        cfg.TRAIN.DATASETS = ('BNS_0_train', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False
    elif args.dataset == "BNS_1":
        cfg.TRAIN.DATASETS = ('BNS_1_train', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False
    elif args.dataset == "BNS_2":
        cfg.TRAIN.DATASETS = ('BNS_2_train', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False
    elif args.dataset == "BNS_3":
        cfg.TRAIN.DATASETS = ('BNS_3_train', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False
    elif args.dataset == "BNS_4":
        cfg.TRAIN.DATASETS = ('BNS_4_train', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False
    elif args.dataset == "BNS_5":
        cfg.TRAIN.DATASETS = ('BNS_5_train', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False
    elif args.dataset == "BNS_6":
        cfg.TRAIN.DATASETS = ('BNS_6_train', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False
    elif args.dataset == "BNS_0_gan":
        cfg.TRAIN.DATASETS = ('BNS_0_train_gan', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False
    elif args.dataset == "BNS_1_gan":
        cfg.TRAIN.DATASETS = ('BNS_1_train_gan', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False
    elif args.dataset == "BNS_2_gan":
        cfg.TRAIN.DATASETS = ('BNS_2_train_gan', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False
    elif args.dataset == "BNS_3_gan":
        cfg.TRAIN.DATASETS = ('BNS_3_train_gan', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False
    elif args.dataset == "BNS_4_gan":
        cfg.TRAIN.DATASETS = ('BNS_4_train_gan', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False
    elif args.dataset == "BNS_5_gan":
        cfg.TRAIN.DATASETS = ('BNS_5_train_gan', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False
    elif args.dataset == "BNS_6_gan":
        cfg.TRAIN.DATASETS = ('BNS_6_train_gan', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False

    elif args.dataset == "TNBC":
        cfg.TRAIN.DATASETS = ('TNBC_train', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.FPN.RPN_ANCHOR_START_SIZE = 8
        cfg.SOLVER.MAX_ITER = 3000
        cfg.TRAIN.USE_FLIPPED = False

    elif args.dataset == "monuseg_baseline_default":
        cfg.TRAIN.DATASETS = ('monuseg_baseline_train', )
        cfg.MODEL.NUM_CLASSES = 2
        cfg.SOLVER.MAX_ITER = 3000
    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))
    ###cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH
    original_num_gpus = cfg.NUM_GPUS
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)

    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    effective_batch_size = args.iter_size * args.batch_size
    print('effective_batch_size = batch_size * iter_size = %d * %d' %
          (args.batch_size, args.iter_size))

    print('Adaptive config changes:')
    print('    effective_batch_size: %d --> %d' %
          (original_batch_size, effective_batch_size))
    print('    NUM_GPUS:             %d --> %d' %
          (original_num_gpus, cfg.NUM_GPUS))
    print('    IMS_PER_BATCH:        %d --> %d' %
          (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH))

    ### Adjust learning based on batch size change linearly
    # For iter_size > 1, gradients are `accumulated`, so lr is scaled based
    # on batch_size instead of effective_batch_size
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch_size change:\n'
          '    BASE_LR: {} --> {}'.format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Adjust solver steps
    step_scale = original_batch_size / effective_batch_size
    old_solver_steps = cfg.SOLVER.STEPS
    old_max_iter = cfg.SOLVER.MAX_ITER
    cfg.SOLVER.STEPS = list(
        map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS))
    cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5)

    print(
        'Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n'
        '    SOLVER.STEPS: {} --> {}\n'
        '    SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps,
                                                cfg.SOLVER.STEPS, old_max_iter,
                                                cfg.SOLVER.MAX_ITER))

    # Scale FPN rpn_proposals collect size (post_nms_topN) in `collect` function
    # of `collect_and_distribute_fpn_rpn_proposals.py`
    #
    # post_nms_topN = int(cfg[cfg_key].RPN_POST_NMS_TOP_N * cfg.FPN.RPN_COLLECT_SCALE + 0.5)
    if cfg.FPN.FPN_ON and cfg.MODEL.FASTER_RCNN:
        cfg.FPN.RPN_COLLECT_SCALE = cfg.TRAIN.IMS_PER_BATCH / original_ims_per_batch
        print(
            'Scale FPN rpn_proposals collect size directly propotional to the change of IMS_PER_BATCH:\n'
            '    cfg.FPN.RPN_COLLECT_SCALE: {}'.format(
                cfg.FPN.RPN_COLLECT_SCALE))
    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma
    assert_and_infer_cfg()
    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)

    print('------------------------------------')
    #print (roidb[0])
    print(len(roidb))
    print(cfg.TRAIN.DATASETS)
    print(cfg.TRAIN.PROPOSAL_FILES)
    imggg = cv2.imread(roidb[0]['image'])
    for iii in roidb[0]['boxes']:
        xmin = iii[0]
        ymin = iii[1]
        xmax = iii[2]
        ymax = iii[3]
        imggg = cv2.rectangle(imggg, (xmin, ymin), (xmax, ymax), (0, 255, 0),
                              3)
    cv2.imwrite('aaaaaa.png', imggg)
    print('------------------------------------')
    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb',
                timers['roidb'].average_time)

    # Effective training sample size for one epoch
    train_size = roidb_size // args.batch_size * args.batch_size
    batchSampler = BatchSampler(sampler=MinibatchSampler(
        ratio_list, ratio_index),
                                batch_size=args.batch_size,
                                drop_last=True)

    dataset = RoiDataLoader(roidb, cfg.MODEL.NUM_CLASSES, training=True)

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=batchSampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    dataiterator = iter(dataloader)

    ### Model ###
    maskRCNN = Generalized_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    gn_param_nameset = set()
    for name, module in maskRCNN.named_modules():
        if isinstance(module, nn.GroupNorm):
            gn_param_nameset.add(name + '.weight')
            gn_param_nameset.add(name + '.bias')
    gn_params = []
    gn_param_names = []
    bias_params = []
    bias_param_names = []
    nonbias_params = []
    nonbias_param_names = []
    nograd_param_names = []
    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
                bias_param_names.append(key)
            elif key in gn_param_nameset:
                gn_params.append(value)
                gn_param_names.append(key)
            else:
                nonbias_params.append(value)
                nonbias_param_names.append(key)
        else:
            nograd_param_names.append(key)
    assert (gn_param_nameset - set(nograd_param_names) -
            set(bias_param_names)) == set(gn_param_names)

    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [{
        'params': nonbias_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        bias_params,
        'lr':
        0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }, {
        'params': gn_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN
    }]
    # names of paramerters for each paramter
    param_names = [nonbias_param_names, bias_param_names, gn_param_names]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print(
                        'train_size value: %d different from the one in checkpoint: %d'
                        % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer,
                                                 checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[0][
        'lr']  # lr of non-bias parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name() + '_step'
    output_dir = misc_utils.get_output_dir(args, args.run_name)

    ############
    output_dir = output_dir + '_' + args.dataset

    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):

            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 -
                                                                 alpha) + alpha
                else:
                    raise KeyError(
                        'Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate(optimizer, lr,
                                               cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()
            for inner_iter in range(args.iter_size):
                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)

                for key in input_data:
                    if key != 'roidb':  # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))

                net_outputs = maskRCNN(**input_data)
                training_stats.UpdateIterStats(net_outputs, inner_iter)
                loss = net_outputs['total_loss']
                loss.backward()
            optimizer.step()
            training_stats.IterToc()

            training_stats.LogIterStats(step, lr)

            if (step + 1) % CHECKPOINT_PERIOD == 0:
                save_ckpt(output_dir, args, step, train_size, maskRCNN,
                          optimizer)

        # ---- Training ends ----
        # Save last checkpoint
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

    except (RuntimeError, KeyboardInterrupt):
        del dataiterator
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
示例#3
0
def main():
    """main function"""

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    args = parse_args()
    print('Called with args:')
    print(args)

    assert args.image_dir or args.images
    assert bool(args.image_dir) ^ bool(args.images)

    if args.dataset.startswith("coco"):
        dataset = datasets.get_coco_dataset()
        cfg.MODEL.NUM_CLASSES = len(dataset.classes)
    elif args.dataset.startswith("keypoints_coco"):
        dataset = datasets.get_coco_dataset()
        cfg.MODEL.NUM_CLASSES = 2
    else:
        raise ValueError('Unexpected dataset name: {}'.format(args.dataset))

    print('load cfg from file: {}'.format(args.cfg_file))
    cfg_from_file(args.cfg_file)

    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    assert bool(args.load_ckpt) ^ bool(args.load_detectron), \
        'Exactly one of --load_ckpt and --load_detectron should be specified.'
    cfg.MODEL.LOAD_IMAGENET_PRETRAINED_WEIGHTS = False  # Don't need to load imagenet pretrained weights
    assert_and_infer_cfg()

    maskRCNN = Generalized_RCNN()

    if args.cuda:
        maskRCNN.cuda()

    if args.load_ckpt:
        load_name = args.load_ckpt
        print("loading checkpoint %s" % (load_name))
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])

    if args.load_detectron:
        print("loading detectron weights %s" % args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True,
                                 device_ids=[0])  # only support single GPU

    maskRCNN.eval()
    print(count_parameters(maskRCNN))
    if args.image_dir:
        imglist = misc_utils.get_imagelist_from_dir(args.image_dir)
    else:
        imglist = args.images
    num_images = len(imglist)
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    for i in xrange(num_images):
        print('img', i)
        im = cv2.imread(imglist[i])
        img = np.zeros(list(im.shape), dtype=np.uint8)
        img.fill(255)
        assert im is not None

        timers = defaultdict(Timer)

        cls_boxes, cls_segms, cls_keyps = im_detect_all(maskRCNN,
                                                        im,
                                                        timers=timers)

        im_name, _ = os.path.splitext(os.path.basename(imglist[i]))
        vis_utils.vis_one_image(
            img[:, :, ::-1],  # BGR -> RGB for visualization
            im_name,
            args.output_dir,
            cls_boxes,
            cls_segms,
            cls_keyps,
            dataset=dataset,
            box_alpha=0.3,
            show_class=True,
            thresh=0.7,
            kp_thresh=2)

    if args.merge_pdfs and num_images > 1:
        merge_out_path = '{}/results.pdf'.format(args.output_dir)
        if os.path.exists(merge_out_path):
            os.remove(merge_out_path)
        command = "pdfunite {}/*.pdf {}".format(args.output_dir,
                                                merge_out_path)
        subprocess.call(command, shell=True)
示例#4
0
def main():
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    if args.dataset == "coco2017":
        cfg.TRAIN.DATASETS = ('coco_2017_train',)
    elif args.dataset == "keypoints_coco2017":
        cfg.TRAIN.DATASETS = ('keypoints_coco_2017_train',)
    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    print('Batch size change from {} (in config file) to {}'.format(
        original_batch_size, args.batch_size))
    print('NUM_GPUs: %d, TRAIN.IMS_PER_BATCH: %d' % (cfg.NUM_GPUS, cfg.TRAIN.IMS_PER_BATCH))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Adjust learning based on batch size change linearly
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch size change: {} --> {}'.format(
        old_base_lr, cfg.SOLVER.BASE_LR))

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma

    args.mGPUs = (cfg.NUM_GPUS > 1)

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    train_size = len(roidb)
    logger.info('{:d} roidb entries'.format(train_size))
    logger.info('Takes %.2f sec(s) to construct roidb', timers['roidb'].average_time)

    sampler = MinibatchSampler(ratio_list, ratio_index)
    dataset = RoiDataLoader(
        roidb,
        cfg.MODEL.NUM_CLASSES,
        training=True)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        sampler=sampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)

    assert_and_infer_cfg()

    ### Model ###
    maskRCNN = Generalized_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    bias_params = []
    nonbias_params = []
    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
            else:
                nonbias_params.append(value)
    params = [
        {'params': nonbias_params,
         'lr': cfg.SOLVER.BASE_LR,
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY},
        {'params': bias_params,
         'lr': cfg.SOLVER.BASE_LR * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0}
    ]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name, map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])
        if args.resume:
            assert checkpoint['iters_per_epoch'] == train_size // args.batch_size, \
                "iters_per_epoch should match for resume"
            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer, checkpoint['optimizer'])
            if checkpoint['step'] == (checkpoint['iters_per_epoch'] - 1):
                # Resume from end of an epoch
                args.start_epoch = checkpoint['epoch'] + 1
                args.start_iter = 0
            else:
                # Resume from the middle of an epoch.
                # NOTE: dataloader is not synced with previous state
                args.start_epoch = checkpoint['epoch']
                args.start_iter = checkpoint['step'] + 1
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[0]['lr']  # lr of non-bias parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN, cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    run_name = misc_utils.get_run_name()
    output_dir = misc_utils.get_output_dir(args, run_name)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    iters_per_epoch = int(train_size / args.batch_size)  # drop last
    ckpt_interval_per_epoch = iters_per_epoch // args.ckpt_num_per_epoch
    step = 0
    try:
        logger.info('Training starts !')
        for epoch in range(args.start_epoch, args.start_epoch + args.num_epochs):
            # ---- Start of epoch ----
            loss_avg = 0
            timers['train_loop'].tic()

            # adjust learning rate
            if args.lr_decay_epochs and epoch == args.lr_decay_epochs[0] and args.start_iter == 0:
                args.lr_decay_epochs.pop(0)
                net_utils.decay_learning_rate(optimizer, lr, cfg.SOLVER.GAMMA)
                lr *= cfg.SOLVER.GAMMA

            for step, input_data in zip(range(args.start_iter, iters_per_epoch), dataloader):

                for key in input_data:
                    if key != 'roidb': # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))

                outputs = maskRCNN(**input_data)

                rois_label = outputs['rois_label']
                cls_score = outputs['cls_score']
                bbox_pred = outputs['bbox_pred']
                loss_rpn_cls = outputs['loss_rpn_cls'].mean()
                loss_rpn_bbox = outputs['loss_rpn_bbox'].mean()
                loss_rcnn_cls = outputs['loss_rcnn_cls'].mean()
                loss_rcnn_bbox = outputs['loss_rcnn_bbox'].mean()

                loss = loss_rpn_cls + loss_rpn_bbox + loss_rcnn_cls + loss_rcnn_bbox

                if cfg.MODEL.MASK_ON:
                    loss_rcnn_mask = outputs['loss_rcnn_mask'].mean()
                    loss += loss_rcnn_mask

                if cfg.MODEL.KEYPOINTS_ON:
                    loss_rcnn_keypoints = outputs['loss_rcnn_keypoints'].mean()
                    loss += loss_rcnn_keypoints

                loss_avg += loss.data.cpu().numpy()[0]

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                if (step+1) % ckpt_interval_per_epoch == 0:
                    net_utils.save_ckpt(output_dir, args, epoch, step, maskRCNN, optimizer, iters_per_epoch)

                if (step+1) % args.disp_interval == 0:
                    if (step + 1 - args.start_iter) >= args.disp_interval:  # for the case of resume
                        diff = timers['train_loop'].toc(average=False)
                        loss_avg /= args.disp_interval

                        loss_rpn_cls = loss_rpn_cls.data[0]
                        loss_rpn_bbox = loss_rpn_bbox.data[0]
                        loss_rcnn_cls = loss_rcnn_cls.data[0]
                        loss_rcnn_bbox = loss_rcnn_bbox.data[0]
                        fg_cnt = torch.sum(rois_label.data.ne(0))
                        bg_cnt = rois_label.data.numel() - fg_cnt
                        print("[%s][epoch %2d][iter %4d / %4d]"
                            % (run_name, epoch, step, iters_per_epoch))
                        print("\t\tloss: %.4f, lr: %.2e" % (loss_avg, lr))
                        print("\t\tfg/bg=(%d/%d), time cost: %f" % (fg_cnt, bg_cnt, diff))
                        print("\t\trpn_cls: %.4f, rpn_bbox: %.4f, rcnn_cls: %.4f, rcnn_bbox %.4f"
                            % (loss_rpn_cls, loss_rpn_bbox, loss_rcnn_cls, loss_rcnn_bbox))

                        print_prefix = "\t\t"
                        if cfg.MODEL.MASK_ON:
                            loss_rcnn_mask = loss_rcnn_mask.data[0]
                            print("%srcnn_mask %.4f" % (print_prefix, loss_rcnn_mask))
                            print_prefix = ", "
                        if cfg.MODEL.KEYPOINTS_ON:
                            loss_rcnn_keypoints = loss_rcnn_keypoints.data[0]
                            print("%srcnn_keypoints %.4f" % (print_prefix, loss_rcnn_keypoints))

                        if args.use_tfboard:
                            info = {
                                'loss': loss_avg,
                                'loss_rpn_cls': loss_rpn_cls,
                                'loss_rpn_box': loss_rpn_bbox,
                                'loss_rcnn_cls': loss_rcnn_cls,
                                'loss_rcnn_box': loss_rcnn_bbox,
                            }
                            if cfg.MODEL.MASK_ON:
                                info['loss_rcnn_mask'] = loss_rcnn_mask
                            if cfg.MODEL.KEYPOINTS_ON:
                                info['loss_rcnn_keypoints'] = loss_rcnn_keypoints
                            for tag, value in info.items():
                                tblogger.add_scalar(tag, value, iters_per_epoch * epoch + step)

                    loss_avg = 0
                    timers['train_loop'].tic()

            # ---- End of epoch ----
            # save checkpoint
            net_utils.save_ckpt(output_dir, args, epoch, step, maskRCNN, optimizer, iters_per_epoch)
            # reset timer
            timers['train_loop'].reset()
            # reset starting iter number after first epoch
            args.start_iter = 0

    except (RuntimeError, KeyboardInterrupt) as e:
        print('Save on exception:', e)
        net_utils.save_ckpt(output_dir, args, epoch, step, maskRCNN, optimizer, iters_per_epoch)
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        # ---- Training ends ----
        if args.use_tfboard:
            tblogger.close()
示例#5
0
# load config
dataset = datasets.get_coco_dataset()
cfg.MODEL.NUM_CLASSES = len(dataset.classes)
cfg_from_file('configs/baselines/e2e_mask_rcnn_R-50-C4_1x.yaml')
cfg.MODEL.LOAD_IMAGENET_PRETRAINED_WEIGHTS = False  # Don't need to load imagenet pretrained weights
assert_and_infer_cfg()

# load model
maskRCNN = Generalized_RCNN()
checkpoint = torch.load(
    '/home/work/liupeng11/code/Detectron.pytorch/models/e2e_mask_rcnn_R-50-C4_1x.pth',
    map_location=lambda storage, loc: storage)
net_utils.load_ckpt(maskRCNN, checkpoint['model'])
maskRCNN.eval()
maskRCNN = mynn.DataParallel(maskRCNN,
                             cpu_keywords=['im_info', 'roidb'],
                             minibatch=True,
                             device_ids=[0])

# load image
img_path = "/home/work/liupeng11/code/Detectron.pytorch/demo/sample_images/img1.jpg"
im = cv2.imread(img_path)

# detect bouding boxes and segments
from core.test import im_detect_bbox, im_detect_mask, box_results_with_nms_and_limit, segm_results
scores, boxes, im_scale, blob_conv = im_detect_bbox(maskRCNN, im,
                                                    cfg.TEST.SCALE,
                                                    cfg.TEST.MAX_SIZE, None)
scores, boxes, cls_boxes = box_results_with_nms_and_limit(scores, boxes)
masks = im_detect_mask(maskRCNN, im_scale, boxes, blob_conv)
cls_segms = segm_results(cls_boxes, masks, boxes, im.shape[0], im.shape[1])
cls_keyps = None
示例#6
0
def main():
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    if args.dataset == "vrd":
        cfg.TRAIN.DATASETS = ('vrd_train',)
        cfg.MODEL.NUM_CLASSES = 101
        cfg.MODEL.NUM_PRD_CLASSES = 70  # exclude background
    elif args.dataset == "vg_mini":
        cfg.TRAIN.DATASETS = ('vg_train_mini',)
        cfg.MODEL.NUM_CLASSES = 151
        cfg.MODEL.NUM_PRD_CLASSES = 50  # exclude background
    elif args.dataset == "vg":
        cfg.TRAIN.DATASETS = ('vg_train',)
        cfg.MODEL.NUM_CLASSES = 151
        cfg.MODEL.NUM_PRD_CLASSES = 50  # exclude background
    elif args.dataset == "oi_rel":
        cfg.TRAIN.DATASETS = ('oi_rel_train',)
        # cfg.MODEL.NUM_CLASSES = 62
        cfg.MODEL.NUM_CLASSES = 58
        cfg.MODEL.NUM_PRD_CLASSES = 9  # rel, exclude background
    elif args.dataset == "oi_rel_mini":
        cfg.TRAIN.DATASETS = ('oi_rel_train_mini',)
        # cfg.MODEL.NUM_CLASSES = 62
        cfg.MODEL.NUM_CLASSES = 58
        cfg.MODEL.NUM_PRD_CLASSES = 9  # rel, exclude background
    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    ### Adaptively adjust some configs ###
    
    cfg.SOLVER.BASE_LR = 0.0033
    cfg.SOLVER.GAMMA = 0.33
    
    if args.dataset == "vrd":
        cfg.SOLVER.STEPS = [1000,3000,8000,14000]
        cfg.SOLVER.MAX_ITER = 22680
        cfg.MODEL.STAGE_TWO = True
        cfg.TRAIN.FG_REL_SIZE_PER_IM = 128
        cfg.TRAIN.FG_REL_FRACTION = 0.5
    if args.dataset == "vg":
        cfg.SOLVER.STEPS = [0,90000,120000]
        cfg.SOLVER.MAX_ITER = 125446
        cfg.MODEL.STAGE_TWO = True
        cfg.TRAIN.FG_REL_SIZE_PER_IM = 256
        cfg.TRAIN.FG_REL_FRACTION = 0.5
        #cfg.MODEL.FEATLOSS_WEIGHT = 0.05

    cfg.NUM_GPUS = torch.cuda.device_count()
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH
    original_num_gpus = cfg.NUM_GPUS
    if args.batch_size is None:
        args.batch_size = original_batch_size
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)

    
    
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    effective_batch_size = args.iter_size * args.batch_size
    print('effective_batch_size = batch_size * iter_size = %d * %d' % (args.batch_size, args.iter_size))

    print('Adaptive config changes:')
    print('    effective_batch_size: %d --> %d' % (original_batch_size, effective_batch_size))
    print('    NUM_GPUS:             %d --> %d' % (original_num_gpus, cfg.NUM_GPUS))
    print('    IMS_PER_BATCH:        %d --> %d' % (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH))

    ### Adjust learning based on batch size change linearly
    # For iter_size > 1, gradients are `accumulated`, so lr is scaled based
    # on batch_size instead of effective_batch_size
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch_size change:\n'
          '    BASE_LR: {} --> {}'.format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Adjust solver steps
    step_scale = original_batch_size / effective_batch_size
    old_solver_steps = cfg.SOLVER.STEPS
    old_max_iter = cfg.SOLVER.MAX_ITER
    cfg.SOLVER.STEPS = list(map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS))
    cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5)
    print('Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n'
          '    SOLVER.STEPS: {} --> {}\n'
          '    SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps, cfg.SOLVER.STEPS,
                                                  old_max_iter, cfg.SOLVER.MAX_ITER))

    # Scale FPN rpn_proposals collect size (post_nms_topN) in `collect` function
    # of `collect_and_distribute_fpn_rpn_proposals.py`
    #
    # post_nms_topN = int(cfg[cfg_key].RPN_POST_NMS_TOP_N * cfg.FPN.RPN_COLLECT_SCALE + 0.5)
    if cfg.FPN.FPN_ON and cfg.MODEL.FASTER_RCNN:
        cfg.FPN.RPN_COLLECT_SCALE = cfg.TRAIN.IMS_PER_BATCH / original_ims_per_batch
        print('Scale FPN rpn_proposals collect size directly propotional to the change of IMS_PER_BATCH:\n'
              '    cfg.FPN.RPN_COLLECT_SCALE: {}'.format(cfg.FPN.RPN_COLLECT_SCALE))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma
    assert_and_infer_cfg()

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    
    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb', timers['roidb'].average_time)

    # Effective training sample size for one epoch
    train_size = roidb_size // args.batch_size * args.batch_size

    batchSampler = BatchSampler(
        sampler=MinibatchSampler(ratio_list, ratio_index),
        batch_size=args.batch_size,
        drop_last=True
    )
    dataset = RoiDataLoader(
        roidb,
        cfg.MODEL.NUM_CLASSES,
        training=True)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=batchSampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    dataiterator = iter(dataloader)

    ### Model ###
    maskRCNN = Generalized_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()
        
    ### Optimizer ###
    # record backbone params, i.e., conv_body and box_head params
    gn_params = []
    backbone_bias_params = []
    backbone_bias_param_names = []
    prd_branch_bias_params = []
    prd_branch_bias_param_names = []
    backbone_nonbias_params = []
    backbone_nonbias_param_names = []
    prd_branch_nonbias_params = []
    prd_branch_nonbias_param_names = []
    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'gn' in key:
                gn_params.append(value)
            elif 'Conv_Body' in key or 'Box_Head' in key or 'Box_Outs' in key or 'RPN' in key:
                if 'bias' in key:
                    backbone_bias_params.append(value)
                    backbone_bias_param_names.append(key)
                else:
                    backbone_nonbias_params.append(value)
                    backbone_nonbias_param_names.append(key)
            else:
                if 'bias' in key:
                    prd_branch_bias_params.append(value)
                    prd_branch_bias_param_names.append(key)
                else:
                    prd_branch_nonbias_params.append(value)
                    prd_branch_nonbias_param_names.append(key)
    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [
        {'params': backbone_nonbias_params,
         'lr': 0,
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY},
        {'params': backbone_bias_params,
         'lr': 0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0},
        {'params': prd_branch_nonbias_params,
         'lr': 0,
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY},
        {'params': prd_branch_bias_params,
         'lr': 0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0},
        {'params': gn_params,
         'lr': 0,
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN}
    ]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name, map_location=lambda storage, loc: storage)
        net_utils_rel.load_ckpt_rel(maskRCNN, checkpoint['model'])
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print('train_size value: %d different from the one in checkpoint: %d'
                          % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer, checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    # lr = optimizer.param_groups[0]['lr']  # lr of non-bias parameters, for commmand line outputs.
    lr = optimizer.param_groups[2]['lr']  # lr of non-backbone parameters, for commmand line outputs.
    backbone_lr = optimizer.param_groups[0]['lr']  # lr of backbone parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN, cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name() + '_step_with_prd_cls_v' + str(cfg.MODEL.SUBTYPE)
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    # CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)
    #CHECKPOINT_PERIOD = cfg.SOLVER.MAX_ITER / cfg.TRAIN.SNAPSHOT_FREQ
    CHECKPOINT_PERIOD = cfg.SOLVER.MAX_ITER / 8

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args,
        args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    try:
        logger.info('Training starts !')
        step = args.start_step
        superview = []
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):

            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 - alpha) + alpha
                else:
                    raise KeyError('Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils_rel.update_learning_rate_rel(optimizer, lr, lr_new)
                # lr = optimizer.param_groups[0]['lr']
                lr = optimizer.param_groups[2]['lr']
                backbone_lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils_rel.update_learning_rate_rel(optimizer, lr, cfg.SOLVER.BASE_LR)
                # lr = optimizer.param_groups[0]['lr']
                lr = optimizer.param_groups[2]['lr']
                backbone_lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils_rel.update_learning_rate_rel(optimizer, lr, lr_new)
                # lr = optimizer.param_groups[0]['lr']
                lr = optimizer.param_groups[2]['lr']
                backbone_lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            training_stats.IterTic()
            
            for inner_iter in range(args.iter_size):
                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)
                
                for key in input_data:
                    if key != 'roidb': # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))
                net_outputs = maskRCNN(**input_data)
                optimizer.zero_grad()
                training_stats.UpdateIterStats(net_outputs, step, inner_iter)
                loss = net_outputs['total_loss']
                loss.backward()
                if step == 20000:
                    cfg.immutable(False)
                    cfg.MODEL.MEMORY_SAVE_UPDATE = True
                    cfg.immutable(True)
                if step == 40000:
                    cfg.immutable(False)
                    cfg.MODEL.MEMORY_TRAIN_UPDATE = True
                    cfg.immutable(True)
                    class_count = maskRCNN.module.RelDN.class_count.clone()
                    class_count[class_count==0] = 1
                    memory_ini = maskRCNN.module.RelDN.memory_save / class_count.float().unsqueeze(1)
                    maskRCNN.module.RelDN.memory_train.data = memory_ini.clone()
                    np.save('memory_ini.npy',memory_ini.cpu().numpy())
                    maskRCNN.module.RelDN.mix_scores.fc_hallucinator.load_state_dict(maskRCNN.module.RelDN.mix_scores.linear_classifier.state_dict())

            optimizer.step()
            training_stats.IterToc()
            training_stats.LogIterStats(step, lr, backbone_lr)
            if (step+1) % CHECKPOINT_PERIOD == 0:
                save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

        # ---- Training ends ----
        # Save last checkpoint
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        np.save('result{}.npy'.format(step),superview)
        

    except (RuntimeError, KeyboardInterrupt):
        del dataiterator
        logger.info('Save ckpt on exception ...')
        np.save('result{}.npy'.format(step),superview)
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
示例#7
0
def main():
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    if args.dataset == "coco2017":
        cfg.TRAIN.DATASETS = ('coco_2017_train', )
        cfg.MODEL.NUM_CLASSES = 81
    elif args.dataset == "keypoints_coco2017":
        cfg.TRAIN.DATASETS = ('keypoints_coco_2017_train', )
        cfg.MODEL.NUM_CLASSES = 2
    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    print('Batch size change from {} (in config file) to {}'.format(
        original_batch_size, args.batch_size))
    print('NUM_GPUs: %d, TRAIN.IMS_PER_BATCH: %d' %
          (cfg.NUM_GPUS, cfg.TRAIN.IMS_PER_BATCH))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Adjust learning based on batch size change linearly
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch size change: {} --> {}'.
          format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma
    assert_and_infer_cfg()

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb',
                timers['roidb'].average_time)

    # Effective training sample size for one epoch
    train_size = roidb_size // args.batch_size * args.batch_size

    sampler = MinibatchSampler(ratio_list, ratio_index)
    dataset = RoiDataLoader(roidb, cfg.MODEL.NUM_CLASSES, training=True)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        drop_last=True,
        sampler=sampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    dataiterator = iter(dataloader)

    ### Model ###
    maskRCNN = Generalized_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    bias_params = []
    nonbias_params = []
    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
            else:
                nonbias_params.append(value)
    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [{
        'params': nonbias_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        bias_params,
        'lr':
        0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            assert checkpoint['train_size'] == train_size
            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer,
                                                 checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[0][
        'lr']  # lr of non-bias parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    run_name = misc_utils.get_run_name()
    output_dir = misc_utils.get_output_dir(args, run_name)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    logger.info('Training starts !')
    loss_avg = 0
    try:
        timers['train_loop'].tic()
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):

            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 -
                                                                 alpha) + alpha
                else:
                    raise KeyError(
                        'Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                lr = cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = lr_new
                decay_steps_ind += 1

            try:
                input_data = next(dataiterator)
            except StopIteration:
                dataiterator = iter(dataloader)
                input_data = next(dataiterator)

            for key in input_data:
                if key != 'roidb':  # roidb is a list of ndarrays with inconsistent length
                    input_data[key] = list(map(Variable, input_data[key]))

            outputs = maskRCNN(**input_data)

            rois_label = outputs['rois_label']
            cls_score = outputs['cls_score']
            bbox_pred = outputs['bbox_pred']
            loss_rpn_cls = outputs['loss_rpn_cls'].mean()
            loss_rpn_bbox = outputs['loss_rpn_bbox'].mean()
            loss_rcnn_cls = outputs['loss_rcnn_cls'].mean()
            loss_rcnn_bbox = outputs['loss_rcnn_bbox'].mean()

            loss = loss_rpn_cls + loss_rpn_bbox + loss_rcnn_cls + loss_rcnn_bbox

            if cfg.MODEL.MASK_ON:
                loss_rcnn_mask = outputs['loss_rcnn_mask'].mean()
                loss += loss_rcnn_mask

            if cfg.MODEL.KEYPOINTS_ON:
                loss_rcnn_keypoints = outputs['loss_rcnn_keypoints'].mean()
                loss += loss_rcnn_keypoints

            loss_avg += loss.data.cpu().numpy()[0]

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (step + 1) % CHECKPOINT_PERIOD == 0:
                save_ckpt(output_dir, args, step, train_size, maskRCNN,
                          optimizer)

            if ((step % args.disp_interval == 0 and
                 (step - args.start_step >= args.disp_interval))
                    or step == cfg.SOLVER.MAX_ITER - 1):
                diff = timers['train_loop'].toc(average=False)
                loss_avg /= args.disp_interval

                loss_rpn_cls = loss_rpn_cls.data[0]
                loss_rpn_bbox = loss_rpn_bbox.data[0]
                loss_rcnn_cls = loss_rcnn_cls.data[0]
                loss_rcnn_bbox = loss_rcnn_bbox.data[0]
                fg_cnt = torch.sum(rois_label.data.ne(0))
                bg_cnt = rois_label.data.numel() - fg_cnt
                print("[ %s ][ step %d ]" % (run_name, step))
                print("\t\tloss: %.4f, lr: %.2e" % (loss_avg, lr))
                print("\t\tfg/bg=(%d/%d), time cost: %f" %
                      (fg_cnt, bg_cnt, diff))
                print(
                    "\t\trpn_cls: %.4f, rpn_bbox: %.4f, rcnn_cls: %.4f, rcnn_bbox %.4f"
                    % (loss_rpn_cls, loss_rpn_bbox, loss_rcnn_cls,
                       loss_rcnn_bbox))

                print_prefix = "\t\t"
                if cfg.MODEL.MASK_ON:
                    loss_rcnn_mask = loss_rcnn_mask.data[0]
                    print("%srcnn_mask %.4f" % (print_prefix, loss_rcnn_mask))
                    print_prefix = ", "
                if cfg.MODEL.KEYPOINTS_ON:
                    loss_rcnn_keypoints = loss_rcnn_keypoints.data[0]
                    print("%srcnn_keypoints %.4f" %
                          (print_prefix, loss_rcnn_keypoints))

                if args.use_tfboard and not args.no_save:
                    info = {
                        'lr': lr,
                        'loss': loss_avg,
                        'loss_rpn_cls': loss_rpn_cls,
                        'loss_rpn_box': loss_rpn_bbox,
                        'loss_rcnn_cls': loss_rcnn_cls,
                        'loss_rcnn_box': loss_rcnn_bbox,
                    }
                    if cfg.MODEL.MASK_ON:
                        info['loss_rcnn_mask'] = loss_rcnn_mask
                    if cfg.MODEL.KEYPOINTS_ON:
                        info['loss_rcnn_keypoints'] = loss_rcnn_keypoints
                    for tag, value in info.items():
                        tblogger.add_scalar(tag, value, step)

                loss_avg = 0
                timers['train_loop'].tic()

        # Save last checkpoint
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

    except (RuntimeError, KeyboardInterrupt) as e:
        print('Save on exception:', e)
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        # ---- Training ends ----
        if args.use_tfboard and not args.no_save:
            tblogger.close()
def main():
    """main function"""

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    args = parse_args()
    print('Called with args:')
    print(args)

    assert args.image_dir or args.images
    assert bool(args.image_dir) ^ bool(args.images)

    prefix_path = args.output_dir + '_results'

    if os.path.exists(prefix_path):
        shutil.rmtree(prefix_path)
        os.mkdir(prefix_path)
    else:
        os.mkdir(prefix_path)

    if args.dataset.startswith("coco"):
        dataset = datasets.get_coco_dataset()
        cfg.MODEL.NUM_CLASSES = len(dataset.classes)
    elif args.dataset.startswith("keypoints_coco"):
        dataset = datasets.get_coco_dataset()
        cfg.MODEL.NUM_CLASSES = 2
    else:
        raise ValueError('Unexpected dataset name: {}'.format(args.dataset))

    print('load cfg from file: {}'.format(args.cfg_file))
    cfg_from_file(args.cfg_file)

    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    assert bool(args.load_ckpt) ^ bool(args.load_detectron), \
        'Exactly one of --load_ckpt and --load_detectron should be specified.'
    cfg.MODEL.LOAD_IMAGENET_PRETRAINED_WEIGHTS = False  # Don't need to load imagenet pretrained weights
    assert_and_infer_cfg()

    maskRCNN = Generalized_RCNN()

    if args.cuda:
        maskRCNN.cuda()

    if args.load_ckpt:
        load_name = args.load_ckpt
        print("loading checkpoint %s" % (load_name))
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])

    if args.load_detectron:
        print("loading detectron weights %s" % args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True,
                                 device_ids=[0])  # only support single GPU

    maskRCNN.eval()
    if args.image_dir:
        imglist = misc_utils.get_imagelist_from_dir(args.image_dir)
    else:
        imglist = args.images
    num_images = len(imglist)

    for i in tqdm(range(num_images)):
        im = cv2.imread(imglist[i])
        assert im is not None

        timers = defaultdict(Timer)

        cls_boxes, cls_segms, cls_keyps = im_detect_all(maskRCNN,
                                                        im,
                                                        timers=timers)

        im_name, _ = os.path.splitext(os.path.basename(imglist[i]))

        boxes, _, _, classes = convert_from_cls_format(cls_boxes, cls_segms,
                                                       cls_keyps)
        if classes == []:
            continue
        voc_boxes = np.zeros_like(boxes)
        voc_boxes[:, 0:1] = boxes[:, 4:5]
        voc_boxes[:, 1:3] = boxes[:, 0:2] + 1
        voc_boxes[:, 3:5] = boxes[:, 2:4] + 1

        for instance_idx, cls_idx in enumerate(classes):
            cls_name = dataset.classes[cls_idx]
            if cls_name == 'motorcycle':
                cls_name = 'motorbike'
            f = open(os.path.join(prefix_path, cls_name + ".txt"), "a+")
            f.write("%s " % im_name)
            for item in voc_boxes[instance_idx]:
                f.write("%f " % item)
            f.write("\n")
            f.close()
示例#9
0
def main():
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    if args.dataset == "coco2017":
        cfg.TRAIN.DATASETS = ('coco_2017_train', )
        cfg.MODEL.NUM_CLASSES = 81
    elif args.dataset == "keypoints_coco2017":
        cfg.TRAIN.DATASETS = ('keypoints_coco_2017_train', )
        cfg.MODEL.NUM_CLASSES = 2

    # ADE20k as a detection dataset
    elif args.dataset == "ade_train":
        cfg.TRAIN.DATASETS = ('ade_train', )
        cfg.MODEL.NUM_CLASSES = 446

    # Noisy CS6+WIDER datasets
    elif args.dataset == 'cs6_noise020+WIDER':
        cfg.TRAIN.DATASETS = ('cs6_noise020', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise050+WIDER':
        cfg.TRAIN.DATASETS = ('cs6_noise050', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise080+WIDER':
        cfg.TRAIN.DATASETS = ('cs6_noise080', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise085+WIDER':
        cfg.TRAIN.DATASETS = ('cs6_noise085', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise090+WIDER':
        cfg.TRAIN.DATASETS = ('cs6_noise090', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise095+WIDER':
        cfg.TRAIN.DATASETS = ('cs6_noise095', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise100+WIDER':
        cfg.TRAIN.DATASETS = ('cs6_noise100', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2

    # Just Noisy CS6 datasets
    elif args.dataset == 'cs6_noise020':
        cfg.TRAIN.DATASETS = ('cs6_noise020', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise030':
        cfg.TRAIN.DATASETS = ('cs6_noise030', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise040':
        cfg.TRAIN.DATASETS = ('cs6_noise040', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise050':
        cfg.TRAIN.DATASETS = ('cs6_noise050', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise060':
        cfg.TRAIN.DATASETS = ('cs6_noise060', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise070':
        cfg.TRAIN.DATASETS = ('cs6_noise070', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise080':
        cfg.TRAIN.DATASETS = ('cs6_noise080', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise085':
        cfg.TRAIN.DATASETS = ('cs6_noise085', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise090':
        cfg.TRAIN.DATASETS = ('cs6_noise090', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise095':
        cfg.TRAIN.DATASETS = ('cs6_noise095', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'cs6_noise100':
        cfg.TRAIN.DATASETS = ('cs6_noise100', )
        cfg.MODEL.NUM_CLASSES = 2

    # Cityscapes 7 classes
    elif args.dataset == "cityscapes":
        cfg.TRAIN.DATASETS = ('cityscapes_train', )
        cfg.MODEL.NUM_CLASSES = 8

    # BDD 7 classes
    elif args.dataset == "bdd_any_any_any":
        cfg.TRAIN.DATASETS = ('bdd_any_any_any_train', )
        cfg.MODEL.NUM_CLASSES = 8
    elif args.dataset == "bdd_any_any_daytime":
        cfg.TRAIN.DATASETS = ('bdd_any_any_daytime_train', )
        cfg.MODEL.NUM_CLASSES = 8
    elif args.dataset == "bdd_clear_any_daytime":
        cfg.TRAIN.DATASETS = ('bdd_clear_any_daytime_train', )
        cfg.MODEL.NUM_CLASSES = 8

    # Cistyscapes Pedestrian sets
    elif args.dataset == "cityscapes_peds":
        cfg.TRAIN.DATASETS = ('cityscapes_peds_train', )
        cfg.MODEL.NUM_CLASSES = 2

    # Cityscapes Car sets
    elif args.dataset == "cityscapes_cars_HPlen3+kitti_car_train":
        cfg.TRAIN.DATASETS = ('cityscapes_cars_HPlen3', 'kitti_car_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cityscapes_cars_HPlen5+kitti_car_train":
        cfg.TRAIN.DATASETS = ('cityscapes_cars_HPlen5', 'kitti_car_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cityscapes_car_train+kitti_car_train":
        cfg.TRAIN.DATASETS = ('cityscapes_car_train', 'kitti_car_train')
        cfg.MODEL.NUM_CLASSES = 2

    # KITTI Car set
    elif args.dataset == "kitti_car_train":
        cfg.TRAIN.DATASETS = ('kitti_car_train', )
        cfg.MODEL.NUM_CLASSES = 2

    # BDD pedestrians sets
    elif args.dataset == "bdd_peds":
        cfg.TRAIN.DATASETS = ('bdd_peds_train',
                              )  # bdd peds: clear_any_daytime
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "bdd_peds_full":
        cfg.TRAIN.DATASETS = ('bdd_peds_full_train', )  # bdd peds: any_any_any
        cfg.MODEL.NUM_CLASSES = 2
    # Pedestrians with constraints
    elif args.dataset == "bdd_peds_not_clear_any_daytime":
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_train', )
        cfg.MODEL.NUM_CLASSES = 2
    # Ashish's  20k samples videos
    elif args.dataset == "bdd_peds_not_clear_any_daytime_20k":
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_20k_train', )
        cfg.MODEL.NUM_CLASSES = 2
    # Source domain + Target domain detections
    elif args.dataset == "bdd_peds+DETS_20k":
        cfg.TRAIN.DATASETS = ('bdd_peds_dets_20k_target_domain',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    # Source domain + Target domain detections -- same 18k images as HP18k
    elif args.dataset == "bdd_peds+DETS18k":
        cfg.TRAIN.DATASETS = ('bdd_peds_dets18k_target_domain',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    # Only Dets
    elif args.dataset == "DETS20k":
        cfg.TRAIN.DATASETS = ('bdd_peds_dets_20k_target_domain', )
        cfg.MODEL.NUM_CLASSES = 2
    # Only Dets18k - same images as HP18k
    elif args.dataset == 'DETS18k':
        cfg.TRAIN.DATASETS = ('bdd_peds_dets18k_target_domain', )
        cfg.MODEL.NUM_CLASSES = 2
    # Only HP
    elif args.dataset == 'HP':
        cfg.TRAIN.DATASETS = ('bdd_peds_HP_target_domain', )
        cfg.MODEL.NUM_CLASSES = 2
    # Only HP 18k videos
    elif args.dataset == 'HP18k':
        cfg.TRAIN.DATASETS = ('bdd_peds_HP18k_target_domain', )
        cfg.MODEL.NUM_CLASSES = 2
    # Source domain + Target domain HP
    elif args.dataset == 'bdd_peds+HP':
        cfg.TRAIN.DATASETS = ('bdd_peds_train', 'bdd_peds_HP_target_domain')
        cfg.MODEL.NUM_CLASSES = 2
        # Source domain + Target domain HP 18k videos
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+HP18k':
        cfg.TRAIN.DATASETS = ('bdd_peds_HP18k_target_domain', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    #### Source domain + Target domain with different conf threshold theta ####
    elif args.dataset == 'bdd_peds+HP18k_thresh-050':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_thresh-050', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+HP18k_thresh-060':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_thresh-060', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+HP18k_thresh-070':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_thresh-070', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+HP18k_thresh-090':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_thresh-090', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    ##############################

    #### Data distillation on BDD -- for rebuttal
    elif args.dataset == 'bdd_peds+bdd_data_dist_small':
        cfg.TRAIN.DATASETS = ('bdd_data_dist_small', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+bdd_data_dist_mid':
        cfg.TRAIN.DATASETS = ('bdd_data_dist_mid', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+bdd_data_dist':
        cfg.TRAIN.DATASETS = ('bdd_data_dist', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    ##############################

    #### Source domain + **Labeled** Target domain with varying number of images
    elif args.dataset == 'bdd_peds+labeled_100':
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_train_100',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+labeled_075':
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_train_075',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+labeled_050':
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_train_050',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+labeled_025':
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_train_025',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+labeled_010':
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_train_010',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+labeled_005':
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_train_005',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+labeled_001':
        cfg.TRAIN.DATASETS = ('bdd_peds_not_clear_any_daytime_train_001',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    ##############################

    # Source domain + Target domain HP tracker bboxes only
    elif args.dataset == 'bdd_peds+HP18k_track_only':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_track_only', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2

    ##### subsets of bdd_HP18k with different constraints
    # Source domain + HP tracker images at NIGHT
    elif args.dataset == 'bdd_peds+HP18k_any_any_night':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_any_any_night', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == 'bdd_peds+HP18k_rainy_any_daytime':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_rainy_any_daytime', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+HP18k_rainy_any_night':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_rainy_any_night', 'bdd_peds_train')

        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+HP18k_overcast,rainy_any_daytime':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_overcast,rainy_any_daytime',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+HP18k_overcast,rainy_any_night':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_overcast,rainy_any_night',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == 'bdd_peds+HP18k_overcast,rainy,snowy_any_daytime':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_overcast,rainy,snowy_any_daytime',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    #############  end of bdd constraned subsets  #####################

    # Source domain + Target domain HP18k -- after histogram matching
    elif args.dataset == 'bdd_peds+HP18k_remap_hist':
        cfg.TRAIN.DATASETS = ('bdd_peds_HP18k_target_domain_remap_hist',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+HP18k_remap_cityscape_hist':
        cfg.TRAIN.DATASETS = (
            'bdd_peds_HP18k_target_domain_remap_cityscape_hist',
            'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    # Source domain + Target domain HP18k -- after histogram matching
    elif args.dataset == 'bdd_peds+HP18k_remap_random':
        cfg.TRAIN.DATASETS = ('bdd_peds_HP18k_target_domain_remap_random',
                              'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    # Source+Noisy Target domain -- prevent domain adv from using HP roi info
    elif args.dataset == 'bdd_peds+bdd_HP18k_noisy_100k':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_noisy_100k', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+bdd_HP18k_noisy_080':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_noisy_080', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+bdd_HP18k_noisy_060':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_noisy_060', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == 'bdd_peds+bdd_HP18k_noisy_070':
        cfg.TRAIN.DATASETS = ('bdd_HP18k_noisy_070', 'bdd_peds_train')
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == "wider_train":
        cfg.TRAIN.DATASETS = ('wider_train', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-subset":
        cfg.TRAIN.DATASETS = ('cs6-subset', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-subset-score":
        cfg.TRAIN.DATASETS = ('cs6-subset-score', )
    elif args.dataset == "cs6-subset-gt":
        cfg.TRAIN.DATASETS = ('cs6-subset-gt', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-3013-gt":
        cfg.TRAIN.DATASETS = ('cs6-3013-gt',
                              )  # DEBUG: overfit on one video annots
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-subset-gt+WIDER":
        cfg.TRAIN.DATASETS = ('cs6-subset-gt', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-subset+WIDER":
        cfg.TRAIN.DATASETS = ('cs6-subset', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-gt":
        cfg.TRAIN.DATASETS = ('cs6-train-gt', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-gt-noisy-0.3":
        cfg.TRAIN.DATASETS = ('cs6-train-gt-noisy-0.3', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-gt-noisy-0.5":
        cfg.TRAIN.DATASETS = ('cs6-train-gt-noisy-0.5', )
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == "cs6-train-det-score":
        cfg.TRAIN.DATASETS = ('cs6-train-det-score', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-det-score-0.5":
        cfg.TRAIN.DATASETS = ('cs6-train-det-score-0.5', )
    elif args.dataset == "cs6-train-det":
        cfg.TRAIN.DATASETS = ('cs6-train-det', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-det-0.5":
        cfg.TRAIN.DATASETS = ('cs6-train-det-0.5', )
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == "cs6-train-hp":
        cfg.TRAIN.DATASETS = ('cs6-train-hp', )
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == "cs6-train-easy-gt":
        cfg.TRAIN.DATASETS = ('cs6-train-easy-gt', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-easy-gt-sub":
        cfg.TRAIN.DATASETS = ('cs6-train-easy-gt-sub', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-easy-hp":
        cfg.TRAIN.DATASETS = ('cs6-train-easy-hp', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-easy-det":
        cfg.TRAIN.DATASETS = ('cs6-train-easy-det', )
        cfg.MODEL.NUM_CLASSES = 2

        # Joint training with CS6 and WIDER
    elif args.dataset == "cs6-train-easy-gt-sub+WIDER":
        cfg.TRAIN.DATASETS = ('cs6-train-easy-gt-sub', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cs6-train-gt+WIDER":
        cfg.TRAIN.DATASETS = ('cs6-train-gt', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == "cs6-train-hp+WIDER":
        cfg.TRAIN.DATASETS = ('cs6-train-hp', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == "cs6-train-dummy+WIDER":
        cfg.TRAIN.DATASETS = ('cs6-train-dummy', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2

    elif args.dataset == "cs6-train-det+WIDER":
        cfg.TRAIN.DATASETS = ('cs6-train-det', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2

    # Dets dataset created by removing tracker results from the HP json
    elif args.dataset == "cs6_train_det_from_hp+WIDER":
        cfg.TRAIN.DATASETS = ('cs6_train_det_from_hp', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    # Dataset created by removing det results from the HP json -- HP tracker only
    elif args.dataset == "cs6_train_hp_tracker_only+WIDER":
        cfg.TRAIN.DATASETS = ('cs6_train_hp_tracker_only', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2
    # HP dataset with noisy labels: used to prevent DA from getting any info from HP
    elif args.dataset == "cs6_train_hp_noisy_100+WIDER":
        cfg.TRAIN.DATASETS = ('cs6_train_hp_noisy_100', 'wider_train')
        cfg.MODEL.NUM_CLASSES = 2

    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH
    original_num_gpus = cfg.NUM_GPUS
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    effective_batch_size = args.iter_size * args.batch_size
    print('effective_batch_size = batch_size * iter_size = %d * %d' %
          (args.batch_size, args.iter_size))

    print('Adaptive config changes:')
    print('    effective_batch_size: %d --> %d' %
          (original_batch_size, effective_batch_size))
    print('    NUM_GPUS:             %d --> %d' %
          (original_num_gpus, cfg.NUM_GPUS))
    print('    IMS_PER_BATCH:        %d --> %d' %
          (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH))

    ### Adjust learning based on batch size change linearly
    # For iter_size > 1, gradients are `accumulated`, so lr is scaled based
    # on batch_size instead of effective_batch_size
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch_size change:\n'
          '    BASE_LR: {} --> {}'.format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Adjust solver steps
    step_scale = original_batch_size / effective_batch_size
    old_solver_steps = cfg.SOLVER.STEPS
    old_max_iter = cfg.SOLVER.MAX_ITER
    cfg.SOLVER.STEPS = list(
        map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS))
    cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5)
    print(
        'Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n'
        '    SOLVER.STEPS: {} --> {}\n'
        '    SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps,
                                                cfg.SOLVER.STEPS, old_max_iter,
                                                cfg.SOLVER.MAX_ITER))

    # Scale FPN rpn_proposals collect size (post_nms_topN) in `collect` function
    # of `collect_and_distribute_fpn_rpn_proposals.py`
    #
    # post_nms_topN = int(cfg[cfg_key].RPN_POST_NMS_TOP_N * cfg.FPN.RPN_COLLECT_SCALE + 0.5)
    if cfg.FPN.FPN_ON and cfg.MODEL.FASTER_RCNN:
        cfg.FPN.RPN_COLLECT_SCALE = cfg.TRAIN.IMS_PER_BATCH / original_ims_per_batch
        print(
            'Scale FPN rpn_proposals collect size directly propotional to the change of IMS_PER_BATCH:\n'
            '    cfg.FPN.RPN_COLLECT_SCALE: {}'.format(
                cfg.FPN.RPN_COLLECT_SCALE))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma
    assert_and_infer_cfg()

    timers = defaultdict(Timer)
    """def _init_fn(worker_id):
        random.seed(999)
        np.random.seed(999)
        torch.cuda.manual_seed(999)
        torch.cuda.manual_seed_all(999)
        torch.manual_seed(999)
        torch.backends.cudnn.deterministic = True
    """

    ### Dataset ###
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb',
                timers['roidb'].average_time)
    if cfg.TRAIN.JOINT_TRAINING:
        if len(cfg.TRAIN.DATASETS) == 2:
            print('Joint training on two datasets')
        else:
            raise NotImplementedError

        joint_training_roidb = []
        for i, dataset_name in enumerate(cfg.TRAIN.DATASETS):
            # ROIDB construction
            timers['roidb'].tic()
            roidb, ratio_list, ratio_index = combined_roidb_for_training(
                (dataset_name), cfg.TRAIN.PROPOSAL_FILES)
            timers['roidb'].toc()
            roidb_size = len(roidb)
            logger.info('{:d} roidb entries'.format(roidb_size))
            logger.info('Takes %.2f sec(s) to construct roidb',
                        timers['roidb'].average_time)

            if i == 0:
                roidb_size = len(roidb)

            batchSampler = BatchSampler(sampler=MinibatchSampler(
                ratio_list, ratio_index),
                                        batch_size=args.batch_size,
                                        drop_last=True)
            dataset = RoiDataLoader(roidb,
                                    cfg.MODEL.NUM_CLASSES,
                                    training=True)
            dataloader = torch.utils.data.DataLoader(
                dataset,
                batch_sampler=batchSampler,
                num_workers=cfg.DATA_LOADER.NUM_THREADS,
                collate_fn=collate_minibatch)
            #worker_init_fn=_init_fn)
            # decrease num-threads when using two dataloaders
            dataiterator = iter(dataloader)

            joint_training_roidb.append({
                'dataloader': dataloader,
                'dataiterator': dataiterator,
                'dataset_name': dataset_name
            })
    else:
        roidb, ratio_list, ratio_index = combined_roidb_for_training(
            cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
        timers['roidb'].toc()
        roidb_size = len(roidb)
        logger.info('{:d} roidb entries'.format(roidb_size))
        logger.info('Takes %.2f sec(s) to construct roidb',
                    timers['roidb'].average_time)

        batchSampler = BatchSampler(sampler=MinibatchSampler(
            ratio_list, ratio_index),
                                    batch_size=args.batch_size,
                                    drop_last=True)
        dataset = RoiDataLoader(roidb, cfg.MODEL.NUM_CLASSES, training=True)
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_sampler=batchSampler,
            num_workers=cfg.DATA_LOADER.NUM_THREADS,
            collate_fn=collate_minibatch)
        #worker_init_fn=init_fn)
        dataiterator = iter(dataloader)

    # Effective training sample size for one epoch
    train_size = roidb_size // args.batch_size * args.batch_size

    if cfg.TRAIN.JOINT_SELECTIVE_FG:
        orig_fg_batch_ratio = cfg.TRAIN.FG_FRACTION

    ### Model ###
    maskRCNN = Generalized_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    gn_param_nameset = set()
    for name, module in maskRCNN.named_modules():
        if isinstance(module, nn.GroupNorm):
            gn_param_nameset.add(name + '.weight')
            gn_param_nameset.add(name + '.bias')
    gn_params = []
    gn_param_names = []
    bias_params = []
    bias_param_names = []
    nonbias_params = []
    nonbias_param_names = []
    nograd_param_names = []
    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
                bias_param_names.append(key)
            elif key in gn_param_nameset:
                gn_params.append(value)
                gn_param_names.append(key)
            else:
                nonbias_params.append(value)
                nonbias_param_names.append(key)
        else:
            nograd_param_names.append(key)
    assert (gn_param_nameset - set(nograd_param_names) -
            set(bias_param_names)) == set(gn_param_names)

    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [{
        'params': nonbias_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        bias_params,
        'lr':
        0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }, {
        'params': gn_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN
    }]
    # names of paramerters for each paramter
    param_names = [nonbias_param_names, bias_param_names, gn_param_names]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print(
                        'train_size value: %d different from the one in checkpoint: %d'
                        % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer,
                                                 checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[0][
        'lr']  # lr of non-bias parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name() + '_' + str(
        cfg.TRAIN.DATASETS) + '_step'
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):
            """
            random.seed(cfg.RNG_SEED)
            np.random.seed(cfg.RNG_SEED)
            torch.cuda.manual_seed(cfg.RNG_SEED)
            torch.cuda.manual_seed_all(cfg.RNG_SEED)
            torch.manual_seed(cfg.RNG_SEED)
            torch.backends.cudnn.deterministic = True
            """

            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 -
                                                                 alpha) + alpha
                else:
                    raise KeyError(
                        'Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate(optimizer, lr,
                                               cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()

            for inner_iter in range(args.iter_size):
                # use a iter counter for optional alternating batches
                if args.iter_size == 1:
                    iter_counter = step
                else:
                    iter_counter = inner_iter

                if cfg.TRAIN.JOINT_TRAINING:
                    # alternate batches between dataset[0] and dataset[1]
                    if iter_counter % 2 == 0:
                        if True:  #DEBUG:
                            print('Dataset: %s' %
                                  joint_training_roidb[0]['dataset_name'])
                        dataloader = joint_training_roidb[0]['dataloader']
                        dataiterator = joint_training_roidb[0]['dataiterator']
                        # NOTE: if available FG samples cannot fill minibatch
                        # then batchsize will be smaller than cfg.TRAIN.BATCH_SIZE_PER_IM.
                    else:
                        if True:  #DEBUG:
                            print('Dataset: %s' %
                                  joint_training_roidb[1]['dataset_name'])
                        dataloader = joint_training_roidb[1]['dataloader']
                        dataiterator = joint_training_roidb[1]['dataiterator']

                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    # end of epoch for dataloader
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)
                    if cfg.TRAIN.JOINT_TRAINING:
                        if iter_counter % 2 == 0:
                            joint_training_roidb[0][
                                'dataiterator'] = dataiterator
                        else:
                            joint_training_roidb[1][
                                'dataiterator'] = dataiterator

                for key in input_data:
                    if key != 'roidb':  # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))

                net_outputs = maskRCNN(**input_data)
                training_stats.UpdateIterStats(net_outputs, inner_iter)
                loss = net_outputs['total_loss']
                # [p.data.get_device() for p in maskRCNN.parameters()]
                # [(name, p.data.get_device()) for name, p in maskRCNN.named_parameters()]
                loss.backward()

            optimizer.step()
            training_stats.IterToc()

            training_stats.LogIterStats(step, lr)

            if (step + 1) % CHECKPOINT_PERIOD == 0:
                save_ckpt(output_dir, args, step, train_size, maskRCNN,
                          optimizer)

        # ---- Training ends ----
        # Save last checkpoint
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

    except (RuntimeError, KeyboardInterrupt):
        del dataiterator
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
示例#10
0
def main():
    """main function"""

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    args = parse_args()
    print('Called with args:')
    print(args)
    print('load cfg from file: {}'.format(args.cfg_file))
    cfg_from_file(args.cfg_file)

    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    #Do not use RPN.
    #cfg.MODEL.FASTER_RCNN = False

    dataset = datasets.get_coco_dataset()
    cfg.MODEL.NUM_CLASSES = len(dataset.classes)

    assert bool(args.load_ckpt) ^ bool(args.load_detectron), \
        'Exactly one of --load_ckpt and --load_detectron should be specified.'
    cfg.MODEL.LOAD_IMAGENET_PRETRAINED_WEIGHTS = False  # Don't need to load imagenet pretrained weights
    assert_and_infer_cfg()
    train_db = davis_db.DAVIS_imdb(split='val')
    for seq_idx in range(train_db.get_num_sequence()):
        train_db.set_to_sequence(seq_idx)
        seq_name = train_db.get_current_seq_name()
        save_dir = osp.join(args.output_dir, seq_name)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        else:
            merge_out_path = '{}/results.pdf'.format(save_dir)
            if osp.exists(merge_out_path):
                continue
        for idx in range(train_db.get_current_seq_length()):
            im_name = '%02d.pdf' % (idx)
            print(osp.join(save_dir, im_name))
            if osp.exists(osp.join(save_dir, im_name)):
                continue
            im = train_db.get_image_cv2(idx)
            boxes = train_db.get_bboxes(idx)

            new_boxes = []
            for bbox in boxes:
                new_box = []
                new_box.extend(bbox)
                new_box[2] = new_box[0] + new_box[2]
                new_box[3] = new_box[1] + new_box[3]
                new_boxes.append(new_box)

            boxes = np.array(new_boxes, np.float32)
            print(boxes.shape)
            if boxes.shape[0] > 0:
                device = torch.device(
                    "cuda:0" if torch.cuda.is_available() else "cpu")
                #boxes = torch.tensor(boxes,device = device)

                maskRCNN_predictor_with_boxes = Generalized_RCNN_Predictor_with_Boxes(
                )

                if args.cuda:
                    maskRCNN_predictor_with_boxes.cuda()

                if args.load_ckpt:
                    load_name = args.load_ckpt
                    print("loading checkpoint %s" % (load_name))
                    checkpoint = torch.load(
                        load_name, map_location=lambda storage, loc: storage)
                    net_utils.load_ckpt(maskRCNN_predictor_with_boxes,
                                        checkpoint['model'])

                if args.load_detectron:
                    print("loading detectron weights %s" % args.load_detectron)
                    load_detectron_weight(maskRCNN_predictor_with_boxes,
                                          args.load_detectron)

                maskRCNN_predictor_with_boxes = mynn.DataParallel(
                    maskRCNN_predictor_with_boxes,
                    cpu_keywords=['im_info', 'roidb'],
                    minibatch=True,
                    device_ids=[0])  # only support single GPU

                maskRCNN_predictor_with_boxes.eval()

                assert im is not None
                timers = defaultdict(Timer)

                cls_boxes, cls_segms, cls_keyps = im_detect_all(
                    maskRCNN_predictor_with_boxes,
                    im,
                    timers=timers,
                    box_proposals=boxes)
            else:
                cls_boxes = None
                cls_segms = None
                cls_keyps = None
            im_name = '%02d' % (idx)

            vis_utils.vis_one_image(
                im[:, :, ::-1],  # BGR -> RGB for visualization
                im_name,
                save_dir,
                cls_boxes,
                cls_segms,
                cls_keyps,
                dataset=dataset,
                box_alpha=0.3,
                show_class=True,
                thresh=0.01,
                kp_thresh=2)
        if args.merge_pdfs:
            merge_out_path = '{}/results.pdf'.format(save_dir)
            if os.path.exists(merge_out_path):
                os.remove(merge_out_path)
            command = "pdfunite {}/*.pdf {}".format(save_dir, merge_out_path)
            subprocess.call(command, shell=True)
示例#11
0
  correct_state_dict = {k:tmp_state_dict['module.'+k] for k in fasterRCNN.state_dict()}
  fasterRCNN.load_state_dict(correct_state_dict)
  # fasterRCNN.load_state_dict(checkpoint['model'])
  if 'pooling_mode' in checkpoint.keys():
    cfg.POOLING_MODE = checkpoint['pooling_mode']


  print('load model successfully!')

  if args.cuda:
    cfg.CUDA = True

  if args.cuda:
    fasterRCNN.cuda()

  fasterRCNN = mynn.DataParallel(fasterRCNN, minibatch=True)

  start = time.time()
  max_per_image = 100

  vis = args.vis

  if vis:
    thresh = 0.05
  else:
    thresh = 0.0

  save_name = 'faster_rcnn_{}_{}_{}'.format(args.checksession, args.checkepoch, args.checkpoint)
  num_images = len(imdb.image_index)

示例#12
0
def main():
    """main function"""

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    args = parse_args()
    print('Called with args:')
    print(args)

    assert args.image_dir or args.images
    assert bool(args.image_dir) ^ bool(args.images)

    if args.dataset.startswith("coco"):
        dataset = datasets.get_coco_dataset()
        cfg.MODEL.NUM_CLASSES = len(dataset.classes)
    elif args.dataset.startswith("keypoints_coco"):
        dataset = datasets.get_coco_dataset()
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset.startswith("gangjin"):
        dataset = datasets.get_gangjin_dataset()
        cfg.MODEL.NUM_CLASSES = len(dataset.classes)
    else:
        raise ValueError('Unexpected dataset name: {}'.format(args.dataset))

    print('load cfg from file: {}'.format(args.cfg_file))
    cfg_from_file(args.cfg_file)

    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    assert bool(args.load_ckpt) ^ bool(args.load_detectron), \
        'Exactly one of --load_ckpt and --load_detectron should be specified.'
    cfg.MODEL.LOAD_IMAGENET_PRETRAINED_WEIGHTS = False  # Don't need to load imagenet pretrained weights
    assert_and_infer_cfg()

    maskRCNN = Generalized_RCNN()

    if args.cuda:
        maskRCNN.cuda()

    if args.load_ckpt:
        load_name = args.load_ckpt
        print("loading checkpoint %s" % (load_name))
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])

    if args.load_detectron:
        print("loading detectron weights %s" % args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True,
                                 device_ids=[0])  # only support single GPU

    maskRCNN.eval()
    if args.image_dir:
        imglist = misc_utils.get_imagelist_from_dir(args.image_dir)
    else:
        imglist = args.images
    num_images = len(imglist)
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    img_ids = []
    rects = []

    for i in range(num_images):
        print('img', i)
        im = cv2.imread(imglist[i])
        assert im is not None

        timers = defaultdict(Timer)

        cls_boxes, cls_segms, cls_keyps = im_detect_all(maskRCNN,
                                                        im,
                                                        timers=timers)

        boxes, segms, keypoints, classes = vis_utils.convert_from_cls_format(
            cls_boxes, cls_segms, cls_keyps)
        if boxes is not None:
            for j in range(len(boxes)):
                # print(boxes[j][-1])
                if float(boxes[j][-1]) < 0.99:  # 阀值
                    continue
                xmin = float(boxes[j, 0])
                xmax = float(boxes[j, 2])
                ymin = float(boxes[j, 1])
                ymax = float(boxes[j, 3])
                img_ids.append(os.path.basename(imglist[i]))
                rects.append(
                    str(xmin) + " " + str(ymin) + " " + str(xmax) + " " +
                    str(ymax))

        # im_name, _ = os.path.splitext(os.path.basename(imglist[i]))
        # vis_utils.vis_one_image(
        #     im[:, :, ::-1],  # BGR -> RGB for visualization
        #     im_name,
        #     args.output_dir,
        #     cls_boxes,
        #     cls_segms,
        #     cls_keyps,
        #     dataset=dataset,
        #     box_alpha=0.3,
        #     show_class=False,
        #     thresh=0.99,
        #     kp_thresh=2,
        #     ext="jpg"
        # )

    result_dict = {"ID": img_ids, "rects": rects}
    import pandas as pd
    result = pd.DataFrame.from_dict(result_dict)

    result.to_csv('submit/submit1.csv', header=None, index=False)
示例#13
0
def main():
    """main function"""

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    args = parse_args()
    print('Called with args:')
    print(args)

    assert args.image_dir or args.images
    assert bool(args.image_dir) ^ bool(args.images)

    prefix_path = args.output_dir

    os.makedirs(prefix_path, exist_ok=True)

    if args.dataset.startswith("coco"):
        dataset = datasets.get_coco_dataset()
        cfg.MODEL.NUM_CLASSES = len(dataset.classes)
    elif args.dataset.startswith("keypoints_coco"):
        dataset = datasets.get_coco_dataset()
        cfg.MODEL.NUM_CLASSES = 2
    else:
        raise ValueError('Unexpected dataset name: {}'.format(args.dataset))

    print('load cfg from file: {}'.format(args.cfg_file))
    cfg_from_file(args.cfg_file)

    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    assert bool(args.load_ckpt) ^ bool(args.load_detectron), \
        'Exactly one of --load_ckpt and --load_detectron should be specified.'
    cfg.MODEL.LOAD_IMAGENET_PRETRAINED_WEIGHTS = False  # Don't need to load imagenet pretrained weights
    assert_and_infer_cfg()

    maskRCNN = Generalized_RCNN()

    if args.cuda:
        maskRCNN.cuda()

    if args.load_ckpt:
        load_name = args.load_ckpt
        print("loading checkpoint %s" % (load_name))
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])

    if args.load_detectron:
        print("loading detectron weights %s" % args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True,
                                 device_ids=[0])  # only support single GPU

    maskRCNN.eval()
    if args.image_dir:
        imglist = misc_utils.get_imagelist_from_dir(args.image_dir)
    else:
        imglist = args.images
    num_images = len(imglist)

    writen_results = []

    # validate
    demo_im = cv2.imread(imglist[0])
    print(np.shape(demo_im))
    h, w, _ = np.shape(demo_im)
    #print(h)
    #print(args.height)
    assert h == args.height
    assert w == args.width
    h_scale = 720 / args.height
    w_scale = 1280 / args.width

    for i in tqdm(range(num_images)):
        im = cv2.imread(imglist[i])
        assert im is not None

        timers = defaultdict(Timer)

        cls_boxes, cls_segms, cls_keyps = im_detect_all(maskRCNN,
                                                        im,
                                                        timers=timers)

        im_name, _ = os.path.splitext(os.path.basename(imglist[i]))

        # boxs = [[x1, y1, x2, y2, cls], ...]
        boxes, _, _, classes = convert_from_cls_format(cls_boxes, cls_segms,
                                                       cls_keyps)

        if boxes is None:
            continue
        # scale
        boxes[:, 0] = boxes[:, 0] * w_scale
        boxes[:, 2] = boxes[:, 2] * w_scale
        boxes[:, 1] = boxes[:, 1] * h_scale
        boxes[:, 3] = boxes[:, 3] * h_scale

        if classes == []:
            continue

        for instance_idx, cls_idx in enumerate(classes):
            cls_name = dataset.classes[cls_idx]
            if cls_name == 'motorcycle':
                cls_name = 'motor'
            elif cls_name == 'stop sign':
                cls_name = 'traffic sign'
            elif cls_name == 'bicycle':
                cls_name = 'bike'
            if cls_name not in bdd_category:
                continue

            writen_results.append({
                "name": imglist[i].split('/')[-1],
                "timestamp": 1000,
                "category": cls_name,
                "bbox": boxes[instance_idx, :4],
                "score": boxes[instance_idx, -1]
            })

    with open(os.path.join(prefix_path, args.name + '.json'),
              'w') as outputfile:
        json.dump(writen_results, outputfile, cls=MyEncoder)
示例#14
0
from __future__ import absolute_import
示例#15
0
def main():
    """main function"""

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    args = parse_args()
    print('Called with args:')
    print(args)

    assert args.image_dir or args.images
    assert bool(args.image_dir) ^ bool(args.images)

    dataset = datasets.get_hospital_dataset()
    cfg.MODEL.NUM_CLASSES = 20  # with bg
    num_class = cfg.MODEL.NUM_CLASSES
    sents = dataset.sents
    th_cls = dataset.th_cls
    cls2eng = dataset.cls2eng
    eng2type = dataset.eng2type

    print('load cfg from file: {}'.format(args.cfg_file))
    cfg_from_file(args.cfg_file)

    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    assert bool(args.load_ckpt) ^ bool(args.load_detectron), \
        'Exactly one of --load_ckpt and --load_detectron should be specified.'
    cfg.MODEL.LOAD_IMAGENET_PRETRAINED_WEIGHTS = False  # Don't need to load imagenet pretrained weights
    assert_and_infer_cfg()

    maskRCNN = Generalized_RCNN()

    if args.cuda:
        maskRCNN.cuda()

    if args.load_ckpt:
        load_name = args.load_ckpt
        print("loading checkpoint %s" % (load_name))
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])

    if args.load_detectron:
        print("loading detectron weights %s" % args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True,
                                 device_ids=[0])  # only support single GPU

    maskRCNN.eval()
    if args.image_dir:
        imglist = misc_utils.get_imagelist_from_dir(args.image_dir)
    else:
        imglist = args.images
    num_images = len(imglist)
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    for i in xrange(num_images):  # for each image
        print('img', i)
        im = cv2.imread(imglist[i])
        assert im is not None

        # segmentation
        # d = segment(im)
        # pdb.set_trace()

        timers = defaultdict(Timer)

        # detection
        cls_boxes, cls_segms, cls_keyps = im_detect_all(maskRCNN,
                                                        im,
                                                        timers=timers)

        # first we collect boxes from all classes
        dets_total = np.empty([0, 6], dtype=np.float32)
        for cls in range(1, num_class):  # for each cls
            dets = cls_boxes[cls]
            if dets.shape[0] == 0:
                continue
            dets_extend = np.pad(
                dets,
                ((0, 0),
                 (0, 1)),  # add 0 rows above, below and left, but 1 row right
                mode='constant',
                constant_values=cls)  # append cls to dets
            dets_total = np.vstack((dets_total, dets_extend))

        # then use a loose NMS to make each region has only one symptom
        keep = box_utils.nms(dets_total, 0.7)
        nms_dets = dets_total[keep, :]

        # iterate through remained boxes
        report, healthy = '', True
        have_sym_of_cls = [False for _ in range(num_class)]

        n = nms_dets.shape[0]
        final_results = []  # return to the web
        for idx in range(n):  # for each region
            th, cls = nms_dets[idx, -2], int(nms_dets[idx, -1])
            if th > th_cls[cls]:  # diagnosed to have the sym
                report += sents[cls][1]
                have_sym_of_cls[cls] = True
                healthy = False

                ename = cls2eng[int(cls)]
                _type = eng2type[ename]
                final_results.append({
                    'name': ename,
                    'type': _type,
                    'box': list(nms_dets[idx, 0:4])
                })

        for cls in range(1, num_class):  # for each cls
            if not have_sym_of_cls[cls]:  # if have no sym of this cls
                report += sents[cls][0]

        if healthy:
            report = sents[0][0]
        print(report)

        pdb.set_trace()

        # healthy = True  # flag indicating healthy or not
        # for cls in range(1, num_class):  # for each cls
        #     dets = cls_boxes[cls]
        #     if dets.shape[0] == 0:
        #         report += sents[cls][0]
        #         continue
        #     n = dets.shape[0]
        #     flag = False  # indicates if the sym exists
        #     for k in range(n):  # for each region
        #         if dets[k, -1] > th_cls[cls]:  # above threshold for this cls, means have this cls of symptom
        #             report += sents[cls][1]
        #             flag = True
        #             healthy = False
        #     if not flag:  # don't have this symptom
        #         report += sents[cls][0]
        #
        # if healthy:  # use the report for healthy people
        #     report = sents[0][0]

        im_name, _ = os.path.splitext(os.path.basename(imglist[i]))
        # vis_utils.vis_one_image(
        #     im[:, :, ::-1],  # BGR -> RGB for visualization
        #     im_name,
        #     args.output_dir,
        #     cls_boxes,
        #     cls_segms,
        #     cls_keyps,
        #     dataset=dataset,
        #     box_alpha=0.3,
        #     show_class=True,
        #     thresh=0.05,
        #     kp_thresh=2
        # )

    if args.merge_pdfs and num_images > 1:
        merge_out_path = '{}/results.pdf'.format(args.output_dir)
        if os.path.exists(merge_out_path):
            os.remove(merge_out_path)
        command = "pdfunite {}/*.pdf {}".format(args.output_dir,
                                                merge_out_path)
        subprocess.call(command, shell=True)
def main():
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    cfg.DATASET = args.dataset
    if args.dataset == "vrd":
        cfg.TRAIN.DATASETS = ('vrd_train',)
        cfg.MODEL.NUM_CLASSES = 101
        cfg.MODEL.NUM_PRD_CLASSES = 70  # exclude background
    elif args.dataset == "vg":
        cfg.TRAIN.DATASETS = ('vg_train',)
        cfg.MODEL.NUM_CLASSES = 151
        cfg.MODEL.NUM_PRD_CLASSES = 50  # exclude background
    elif args.dataset == "vg80k":
        cfg.TRAIN.DATASETS = ('vg80k_train',)
        cfg.MODEL.NUM_CLASSES = 53305 # includes background
        cfg.MODEL.NUM_PRD_CLASSES = 29086  # excludes background
    elif args.dataset == "gvqa20k":
        cfg.TRAIN.DATASETS = ('gvqa20k_train',)
        cfg.MODEL.NUM_CLASSES = 1704 # includes background
        cfg.MODEL.NUM_PRD_CLASSES = 310  # exclude background
    elif args.dataset == "gvqa10k":
        cfg.TRAIN.DATASETS = ('gvqa10k_train',)
        cfg.MODEL.NUM_CLASSES = 1704 # includes background
        cfg.MODEL.NUM_PRD_CLASSES = 310  # exclude background
    elif args.dataset == "gvqa":
        cfg.TRAIN.DATASETS = ('gvqa_train',)
        cfg.MODEL.NUM_CLASSES = 1704 # includes background
        cfg.MODEL.NUM_PRD_CLASSES = 310  # exclude background

    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH
    original_num_gpus = cfg.NUM_GPUS
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    effective_batch_size = args.iter_size * args.batch_size
    print('effective_batch_size = batch_size * iter_size = %d * %d' % (args.batch_size, args.iter_size))

    print('Adaptive config changes:')
    print('    effective_batch_size: %d --> %d' % (original_batch_size, effective_batch_size))
    print('    NUM_GPUS:             %d --> %d' % (original_num_gpus, cfg.NUM_GPUS))
    print('    IMS_PER_BATCH:        %d --> %d' % (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH))

    ### Adjust learning based on batch size change linearly
    # For iter_size > 1, gradients are `accumulated`, so lr is scaled based
    # on batch_size instead of effective_batch_size
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch_size change:\n'
          '    BASE_LR: {} --> {}'.format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Adjust solver steps
    step_scale = original_batch_size / effective_batch_size
    old_solver_steps = cfg.SOLVER.STEPS
    old_max_iter = cfg.SOLVER.MAX_ITER
    cfg.SOLVER.STEPS = list(map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS))
    cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5)
    print('Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n'
          '    SOLVER.STEPS: {} --> {}\n'
          '    SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps, cfg.SOLVER.STEPS,
                                                  old_max_iter, cfg.SOLVER.MAX_ITER))


    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma
    assert_and_infer_cfg()

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb', timers['roidb'].average_time)

    # Effective training sample size for one epoch
    train_size = roidb_size // args.batch_size * args.batch_size

    batchSampler = BatchSampler(
        sampler=MinibatchSampler(ratio_list, ratio_index),
        batch_size=args.batch_size,
        drop_last=True
    )
    dataset = RoiDataLoader(
        roidb,
        cfg.MODEL.NUM_CLASSES,
        training=True)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=batchSampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    dataiterator = iter(dataloader)

    ### Model ###
    maskRCNN = Generalized_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()
        
    ### Optimizer ###
    # record backbone params, i.e., conv_body and box_head params
    gn_params = []
    backbone_bias_params = []
    backbone_bias_param_names = []
    prd_branch_bias_params = []
    prd_branch_bias_param_names = []
    backbone_nonbias_params = []
    backbone_nonbias_param_names = []
    prd_branch_nonbias_params = []
    prd_branch_nonbias_param_names = []
    classifier_params = []
    classifier_param_names = []
    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'gn' in key:
                gn_params.append(value)
            elif 'Conv_Body' in key or 'Box_Head' in key or 'Box_Outs' in key or 'RPN' in key:
                if 'bias' in key:
                    backbone_bias_params.append(value)
                    backbone_bias_param_names.append(key)
                else:
                    backbone_nonbias_params.append(value)
                    backbone_nonbias_param_names.append(key)
            #elif 'classifier' in key:
            #    classifier_params.append(value)
            #    classifier_param_names.append(value)
            else:
                if 'bias' in key:
                    prd_branch_bias_params.append(value)
                    prd_branch_bias_param_names.append(key)
                else:
                    prd_branch_nonbias_params.append(value)
                    prd_branch_nonbias_param_names.append(key)
    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [
        {'params': backbone_nonbias_params,
         'lr': 0,
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY},
        {'params': backbone_bias_params,
         'lr': 0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0},
        {'params': prd_branch_nonbias_params,
         'lr': 0,
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY},
        {'params': prd_branch_bias_params,
         'lr': 0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0},
        {'params': gn_params,
         'lr': 0,
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN}#,
        #{'params': classifier_params,
        # 'lr': 0.1, 'momentum': 0.9, 'weight_decay': 0.0005}
    ]

    # print('Initializing model optimizer.')

    # if cfg.SOLVER.TYPE == "SGD":
    #     optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    # elif cfg.SOLVER.TYPE == "Adam":
    #     optimizer = torch.optim.Adam(params)


    print('Initializing model and classifier optimizers.')
    #classifier_optim_param = {'lr': 0.1, 'momentum': 0.9, 'weight_decay': 0.0005}
    #params.append({'params': maskRCNN.classifier.parameters(),
    #                'lr': classifier_optim_param['lr']),
    #                'momentum': classifier_optim_param['momentum'],
    #                'weight_decay': classifier_optim_param['weight_decay']})
    #params.append({'params': maskRCNN.prd_classifier.parameters(),
    #                'lr': classifier_optim_param['lr'],
    #                'momentum': classifier_optim_param['momentum'],
    #                'weight_decay': classifier_optim_param['weight_decay']})

    if cfg.MODEL.MEMORY_MODULE_STAGE == 1:
        step_size = 10
    elif cfg.MODEL.MEMORY_MODULE_STAGE == 2:
        step_size = 20
    else:
        raise NotImplementedError

    scheduler_params = {'step_size': step_size, 'gamma': 0.1}

    optimizer, optimizer_scheduler = init_optimizers(params, scheduler_params)

    criterion_optimizer, criterion_optimizer_scheduler = None, None
    if cfg.MODEL.MEMORY_MODULE_STAGE == 2:
        print('Initializing criterion optimizer.')
        feat_loss_optim_param = {'lr': 0.01, 'momentum': 0.9, 'weight_decay': 0.0005}

        optim_params = feat_loss_optim_param
        optim_params = [{'params': maskRCNN.feature_loss_sbj_obj.parameters(),
                         'lr': optim_params['lr'],
                         'momentum': optim_params['momentum'],
                         'weight_decay': optim_params['weight_decay']},
                        {'params': maskRCNN.feature_loss_prd.parameters(),
                         'lr': optim_params['lr'],
                         'momentum': optim_params['momentum'],
                         'weight_decay': optim_params['weight_decay']}
                        ]

        # Initialize criterion optimizer and scheduler
        criterion_optimizer, criterion_optimizer_scheduler = init_optimizers(optim_params, scheduler_params)
    if cfg.MODEL.MEMORY_MODULE_STAGE == 2:
        weights_path = 'Outputs/e2e_relcnn_VGG16_8_epochs_gvqa_y_loss_only_1_gpu/gvqa/Feb07-10-55-03_login104-09_step_with_prd_cls_v3/ckpt/model_step1439.pth'
        weights = torch.load(weights_path)
        #print('weights', weights['model'].keys())
        #print(maskRCNN.state_dict().keys())
        maskRCNN.load_state_dict(weights['model'], strict=False)
        #print(list(maskRCNN.parameters()))
        #print(maskRCNN.state_dict().keys())
        #print(maskRCNN.state_dict()['prd_classifier.fc_hallucinator.weight'] == weights['model']['prd_classifier.fc.weight'])
        #print(torch.all(torch.eq(maskRCNN.state_dict()['prd_classifier.fc_hallucinator.weight'], weights['model']['prd_classifier.fc.weight'])))
        #print(torch.all(torch.eq(maskRCNN.state_dict()['Box_Head.heads.0.weight'], weights['model']['Box_Head.heads.0.weight'])))
        #exit()

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name, map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print('train_size value: %d different from the one in checkpoint: %d'
                          % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer, checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[2]['lr']  # lr of non-backbone parameters, for commmand line outputs.
    backbone_lr = optimizer.param_groups[0]['lr']  # lr of backbone parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN, cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name() + '_step_with_prd_cls_v' + str(cfg.MODEL.SUBTYPE)
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    # CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)
    # CHECKPOINT_PERIOD = cfg.SOLVER.MAX_ITER / cfg.TRAIN.SNAPSHOT_FREQ
    CHECKPOINT_PERIOD = 200000

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args,
        args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):
            # optimizer_scheduler.step()
            # if criterion_optimizer:
            #     criterion_optimizer_scheduler.step()
            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 - alpha) + alpha
                else:
                    raise KeyError('Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate_rel(optimizer, lr, lr_new)
                lr = optimizer.param_groups[2]['lr']
                backbone_lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate_rel(optimizer, lr, cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[2]['lr']
                backbone_lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate_rel(optimizer, lr, lr_new)
                if criterion_optimizer:
                    net_utils.update_learning_rate_rel(criterion_optimizer, lr, lr_new)
                lr = optimizer.param_groups[2]['lr']
                backbone_lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new

                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()
            if criterion_optimizer:
                criterion_optimizer.zero_grad()
            for inner_iter in range(args.iter_size):
                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)
                #print('input_data', [torch.isnan(x) for x in input_data.values()])    
                
                for key in input_data:
                    if key != 'roidb': # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))
                #with autograd.detect_anomaly():
                net_outputs = maskRCNN(**input_data)
                training_stats.UpdateIterStats(net_outputs, inner_iter)
                loss = net_outputs['total_loss']
                loss.backward()
            optimizer.step()
            if criterion_optimizer:
                criterion_optimizer.step()
            training_stats.IterToc()

            training_stats.LogIterStats(step, lr, backbone_lr)

            if (step+1) % CHECKPOINT_PERIOD == 0:
                print('Saving Checkpoint..')
                save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

        # ---- Training ends ----
        # Save last checkpoint
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

    except Exception as e:
        del dataiterator
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
示例#17
0

cfg_file = 'configs/baselines/e2e_pspnet-50_2x.yaml'
cfg_from_file(cfg_file)
#cfg_from_list(cfg_file)
assert_and_infer_cfg()

devices_ids = [1, 2]
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
    [str(ids) for ids in devices_ids])
torch.backends.cudnn.benchmark = True
#torch.cuda.set_device(3)
len_gpus = len(devices_ids)
batch_size = 2 * len_gpus
#net = mynn.DataParallel(load_net().to('cuda'), minibatch=True)
net = mynn.DataParallel(Generalized_SEGDISP().to('cuda'), minibatch=True)
optimizer = optim.SGD(net.parameters(), lr=0.000875, momentum=0.9)
criterion = nn.NLLLoss(ignore_index=255)
dataloader = dataloader(batch_size, len_gpus)
#for i in range(10):
for i, inputs in zip(range(1000), dataloader):
    #data, label= dataloader(batch_size, len_gpus)
    #data = Variable(data).to('cuda')
    #data  = torch.chunk(data, chunks=len_gpus, dim=0)
    #label = Variable(label).to('cuda')
    #assert torch.all(data >= 0) and torch.all(data < 19), 'label is error'
    for key in inputs:
        inputs[key] = list(map(lambda x: Variable(x).to('cuda'), inputs[key]))
    data, label = inputs['data'], inputs['semseg_label_0']
    #cv2.imwrite('ims.png', data[0].cpu().numpy()[0].transpose(1,2,0)[:,:,::-1])
    label = torch.cat(label, 0)
def main():
    """main function"""

    os.environ['CUDA_VISIBLE_DEVICES'] = '1'
    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    args = parse_args()
    print('Called with args:')
    print(args)

    assert args.image_dir or args.images
    assert bool(args.image_dir) ^ bool(args.images)

    if args.dataset.startswith("coco"):
        dataset = datasets.get_coco_dataset()
        cfg.MODEL.NUM_CLASSES = len(dataset.classes)
    elif args.dataset.startswith("keypoints_coco"):
        dataset = datasets.get_coco_dataset()
        cfg.MODEL.NUM_CLASSES = 2
    else:
        raise ValueError('Unexpected dataset name: {}'.format(args.dataset))

    print('load cfg from file: {}'.format(args.cfg_file))
    cfg_from_file(args.cfg_file)

    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    assert bool(args.load_ckpt) ^ bool(args.load_detectron), \
        'Exactly one of --load_ckpt and --load_detectron should be specified.'
    cfg.MODEL.LOAD_IMAGENET_PRETRAINED_WEIGHTS = False  # Don't need to load imagenet pretrained weights
    assert_and_infer_cfg()

    maskRCNN = Generalized_RCNN()

    if args.cuda:
        maskRCNN.cuda()

    if args.load_ckpt:
        load_name = args.load_ckpt
        print("loading checkpoint %s" % (load_name))
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])

    if args.load_detectron:
        print("loading detectron weights %s" % args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True,
                                 device_ids=[0])  # only support single GPU

    maskRCNN.eval()

    params = list(maskRCNN.parameters())
    k = 0
    for i in params:
        l = 1
        for j in i.size():
            l *= j
        k = k + l
    print('zonghe:' + str(k))

    if args.image_dir:
        imglist = misc_utils.get_imagelist_from_dir(args.image_dir)
    else:
        imglist = args.images
    num_images = len(imglist)
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    for i in xrange(num_images):
        print('img', i)
        im = cv2.imread(imglist[i])
        assert im is not None

        timers = defaultdict(Timer)
        start = time.time()
        cls_boxes, cls_segms, cls_keyps = im_detect_all(maskRCNN,
                                                        im,
                                                        timers=timers)
        class_result_boxes = []
        for index, class_boxes in enumerate(cls_boxes):
            if len(class_boxes) != 0:
                class_boxes = class_boxes.tolist()
                results_oneclass = threeD_detect(imglist[i], class_boxes,
                                                 index)
                class_result_boxes.append(results_oneclass)
        save_image = im
        color_class = {
            'Car': [0, 255, 255],
            'Cyclist': [255, 0, 0],
            'Pedestrian': [0, 0, 255]
        }
        for result_boxes in class_result_boxes:
            for box in result_boxes:
                cv2.rectangle(save_image, (box[0], box[1]), (box[2], box[3]),
                              color_class[box[-1]], 2)
                height = round(box[-2][0], 2)
                width = round(box[-2][1], 2)
                length = round(box[-2][2], 2)
                threeD_info = str(height) + ' ' + str(width) + ' ' + str(
                    length)
                cv2.putText(save_image, threeD_info, (box[0], box[1] - 20),
                            cv2.FONT_HERSHEY_COMPLEX, 1, (255, 0, 0), 2)
                _, imagename = os.path.split(imglist[i])
                imagename2 = imagename.split('.')[0]
                cv2.imwrite('../output1/%s.png' % imagename2, save_image)

        end = time.time()
        print(end - start)
        im_name, _ = os.path.splitext(os.path.basename(imglist[i]))
        vis_utils.vis_one_image(
            im[:, :, ::-1],  # BGR -> RGB for visualization
            im_name,
            args.output_dir,
            cls_boxes,
            cls_segms,
            cls_keyps,
            dataset=dataset,
            box_alpha=0.3,
            show_class=True,
            thresh=0.7,
            kp_thresh=2)

    if args.merge_pdfs and num_images > 1:
        merge_out_path = '{}/results.pdf'.format(args.output_dir)
        if os.path.exists(merge_out_path):
            os.remove(merge_out_path)
        command = "pdfunite {}/*.pdf {}".format(args.output_dir,
                                                merge_out_path)
        subprocess.call(command, shell=True)
示例#19
0
    return image


cfg_file = 'configs/baselines/e2e_pspnet-101_2x.yaml'
cfg_from_file(cfg_file)

devices_ids = [5]
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
    [str(ids) for ids in devices_ids])

pretrained_model = './output/pspnet50_2gpu_single_scale/Oct20-12-41-16_localhost.localdomain/ckpt/model_17_1486.pth'
checkpoint = torch.load(pretrained_model)

net = Generalized_SEMSEG()
net.load_state_dict(checkpoint['model'])
net = mynn.DataParallel(net.to('cuda'), minibatch=True)
net.eval()

# params = net.state_dict() #查看权重是否导入
# params_ckpt = checkpoint['model']
# a = list(params.values())
# b = list(params_ckpt.values())
# keys = list(params.keys())
# print(is_equal(a, b, keys))

len_gpus = len(devices_ids)
batch_size = 1 * len_gpus
dataloader = dataloader(batch_size, len_gpus)

loader = transforms.Compose([
    transforms.ToTensor(),
示例#20
0
def main():
    saveNetStructure = False
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        #set gpu device
        os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
            [str(ids) for ids in args.device_ids])
        torch.backends.cudnn.benchmark = True
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    if args.dataset == "coco2017":
        cfg.TRAIN.DATASETS = ('coco_2017_train', )
        cfg.MODEL.NUM_CLASSES = 81
    elif args.dataset == "keypoints_coco2017":
        cfg.TRAIN.DATASETS = ('keypoints_coco_2017_train', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cityscapes":
        cfg.TRAIN.DATASETS = ('cityscapes_semseg_train', )
        cfg.MODEL.NUM_CLASSES = 19
    elif args.dataset == "cityscape_train_on_val":
        cfg.TRAIN.DATASETS = ('cityscape_train_on_val', )
        cfg.MODEL.NUM_CLASSES = 19
    elif args.dataset == "cityscapes_coarse":
        cfg.TRAIN.DATASETS = ('cityscapes_coarse', )
        cfg.MODEL.NUM_CLASSES = 19
    elif args.dataset == "cityscapes_trainval":
        cfg.TRAIN.DATASETS = ('cityscapes_trainval', )
        cfg.MODEL.NUM_CLASSES = 19
    elif args.dataset == "cityscapes_all":
        cfg.TRAIN.DATASETS = ('cityscapes_all', )
        cfg.MODEL.NUM_CLASSES = 19
    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    print('Batch size change from {} (in config file) to {}'.format(
        original_batch_size, args.batch_size))
    print('NUM_GPUs: %d, TRAIN.IMS_PER_BATCH: %d' %
          (cfg.NUM_GPUS, cfg.TRAIN.IMS_PER_BATCH))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Adjust learning based on batch size change linearly
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch size change: {} --> {}'.
          format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    if cfg.SEM.SEM_ON or cfg.DISP.DISP_ON:
        roidb, ratio_list, ratio_index = combined_roidb_for_training_semseg(
            cfg.TRAIN.DATASETS)
    else:
        roidb, ratio_list, ratio_index = combined_roidb_for_training(
            cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    train_size = len(roidb)
    logger.info('{:d} roidb entries'.format(train_size))
    logger.info('Takes %.2f sec(s) to construct roidb',
                timers['roidb'].average_time)

    #sampler = MinibatchSampler(ratio_list, ratio_index)
    sampler = None
    dataset = RoiDataLoader(roidb, cfg.MODEL.NUM_CLASSES, training=True)
    #dataloader = torch.utils.data.DataLoader(
    #    dataset,
    #    batch_size=args.batch_size,
    #    sampler=sampler,
    #    num_workers=cfg.DATA_LOADER.NUM_THREADS,
    #    collate_fn=collate_minibatch_semseg if cfg.SEM.SEM_ON or cfg.DISP.DISP_ON else collate_minibatch,
    #    drop_last=False,
    #    shuffle=True) # when load image will be shuffle in each epoch
    ## Dataset and Loader
    #dataset_train = TrainDataset(
    #    args.list_train, args, batch_per_gpu=args.batch_size_per_gpu)
    #args.epoch_iters=dataset_train.num_sample//(args.num_gpus*args.batch_size_per_gpu)
    dataloader = torchdata.DataLoader(
        dataset,
        batch_size=args.batch_size,  # we have modified data_parallel
        collate_fn=collate_minibatch_semseg,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        drop_last=True,
        shuffle=True,
        maxsize=cfg.TRAIN.IMS_PER_BATCH * 2)

    assert_and_infer_cfg()
    #for data in dataloader:
    #    image = data['data'][0][0].numpy()
    #    print (image.shape)
    #    image=image.transpose(1,2,0)+cfg.PIXEL_MEANS
    #    cv2.imwrite('image.png', image[:,:,::-1])
    #    cv2.imwrite('label.png',10*data['semseg_label_0'][0][0].numpy())
    #    return

    maskRCNN = eval(cfg.MODEL.TYPE)()
    if len(cfg.SEM.PSPNET_PRETRAINED_WEIGHTS) > 1:
        print("loading pspnet weights")
        state_dict = {}
        pretrained = torch.load(cfg.SEM.PSPNET_PRETRAINED_WEIGHTS,
                                map_location=lambda storage, loc: storage)
        pretrained = pretrained['model']
        if cfg.SEM.SPN_ON:
            maskRCNN.pspnet.load_state_dict(pretrained, strict=True)
        #elif 'deform_conv' in cfg.MODEL.CONV_BODY or 'deeplab' in cfg.SEM.DECODER_TYPE:
        elif 'deeplab' in cfg.SEM.DECODER_TYPE and 'uber' in cfg.SEM.PSPNET_PRETRAINED_WEIGHTS:
            encoder = dict()
            for k, v in pretrained.items():
                if 'decoder' in k:
                    continue
                encoder[k.replace('encoder.', '')] = v
            maskRCNN.encoder.load_state_dict(encoder, strict=False)
        else:
            maskRCNN.load_state_dict(pretrained, strict=False)
        print("weights load success")

    if cfg.SEM.SPN_ON:
        maskRCNN.pspnet.eval()
        for p in maskRCNN.pspnet.parameters():
            p.requires_grad = False

    if cfg.CUDA:
        maskRCNN.to('cuda')
    #print(maskRCNN)
    ### Optimizer ###
    bias_params = []
    nonbias_params = []
    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
            else:
                nonbias_params.append(value)
    params = [{
        'params': nonbias_params,
        'lr': cfg.SOLVER.BASE_LR,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        bias_params,
        'lr':
        cfg.SOLVER.BASE_LR * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
        print("Using STEP as Lr reduce policy!")
    if cfg.SOLVER.TYPE == 'SGD' and cfg.SOLVER.LR_POLICY == 'ReduceLROnPlateau':
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                                  'min',
                                                                  patience=10)
        print("Using ReduceLROnPlateau as Lr reduce policy!")
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)
    elif "poly" in cfg.SOLVER.TYPE:
        optimizer = create_optimizers(maskRCNN, args)
        print("Using Poly as Lr reduce policy!")

    args.max_iters = (int(train_size / args.batch_size)) * args.num_epochs
    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])
        if args.resume:
            assert checkpoint['iters_per_epoch'] == train_size // args.batch_size, \
                "iters_per_epoch should match for resume"
            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer,
                                                 checkpoint['optimizer'])
            if checkpoint['step'] == (checkpoint['iters_per_epoch'] - 1):
                # Resume from end of an epoch
                args.start_epoch = checkpoint['epoch'] + 1
                args.start_iter = 0
            else:
                # Resume from the middle of an epoch.
                # NOTE: dataloader is not synced with previous state
                args.start_epoch = checkpoint['epoch']
                args.start_iter = checkpoint['step'] + 1
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    if cfg.SOLVER.TYPE == 'step_poly':
        lr = cfg.SOLVER.BASE_LR / (cfg.SOLVER.GAMMA**len(args.lr_decay_epochs))
    else:
        lr = optimizer.param_groups[0][
            'lr']  # lr of non-bias parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name()
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            #from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    training_stats = TrainingStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)

    iters_per_epoch = int(train_size / args.batch_size)  # drop last
    args.iters_per_epoch = iters_per_epoch
    ckpt_interval_per_epoch = iters_per_epoch // args.ckpt_num_per_epoch
    try:
        logger.info('Training starts !')
        args.step = args.start_iter
        global_step = iters_per_epoch * args.start_epoch + args.step
        for args.epoch in range(args.start_epoch,
                                args.start_epoch + args.num_epochs):
            # ---- Start of epoch ----
            # adjust learning rate
            if args.lr_decay_epochs and args.epoch == args.lr_decay_epochs[
                    0] and args.start_iter == 0 and cfg.SOLVER.LR_POLICY == 'steps_with_decay':
                args.lr_decay_epochs.pop(0)
                net_utils.decay_learning_rate(optimizer, lr, cfg.SOLVER.GAMMA)
                lr *= cfg.SOLVER.GAMMA

            for args.step, input_data in zip(
                    range(args.start_iter, iters_per_epoch), dataloader):

                if cfg.DISP.DISP_ON:
                    input_data['data'] = list(
                        map(lambda x, y: torch.cat((x, y), dim=0),
                            input_data['data'], input_data['data_R']))
                    if cfg.SEM.DECODER_TYPE.endswith('3D'):
                        input_data['disp_scans'] = torch.arange(
                            1, cfg.DISP.MAX_DISPLACEMENT + 1).float().view(
                                1, cfg.DISP.MAX_DISPLACEMENT).repeat(
                                    args.batch_size, 1)
                    del input_data['data_R']

                for key in input_data:
                    if key != 'roidb':  # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(
                            map(
                                lambda x: Variable(x, requires_grad=False).to(
                                    'cuda'), input_data[key]))
                training_stats.IterTic()
                net_outputs = maskRCNN(**input_data)
                training_stats.UpdateIterStats(net_outputs)
                #loss = net_outputs['losses']['loss_semseg']
                #acc  = net_outputs['metrics']['accuracy_pixel']
                #print (loss.item(), acc)
                #for key in net_outputs.keys():
                #    print(key)
                loss = net_outputs['total_loss']

                #print("loss.shape:",loss)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                if cfg.SOLVER.TYPE == 'poly':
                    lr = adjust_learning_rate(optimizer, global_step, args)

                if cfg.SOLVER.TYPE == 'step_poly':
                    lr = step_adjust_learning_rate(optimizer, lr, global_step,
                                                   args)

                training_stats.IterToc()

                if args.step % args.disp_interval == 0:
                    disp_image = ''
                    semseg_image = ''
                    #tblogger.add_image('disp_image',disp_image,global_step)
                    #tblogger.add_image('semseg_image',semseg_image,global_step)
                    log_training_stats(training_stats, global_step, lr)
                global_step += 1
            # ---- End of epoch ----
            # save checkpoint
            if cfg.SOLVER.TYPE == 'SGD' and cfg.SOLVER.LR_POLICY == 'ReduceLROnPlateau':
                lr_scheduler.step(loss)
                lr = optimizer.param_groups[0]['lr']
            if (args.epoch + 1) % args.ckpt_num_per_epoch == 0:
                net_utils.save_ckpt(output_dir, args, maskRCNN, optimizer)
            # reset starting iter number after first epoch
            args.start_iter = 0

        # ---- Training ends ----
        #if iters_per_epoch % args.disp_interval != 0:
        # log last stats at the end
        #    log_training_stats(training_stats, global_step, lr)
        # save final model
        if (args.epoch + 1) % args.ckpt_num_per_epoch:
            net_utils.save_ckpt(output_dir, args, maskRCNN, optimizer)
    except (RuntimeError, KeyboardInterrupt):
        logger.info('Save ckpt on exception ...')
        net_utils.save_ckpt(output_dir, args, maskRCNN, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
def main():
    """main function"""

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    args = parse_args()
    print('Called with args:')
    print(args)

    assert args.image_dir or args.images
    assert bool(args.image_dir) ^ bool(args.images)

    if args.dataset == "pascal_parts_heads":
        dataset = datasets.get_head_dataset()
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "scuthead_a":
        dataset = datasets.get_head_dataset()
        cfg.MODEL.NUM_CLASSES = 2
    else:
        raise ValueError('Unexpected dataset name: {}'.format(args.dataset))

    print('load cfg from file: {}'.format(args.cfg_file))
    cfg_from_file(args.cfg_file)

    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    assert bool(args.load_ckpt) ^ bool(args.load_detectron), \
        'Exactly one of --load_ckpt and --load_detectron should be specified.'
    cfg.MODEL.LOAD_IMAGENET_PRETRAINED_WEIGHTS = False  # Don't need to load imagenet pretrained weights
    assert_and_infer_cfg()

    maskRCNN = Generalized_RCNN()

    if args.cuda:
        maskRCNN.cuda()

    if args.load_ckpt:
        load_name = args.load_ckpt
        print("loading checkpoint %s" % (load_name))
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])

    if args.load_detectron:
        print("loading detectron weights %s" % args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True,
                                 device_ids=[0])  # only support single GPU

    maskRCNN.eval()
    if args.image_dir:
        imglist = misc_utils.get_imagelist_from_dir(args.image_dir)
    else:
        imglist = args.images
    num_images = len(imglist)
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    for i in xrange(num_images):

        im = cv2.imread(imglist[i])
        assert im is not None

        timers = defaultdict(Timer)

        cls_boxes, cls_segms, cls_keyps = im_detect_all(maskRCNN,
                                                        im,
                                                        timers=timers)

        im_name, _ = os.path.splitext(os.path.basename(imglist[i]))
        outputfile = os.path.join(args.output_dir, im_name)

        head_boxes = cls_boxes[1]
        print('img :', i, '   num_heads :', len(head_boxes), ' img_path :',
              imglist[i])
        np.save(outputfile, head_boxes)
def main():
    """Main function"""
    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    merge_cfg_from_file(args.cfg_file)

    # Some manual adjustment for the ApolloScape dataset parameters here
    cfg.OUTPUT_DIR = args.output_dir
    cfg.TRAIN.DATASETS = 'Car3D'
    cfg.MODEL.NUM_CLASSES = 8
    if cfg.CAR_CLS.SIM_MAT_LOSS:
        cfg.MODEL.NUMBER_CARS = 79
    else:
        # Loss is only cross entropy, hence, we detect only car categories in the training set.
        cfg.MODEL.NUMBER_CARS = 34
    cfg.TRAIN.MIN_AREA = 196   # 14*14
    cfg.TRAIN.USE_FLIPPED = False  # Currently I don't know how to handle the flipped case
    cfg.TRAIN.IMS_PER_BATCH = 1

    cfg.NUM_GPUS = torch.cuda.device_count()
    effective_batch_size = cfg.TRAIN.IMS_PER_BATCH * cfg.NUM_GPUS * args.iter_size

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH
    original_num_gpus = cfg.NUM_GPUS
    if args.batch_size is None:
        args.batch_size = original_batch_size

    assert (args.batch_size % cfg.NUM_GPUS) == 0, 'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)

    print('effective_batch_size = batch_size * iter_size = %d * %d' % (args.batch_size, args.iter_size))
    print('Adaptive config changes:')
    print('    effective_batch_size: %d --> %d' % (original_batch_size, effective_batch_size))
    print('    NUM_GPUS:             %d --> %d' % (original_num_gpus, cfg.NUM_GPUS))
    print('    IMS_PER_BATCH:        %d --> %d' % (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH))

    ### Adjust learning based on batch size change linearly
    # For iter_size > 1, gradients are `accumulated`, so lr is scaled based
    # on batch_size instead of effective_batch_size
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch_size change:\n BASE_LR: {} --> {}'.format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Adjust solver steps
    step_scale = original_batch_size / effective_batch_size
    old_solver_steps = cfg.SOLVER.STEPS
    old_max_iter = cfg.SOLVER.MAX_ITER
    cfg.SOLVER.STEPS = list(map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS))
    cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5)
    print('Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n'
          '    SOLVER.STEPS: {} --> {}\n'
          '    SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps, cfg.SOLVER.STEPS, old_max_iter, cfg.SOLVER.MAX_ITER))

    # Scale FPN rpn_proposals collect size (post_nms_topN) in `collect` function
    # of `collect_and_distribute_fpn_rpn_proposals.py`
    #
    # post_nms_topN = int(cfg[cfg_key].RPN_POST_NMS_TOP_N * cfg.FPN.RPN_COLLECT_SCALE + 0.5)
    if cfg.FPN.FPN_ON and cfg.MODEL.FASTER_RCNN:
        cfg.FPN.RPN_COLLECT_SCALE = cfg.TRAIN.IMS_PER_BATCH / original_ims_per_batch
        print('Scale FPN rpn_proposals collect size directly propotional to the change of IMS_PER_BATCH:\n'
              '    cfg.FPN.RPN_COLLECT_SCALE: {}'.format(cfg.FPN.RPN_COLLECT_SCALE))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma
    assert_and_infer_cfg()

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    if cfg.MODEL.LOSS_3D_2D_ON:
        roidb, ratio_list, ratio_index, ds = combined_roidb_for_training(cfg.TRAIN.DATASETS, args.dataset_dir)
    else:
        roidb, ratio_list, ratio_index = combined_roidb_for_training(cfg.TRAIN.DATASETS, args.dataset_dir)

    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb', timers['roidb'].average_time)

    # Effective training sample size for one epoch
    train_size = roidb_size // args.batch_size * args.batch_size

    sampler = MinibatchSampler(ratio_list, ratio_index)
    dataset = RoiDataLoader(
        roidb,
        cfg.MODEL.NUM_CLASSES,
        training=True,
        valid_keys=['has_visible_keypoints', 'boxes', 'seg_areas', 'gt_classes', 'gt_overlaps', 'box_to_gt_ind_map',
                    'is_crowd', 'car_cat_classes', 'poses', 'quaternions', 'im_info'])
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        drop_last=True,
        sampler=sampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    dataiterator = iter(dataloader)

    ### Model ###
    if cfg.MODEL.LOSS_3D_2D_ON:
        maskRCNN = Generalized_RCNN(ds.Car3D)
    else:
        maskRCNN = Generalized_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    bias_params = []
    bias_param_names = []
    nonbias_params = []
    nonbias_param_names = []
    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
                bias_param_names.append(key)
            else:
                nonbias_params.append(value)
                nonbias_param_names.append(key)
    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [
        {'params': nonbias_params,
         'lr': 0,
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY},
        {'params': bias_params,
         'lr': 0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0}
    ]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name, map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'], ignore_list=args.ckpt_ignore_head)
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print('train_size value: %d different from the one in checkpoint: %d'
                          % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer, checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  # TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[0]['lr']  # lr of non-bias parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN, cpu_keywords=['im_info', 'roidb'], minibatch=True)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name() + '_step'
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()
    CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)
    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args,
        args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):
            # Warm up
            warmup_factor_trans = 1.0
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 - alpha) + alpha
                    # warmup_factor_trans = cfg.SOLVER.WARM_UP_FACTOR_TRANS * (1 - alpha) + alpha
                    # warmup_factor_trans *= cfg.TRANS_HEAD.LOSS_BETA
                    warmup_factor_trans = 1.0
                else:
                    raise KeyError('Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate(optimizer, lr, cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                            step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()
            for inner_iter in range(args.iter_size):
                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)

                for key in input_data:
                    if key != 'roidb':  # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))

                net_outputs = maskRCNN(**input_data)

                if not cfg.TRAIN.HOMOSCEDASTIC:
                    net_outputs['losses']['loss_car_cls'] *= cfg.CAR_CLS.CAR_CLS_LOSS_BETA
                    net_outputs['losses']['loss_rot'] *= cfg.CAR_CLS.ROT_LOSS_BETA
                    if cfg.MODEL.TRANS_HEAD_ON:
                        net_outputs['losses']['loss_trans'] *= cfg.TRANS_HEAD.TRANS_LOSS_BETA

                training_stats.UpdateIterStats_car_3d(net_outputs)

                # start training
                loss = net_outputs['losses']['loss_car_cls'] + net_outputs['losses']['loss_rot']
                if cfg.MODEL.TRANS_HEAD_ON:
                    loss += net_outputs['losses']['loss_trans']
                if cfg.MODEL.LOSS_3D_2D_ON:
                    loss += net_outputs['losses']['UV_projection_loss']
                if not cfg.TRAIN.FREEZE_CONV_BODY and not cfg.TRAIN.FREEZE_RPN and not cfg.TRAIN.FREEZE_FPN:
                    loss += net_outputs['total_loss_conv']

                loss.backward()
            optimizer.step()
            training_stats.IterToc()

            training_stats.LogIterStats(step, lr, warmup_factor_trans)

            if (step + 1) % CHECKPOINT_PERIOD == 0:
                save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

        # ---- Training ends ----
        # Save last checkpoint
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

    except (RuntimeError, KeyboardInterrupt):
        del dataiterator
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
示例#23
0
        dataset,
        batch_size=bs,
        sampler=sampler,
        num_workers=gpus,
        collate_fn=collate_minibatch_semseg)
    return dataloader
    # return torch.randn(bs*gpus, 3, 720, 720), \
    #         torch.LongTensor(np.random.randint(0, 19, (bs*gpus, 90, 90), dtype=np.long))

devices_ids=[3,4]
os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
torch.backends.cudnn.benchmark=True
#torch.cuda.set_device(3)
len_gpus = len(devices_ids)
batch_size = 2 * len_gpus
net = mynn.DataParallel(load_net().to('cuda'), minibatch=True)
optimizer = optim.SGD(net.parameters(), lr=0.000875, momentum=0.9)
criterion = nn.NLLLoss(ignore_index=255)
dataloader= dataloader(batch_size, len_gpus)
#for i in range(10):
for i, inputs in zip(range(1000), dataloader):
    #data, label= dataloader(batch_size, len_gpus)
    #data = Variable(data).to('cuda')
    #data  = torch.chunk(data, chunks=len_gpus, dim=0)
    #label = Variable(label).to('cuda')
    #assert torch.all(data >= 0) and torch.all(data < 19), 'label is error'
    for key in inputs:
        inputs[key] = list(map(lambda x:Variable(x).to('cuda'), inputs[key]))
    data, label = inputs['data'], inputs['semseg_label_0']
    print(i)
    #cv2.imwrite('ims.png', data[0].cpu().numpy()[0].transpose(1,2,0)[:,:,::-1])
示例#24
0
def main():
    """Main function"""
    args = parse_cfg()

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb',
                timers['roidb'].average_time)

    # Effective training sample size for one epoch
    train_size = roidb_size // args.batch_size * args.batch_size

    batchSampler = BatchSampler(sampler=MinibatchSampler(
        ratio_list, ratio_index),
                                batch_size=args.batch_size,
                                drop_last=True)
    dataset = RoiDataLoader(roidb, cfg.MODEL.NUM_CLASSES, training=True)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=batchSampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    dataiterator = iter(dataloader)

    ### Model ###
    maskRCNN = GetRCNNModel()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    gn_param_nameset = set()
    for name, module in maskRCNN.named_modules():
        if isinstance(module, nn.GroupNorm):
            gn_param_nameset.add(name + '.weight')
            gn_param_nameset.add(name + '.bias')
    gn_params = []
    gn_param_names = []
    bias_params = []
    bias_param_names = []
    nonbias_params = []
    nonbias_param_names = []
    nograd_param_names = []
    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
                bias_param_names.append(key)
            elif key in gn_param_nameset:
                gn_params.append(value)
                gn_param_names.append(key)
            else:
                nonbias_params.append(value)
                nonbias_param_names.append(key)
        else:
            nograd_param_names.append(key)
    assert (gn_param_nameset - set(nograd_param_names) -
            set(bias_param_names)) == set(gn_param_names)

    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [{
        'params': nonbias_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        bias_params,
        'lr':
        0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }, {
        'params': gn_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN
    }]
    # names of paramerters for each paramter
    param_names = [nonbias_param_names, bias_param_names, gn_param_names]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print(
                        'train_size value: %d different from the one in checkpoint: %d'
                        % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer,
                                                 checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[0][
        'lr']  # lr of non-bias parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)
    if cfg.TRAIN_SYNC_BN:
        # Shu:For synchorinized BN
        patch_replication_callback(maskRCNN)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name() + '_step'
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):

            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 -
                                                                 alpha) + alpha
                else:
                    raise KeyError(
                        'Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate(optimizer, lr,
                                               cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()
            for inner_iter in range(args.iter_size):
                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)

                for key in input_data:
                    if cfg.MODEL.LR_VIEW_ON or cfg.MODEL.GIF_ON or cfg.MODEL.LRASY_MAHA_ON:
                        if key != 'roidb' and key != 'data':  # roidb is a list of ndarrays with inconsistent length
                            input_data[key] = list(
                                map(Variable, input_data[key]))
                        if key == 'data':
                            input_data[key] = [
                                torch.squeeze(item) for item in input_data[key]
                            ]
                            input_data[key] = list(
                                map(Variable, input_data[key]))
                    else:
                        if key != 'roidb':  # roidb is a list of ndarrays with inconsistent length
                            input_data[key] = list(
                                map(Variable, input_data[key]))
                net_outputs = maskRCNN(**input_data)
                training_stats.UpdateIterStats(net_outputs, inner_iter)
                loss = net_outputs['total_loss']
                loss.backward()
            optimizer.step()
            training_stats.IterToc()

            training_stats.LogIterStats(step, lr)

            if (step + 1) % CHECKPOINT_PERIOD == 0:
                net_utils.train_save_ckpt(output_dir, args, step, train_size,
                                          maskRCNN, optimizer)

        # ---- Training ends ----
        # Save last checkpoint
        net_utils.train_save_ckpt(output_dir, args, step, train_size, maskRCNN,
                                  optimizer)

    except (RuntimeError, KeyboardInterrupt):
        del dataiterator
        logger.info('Save ckpt on exception ...')
        net_utils.train_save_ckpt(output_dir, args, step, train_size, maskRCNN,
                                  optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
示例#25
0
def main():
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    if args.dataset == "fis_cell":
        cfg.TRAIN.DATASETS = ('fis_cell_train_val',)
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "coco2017":
        cfg.TRAIN.DATASETS = ('coco_2017_train',)
        cfg.MODEL.NUM_CLASSES = 81
    elif args.dataset == "keypoints_coco2017":
        cfg.TRAIN.DATASETS = ('keypoints_coco_2017_train',)
        cfg.MODEL.NUM_CLASSES = 2
    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    if args.close_fpn:
        args.cfg_file = "configs/few_shot/e2e_mask_rcnn_R-50-C4_1x_{}.yaml".format(args.group)
        cfg.OUTPUT_DIR = 'Outputs_wo_fpn'
    else:
        args.cfg_file = "configs/few_shot/e2e_mask_rcnn_R-50-FPN_1x_{}.yaml".format(args.group)
        
    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    cfg.RNG_SEED = args.random_seed
    if cfg.RNG_SEED is None:
        torch.backends.cudnn.deterministic = False
        torch.backends.cudnn.benchmark = True
    else:
        print('Make the experiment results deterministic.')
        random.seed(cfg.RNG_SEED)
        np.random.seed(cfg.RNG_SEED)
        torch.manual_seed(cfg.RNG_SEED)
        torch.cuda.manual_seed_all(cfg.RNG_SEED)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    cfg.SEEN = args.seen

    if args.close_co_atten:
        cfg.CO_ATTEN = False
        if not args.close_fpn:
            cfg.OUTPUT_DIR = 'Outputs_wo_co_atten'

    if args.close_relation_rcnn:
        cfg.RELATION_RCNN = False
        if not args.close_fpn:
            cfg.FAST_RCNN.ROI_BOX_HEAD = 'fast_rcnn_heads.roi_2mlp_head'
            cfg.MRCNN.ROI_MASK_HEAD = 'mask_rcnn_heads.mask_rcnn_fcn_head_v1up4convs'
        else:
            cfg.FAST_RCNN.ROI_BOX_HEAD = 'torchResNet.ResNet_roi_conv5_head'
            cfg.MRCNN.ROI_MASK_HEAD = 'mask_rcnn_heads.mask_rcnn_fcn_head_v0upshare'
            
        if not args.close_co_atten and not args.close_fpn:
            cfg.OUTPUT_DIR = 'Outputs_wo_relation_rcnn'

    if args.output_dir is not None:
        cfg.OUTPUT_DIR = args.output_dir

    if args.deform_conv:
        cfg.MODEL.USE_DEFORM = True

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH
    original_num_gpus = cfg.NUM_GPUS
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    effective_batch_size = args.iter_size * args.batch_size
    print('effective_batch_size = batch_size * iter_size = %d * %d' % (args.batch_size, args.iter_size))

    print('Adaptive config changes:')
    print('    effective_batch_size: %d --> %d' % (original_batch_size, effective_batch_size))
    print('    NUM_GPUS:             %d --> %d' % (original_num_gpus, cfg.NUM_GPUS))
    print('    IMS_PER_BATCH:        %d --> %d' % (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH))

    ### Adjust learning based on batch size change linearly
    # For iter_size > 1, gradients are `accumulated`, so lr is scaled based
    # on batch_size instead of effective_batch_size
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch_size change:\n'
          '    BASE_LR: {} --> {}'.format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Adjust solver steps
    step_scale = original_batch_size / effective_batch_size
    old_solver_steps = cfg.SOLVER.STEPS
    old_max_iter = cfg.SOLVER.MAX_ITER
    cfg.SOLVER.STEPS = list(map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS))
    cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5)
    print('Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n'
          '    SOLVER.STEPS: {} --> {}\n'
          '    SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps, cfg.SOLVER.STEPS,
                                                  old_max_iter, cfg.SOLVER.MAX_ITER))

    # Scale FPN rpn_proposals collect size (post_nms_topN) in `collect` function
    # of `collect_and_distribute_fpn_rpn_proposals.py`
    #
    # post_nms_topN = int(cfg[cfg_key].RPN_POST_NMS_TOP_N * cfg.FPN.RPN_COLLECT_SCALE + 0.5)
    if cfg.FPN.FPN_ON and cfg.MODEL.FASTER_RCNN:
        cfg.FPN.RPN_COLLECT_SCALE = cfg.TRAIN.IMS_PER_BATCH / original_ims_per_batch
        print('Scale FPN rpn_proposals collect size directly propotional to the change of IMS_PER_BATCH:\n'
              '    cfg.FPN.RPN_COLLECT_SCALE: {}'.format(cfg.FPN.RPN_COLLECT_SCALE))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma
    assert_and_infer_cfg()

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    imdb, roidb, ratio_list, ratio_index, query, cat_list = combined_roidb(
        cfg.TRAIN.DATASETS, True)    
    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb', timers['roidb'].average_time)

    # Effective training sample size for one epoch
    train_size = roidb_size // args.batch_size * args.batch_size

    batchSampler = BatchSampler(
        sampler=MinibatchSampler(ratio_list, ratio_index),
        batch_size=args.batch_size,
        drop_last=True
    )
    dataset = RoiDataLoader(
        roidb, ratio_list, ratio_index, query, 
        cfg.MODEL.NUM_CLASSES,
        training=True, cat_list=cat_list, shot=args.shot)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=batchSampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    dataiterator = iter(dataloader)

    ### Model ###
    maskRCNN = Generalized_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    gn_param_nameset = set()
    for name, module in maskRCNN.named_modules():
        if isinstance(module, nn.GroupNorm):
            gn_param_nameset.add(name+'.weight')
            gn_param_nameset.add(name+'.bias')
    gn_params = []
    gn_param_names = []
    bias_params = []
    bias_param_names = []
    nonbias_params = []
    nonbias_param_names = []
    nograd_param_names = []
    for key, value in maskRCNN.named_parameters():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
                bias_param_names.append(key)
            elif key in gn_param_nameset:
                gn_params.append(value)
                gn_param_names.append(key)
            else:
                nonbias_params.append(value)
                nonbias_param_names.append(key)
        else:
            nograd_param_names.append(key)
    assert (gn_param_nameset - set(nograd_param_names) - set(bias_param_names)) == set(gn_param_names)

    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [
        {'params': nonbias_params,
         'lr': 0,
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY},
        {'params': bias_params,
         'lr': 0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0},
        {'params': gn_params,
         'lr': 0,
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN}
    ]
    # names of paramerters for each paramter
    param_names = [nonbias_param_names, bias_param_names, gn_param_names]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name, map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])
        if args.resume:
            args.shot = checkpoint['shot']
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print('train_size value: %d different from the one in checkpoint: %d'
                          % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            optimizer.load_state_dict(checkpoint['optimizer'])
            # misc_utils.load_optimizer_state_dict(optimizer, checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[0]['lr']  # lr of non-bias parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN, cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name() + '_step'
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            #from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            #tblogger = SummaryWriter(output_dir)
            from loggers.logger import Logger
            tblogger = Logger(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args,
        args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):

            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                #for p in maskRCNN.module.Conv_Body.parameters():
                #    p.requires_grad = False
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 - alpha) + alpha
                else:
                    raise KeyError('Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                #for p in maskRCNN.module.Conv_Body.parameters():
                #    p.requires_grad = True
                net_utils.update_learning_rate(optimizer, lr, cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()
            for inner_iter in range(args.iter_size):
                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)

                for key in input_data:
                    if key != 'roidb' and key != 'query': # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))
                    if key == 'query':
                        input_data[key] = [list(map(Variable, q)) for q in input_data[key]]
                        
                with torch.autograd.detect_anomaly():
                    net_outputs = maskRCNN(**input_data)
                    training_stats.UpdateIterStats(net_outputs, inner_iter)
                    loss = net_outputs['total_loss']
                    loss.backward()
                    torch.nn.utils.clip_grad_value_(maskRCNN.module.parameters(), clip_value=0.4)
            optimizer.step()
            training_stats.IterToc()

            training_stats.LogIterStats(step, lr, input_data, args.shot)

            if (step+1) % CHECKPOINT_PERIOD == 0:
                save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
            
            if (step+1) % args.disp_interval == 0:
                log_training_stats(training_stats, step, lr)

        # ---- Training ends ----
        # Save last checkpoint
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

    except (RuntimeError, KeyboardInterrupt):
        del dataiterator
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
示例#26
0
def main():
    args = parse_args()
    print('Called with args:')
    print(args)

    cfg = set_configs(args)
    timers = defaultdict(Timer)

    ### --------------------------------------------------------------------------------
    ### Dataset Training ###
    ### --------------------------------------------------------------------------------
    timers['roidb_training'].tic()
    roidb_training, ratio_list_training, ratio_index_training, category_to_id_map, prd_category_to_id_map = combined_roidb_for_training(
        cfg.TRAIN.DATASETS)
    timers['roidb_training'].toc()
    roidb_size_training = len(roidb_training)
    logger.info('{:d} training roidb entries'.format(roidb_size_training))
    logger.info('Takes %.2f sec(s) to construct training roidb',
                timers['roidb_training'].average_time)

    batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH

    dataset_training = RoiDataLoader(roidb_training,
                                     cfg.MODEL.NUM_CLASSES,
                                     training=True,
                                     dataset=cfg.TRAIN.DATASETS)
    dataloader_training = torch.utils.data.DataLoader(
        dataset_training,
        batch_size=batch_size,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch,
        shuffle=True,
        drop_last=True)
    dataiterator_training = iter(dataloader_training)

    ### --------------------------------------------------------------------------------
    ### Dataset Validation ###
    ### --------------------------------------------------------------------------------
    timers['roidb_val'].tic()
    roidb_val, ratio_list_val, ratio_index_val, _, _ = combined_roidb_for_training(
        cfg.VAL.DATASETS)
    timers['roidb_val'].toc()
    roidb_size_val = len(roidb_val)
    logger.info('{:d} val roidb entries'.format(roidb_size_val))
    logger.info('Takes %.2f sec(s) to construct val roidb',
                timers['roidb_val'].average_time)

    dataset_val = RoiDataLoader(roidb_val,
                                cfg.MODEL.NUM_CLASSES,
                                training=False,
                                dataset=cfg.VAL.DATASETS)
    dataloader_val = torch.utils.data.DataLoader(
        dataset_val,
        batch_size=batch_size,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch,
        drop_last=True)

    ### --------------------------------------------------------------------------------
    ### Dataset Test ###
    ### --------------------------------------------------------------------------------
    timers['roidb_test'].tic()
    roidb_test, ratio_list_test, ratio_index_test, _, _ = combined_roidb_for_training(
        cfg.TEST.DATASETS)
    timers['roidb_test'].toc()
    roidb_size_test = len(roidb_test)
    logger.info('{:d} test roidb entries'.format(roidb_size_test))
    logger.info('Takes %.2f sec(s) to construct test roidb',
                timers['roidb_test'].average_time)

    dataset_test = RoiDataLoader(roidb_test,
                                 cfg.MODEL.NUM_CLASSES,
                                 training=False,
                                 dataset=cfg.TEST.DATASETS)
    dataloader_test = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=batch_size,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch,
        drop_last=True)

    ### --------------------------------------------------------------------------------
    ### Dataset Unseen ###
    ### --------------------------------------------------------------------------------
    if args.dataset == 'vhico':
        timers['roidb_unseen'].tic()
        roidb_unseen, ratio_list_unseen, ratio_index_unseen, _, _ = combined_roidb_for_training(
            cfg.UNSEEN.DATASETS)
        timers['roidb_unseen'].toc()
        roidb_size_unseen = len(roidb_unseen)
        logger.info('{:d} test unseen roidb entries'.format(roidb_size_unseen))
        logger.info('Takes %.2f sec(s) to construct test roidb',
                    timers['roidb_unseen'].average_time)

        dataset_unseen = RoiDataLoader(roidb_unseen,
                                       cfg.MODEL.NUM_CLASSES,
                                       training=False,
                                       dataset=cfg.UNSEEN.DATASETS)
        dataloader_unseen = torch.utils.data.DataLoader(
            dataset_unseen,
            batch_size=batch_size,
            num_workers=cfg.DATA_LOADER.NUM_THREADS,
            collate_fn=collate_minibatch,
            drop_last=True)

    ### --------------------------------------------------------------------------------
    ### Model ###
    ### --------------------------------------------------------------------------------
    maskRCNN = Generalized_RCNN(category_to_id_map=category_to_id_map,
                                prd_category_to_id_map=prd_category_to_id_map,
                                args=args)
    if cfg.CUDA:
        maskRCNN.cuda()

    ### --------------------------------------------------------------------------------
    ### Optimizer ###
    # record backbone params, i.e., conv_body and box_head params
    ### --------------------------------------------------------------------------------
    gn_params = []
    backbone_bias_params = []
    backbone_bias_param_names = []
    prd_branch_bias_params = []
    prd_branch_bias_param_names = []
    backbone_nonbias_params = []
    backbone_nonbias_param_names = []
    prd_branch_nonbias_params = []
    prd_branch_nonbias_param_names = []
    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'gn' in key:
                gn_params.append(value)
            elif 'Conv_Body' in key or 'Box_Head' in key or 'Box_Outs' in key or 'RPN' in key:
                if 'bias' in key:
                    backbone_bias_params.append(value)
                    backbone_bias_param_names.append(key)
                else:
                    backbone_nonbias_params.append(value)
                    backbone_nonbias_param_names.append(key)
            else:
                if 'bias' in key:
                    prd_branch_bias_params.append(value)
                    prd_branch_bias_param_names.append(key)
                else:
                    prd_branch_nonbias_params.append(value)
                    prd_branch_nonbias_param_names.append(key)
    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [{
        'params': backbone_nonbias_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        backbone_bias_params,
        'lr':
        0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }, {
        'params': prd_branch_nonbias_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        prd_branch_bias_params,
        'lr':
        0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }, {
        'params': gn_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN
    }]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### --------------------------------------------------------------------------------
    ### Load checkpoint
    ### --------------------------------------------------------------------------------
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])

        print(
            '--------------------------------------------------------------------------------'
        )
        print('loading checkpoint %s' % load_name)
        print(
            '--------------------------------------------------------------------------------'
        )

        if args.resume:
            print('resume')
            args.start_step = checkpoint['step'] + 1
            misc_utils.load_optimizer_state_dict(optimizer,
                                                 checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()
    else:
        print('args.load_ckpt', args.load_ckpt)

    lr = optimizer.param_groups[2][
        'lr']  # lr of non-backbone parameters, for commmand line outputs.
    backbone_lr = optimizer.param_groups[0][
        'lr']  # lr of backbone parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### --------------------------------------------------------------------------------
    ### Training Setups ###
    ### --------------------------------------------------------------------------------
    args.run_name = args.out_dir
    output_dir = misc_utils.get_output_dir(args, args.out_dir)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            tblogger = SummaryWriter(output_dir)

    ### --------------------------------------------------------------------------------
    ### Training Loop ###
    ### --------------------------------------------------------------------------------
    maskRCNN.train()

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None, True)

    val_stats = ValStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None, False)

    test_stats = TestStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None, False)

    best_total_loss = np.inf
    best_eval_result = 0

    ### --------------------------------------------------------------------------------
    ### EVAL ###
    ### --------------------------------------------------------------------------------
    if cfg.EVAL_SUBSET == 'unseen':
        print('testing unseen ...')
        is_best, best_eval_result = run_eval(args,
                                             cfg,
                                             maskRCNN,
                                             dataloader_unseen,
                                             step=0,
                                             output_dir=output_dir,
                                             test_stats=test_stats,
                                             best_eval_result=best_eval_result,
                                             eval_subset=cfg.EVAL_SUBSET)
        return
    elif cfg.EVAL_SUBSET == 'test':
        print('testing ...')
        is_best, best_eval_result = run_eval(args,
                                             cfg,
                                             maskRCNN,
                                             dataloader_test,
                                             step=0,
                                             output_dir=output_dir,
                                             test_stats=test_stats,
                                             best_eval_result=best_eval_result,
                                             eval_subset=cfg.EVAL_SUBSET)
        return

    ### --------------------------------------------------------------------------------
    ### TRAIN ###
    ### --------------------------------------------------------------------------------
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):
            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 -
                                                                 alpha) + alpha
                else:
                    raise KeyError(
                        'Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate_rel(optimizer, lr, lr_new)
                lr = optimizer.param_groups[2]['lr']
                backbone_lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate_rel(optimizer, lr,
                                                   cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[2]['lr']
                backbone_lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate_rel(optimizer, lr, lr_new)
                lr = optimizer.param_groups[2]['lr']
                backbone_lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            #########################################################################################################################
            ## train
            #########################################################################################################################
            training_stats.IterTic()
            optimizer.zero_grad()

            for inner_iter in range(args.iter_size):
                try:
                    input_data = next(dataiterator_training)
                except StopIteration:
                    print('recurrence data loader')
                    dataiterator_training = iter(dataloader_training)
                    input_data = next(dataiterator_training)

                for key in input_data:
                    if key != 'roidb':  # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))

                net_outputs = maskRCNN(**input_data)

                training_stats.UpdateIterStats(net_outputs['gt_label'],
                                               inner_iter)
                loss = net_outputs['gt_label']['total_loss']
                loss.backward()

            optimizer.step()
            training_stats.IterToc()
            training_stats.LogIterStats(step, lr, backbone_lr)

            if (step + 1) % cfg.SAVE_MODEL_ITER == 0:
                save_ckpt(output_dir, args, step, batch_size, maskRCNN,
                          optimizer, False, best_total_loss)

        # ---- Training ends ----
        save_ckpt(output_dir, args, step, batch_size, maskRCNN, optimizer,
                  False, best_total_loss)

    except (RuntimeError, KeyboardInterrupt):
        del dataiterator_training
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, batch_size, maskRCNN, optimizer,
                  False, best_total_loss)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
示例#27
0
def main():
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    if args.dataset == "coco2017":
        cfg.TRAIN.DATASETS = ('coco_2017_train', )
        cfg.MODEL.NUM_CLASSES = 81
    elif args.dataset == "keypoints_coco2017":
        cfg.TRAIN.DATASETS = ('keypoints_coco_2017_train', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "pascal_parts_heads":
        cfg.TRAIN.DATASETS = ('pascal_parts_heads', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "scuthead_a":
        cfg.TRAIN.DATASETS = ('scuthead_a', )
        cfg.MODEL.NUM_CLASSES = 2
    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    print('Batch size change from {} (in config file) to {}'.format(
        original_batch_size, args.batch_size))
    print('NUM_GPUs: %d, TRAIN.IMS_PER_BATCH: %d' %
          (cfg.NUM_GPUS, cfg.TRAIN.IMS_PER_BATCH))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Adjust learning based on batch size change linearly
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch size change: {} --> {}'.
          format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma

    timers = defaultdict(Timer)

    ### Dataset ###
    print("Checkpoint 0")
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    print("Checkpoint 1")
    train_size = len(roidb)
    logger.info('{:d} roidb entries'.format(train_size))
    logger.info('Takes %.2f sec(s) to construct roidb',
                timers['roidb'].average_time)

    sampler = MinibatchSampler(ratio_list, ratio_index)
    dataset = RoiDataLoader(roidb, cfg.MODEL.NUM_CLASSES, training=True)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        sampler=sampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)

    assert_and_infer_cfg()

    ### Model ###
    maskRCNN = Generalized_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    bias_params = []
    nonbias_params = []
    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
            else:
                nonbias_params.append(value)
    params = [{
        'params': nonbias_params,
        'lr': cfg.SOLVER.BASE_LR,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        bias_params,
        'lr':
        cfg.SOLVER.BASE_LR * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])
        if args.resume:
            assert checkpoint['iters_per_epoch'] == train_size // args.batch_size, \
                "iters_per_epoch should match for resume"
            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer,
                                                 checkpoint['optimizer'])
            if checkpoint['step'] == (checkpoint['iters_per_epoch'] - 1):
                # Resume from end of an epoch
                args.start_epoch = checkpoint['epoch'] + 1
                args.start_iter = 0
            else:
                # Resume from the middle of an epoch.
                # NOTE: dataloader is not synced with previous state
                args.start_epoch = checkpoint['epoch']
                args.start_iter = checkpoint['step'] + 1
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[0][
        'lr']  # lr of non-bias parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name()
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    training_stats = TrainingStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)

    iters_per_epoch = int(train_size / args.batch_size)  # drop last
    print(
        str(iters_per_epoch) + '= int(' + str(train_size) + '/' +
        str(args.batch_size) + ')')
    args.iters_per_epoch = iters_per_epoch
    ckpt_interval_per_epoch = iters_per_epoch // args.ckpt_num_per_epoch
    print(
        str(ckpt_interval_per_epoch) + '=' + str(iters_per_epoch) + '//' +
        str(args.ckpt_num_per_epoch))
    try:
        logger.info('Training starts !')
        args.step = args.start_iter
        global_step = iters_per_epoch * args.start_epoch + args.step
        for args.epoch in range(args.start_epoch,
                                args.start_epoch + args.num_epochs):
            # ---- Start of epoch ----
            print(args.epoch)
            # adjust learning rate
            if args.lr_decay_epochs and args.epoch == args.lr_decay_epochs[
                    0] and args.start_iter == 0:
                args.lr_decay_epochs.pop(0)
                net_utils.decay_learning_rate(optimizer, lr, cfg.SOLVER.GAMMA)
                lr *= cfg.SOLVER.GAMMA

            for args.step, input_data in zip(
                    range(args.start_iter, iters_per_epoch), dataloader):

                for key in input_data:
                    if key != 'roidb':  # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))

                training_stats.IterTic()
                net_outputs = maskRCNN(**input_data)
                training_stats.UpdateIterStats(net_outputs)
                loss = net_outputs['total_loss']
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                training_stats.IterToc()

                if (args.step + 1) % ckpt_interval_per_epoch == 0:
                    net_utils.save_ckpt(output_dir, args, maskRCNN, optimizer)

                if args.step % args.disp_interval == 0:
                    log_training_stats(training_stats, global_step, lr)

                global_step += 1

            # ---- End of epoch ----
            # save checkpoint
            net_utils.save_ckpt(output_dir, args, maskRCNN, optimizer)
            # reset starting iter number after first epoch
            args.start_iter = 0

        # ---- Training ends ----
        if iters_per_epoch % args.disp_interval != 0:
            # log last stats at the end
            log_training_stats(training_stats, global_step, lr)

    except (RuntimeError, KeyboardInterrupt):
        logger.info('Save ckpt on exception ...')
        net_utils.save_ckpt(output_dir, args, maskRCNN, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
示例#28
0

cfg_file = 'e2e_segdisp-R-50_3Dpool_1x.yaml'
cfg_from_file(cfg_file)
print(cfg.SEM)
print(cfg.DISP)
#cfg_from_list(cfg_file)
#assert_and_infer_cfg()
devices_ids = [5]
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
    [str(ids) for ids in devices_ids])
torch.backends.cudnn.benchmark = True
#torch.cuda.set_device(3)
len_gpus = len(devices_ids)
batch_size = 2 * len_gpus
#net = mynn.DataParallel(load_net().to('cuda'), minibatch=True)
net = mynn.DataParallel(DispSeg().to('cuda'), minibatch=True)
optimizer = optim.SGD(net.parameters(), lr=0.000875, momentum=0.9)
criterion = nn.NLLLoss(ignore_index=255)
#dataloader= dataloader(batch_size, len_gpus)
for i in range(10):
    #for i, inputs in zip(range(1000), dataloader):
    inputs = dataloader(batch_size, len_gpus)
    for key in inputs:
        inputs[key] = torch.chunk(inputs[key], chunks=len_gpus, dim=0)
    optimizer.zero_grad()
    loss = net(**inputs)
    optimizer.step()
    for k in loss['losses'].keys():
        print(loss['losses'][k].item())
示例#29
0
def main():

    saveNetStructure = False
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        #set gpu device
        os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
            [str(ids) for ids in args.device_ids])
        torch.backends.cudnn.benchmark = True
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    if args.dataset == "coco2017":
        cfg.TRAIN.DATASETS = ('coco_2017_train', )
        cfg.MODEL.NUM_CLASSES = 81
    elif args.dataset == "keypoints_coco2017":
        cfg.TRAIN.DATASETS = ('keypoints_coco_2017_train', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "cityscapes":
        cfg.TRAIN.DATASETS = ('cityscapes_semseg_train', )
        cfg.MODEL.NUM_CLASSES = 19
    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    print('Batch size change from {} (in config file) to {}'.format(
        original_batch_size, args.batch_size))
    print('NUM_GPUs: %d, TRAIN.IMS_PER_BATCH: %d' %
          (cfg.NUM_GPUS, cfg.TRAIN.IMS_PER_BATCH))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Adjust learning based on batch size change linearly
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch size change: {} --> {}'.
          format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    if cfg.SEM.SEM_ON or cfg.DISP.DISP_ON:
        roidb, ratio_list, ratio_index = combined_roidb_for_training_semseg(
            cfg.TRAIN.DATASETS)
    else:
        roidb, ratio_list, ratio_index = combined_roidb_for_training(
            cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    train_size = len(roidb)
    logger.info('{:d} roidb entries'.format(train_size))
    logger.info('Takes %.2f sec(s) to construct roidb',
                timers['roidb'].average_time)

    sampler = MinibatchSampler(ratio_list, ratio_index)
    dataset = RoiDataLoader(roidb, cfg.MODEL.NUM_CLASSES, training=True)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        sampler=sampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch_semseg
        if cfg.SEM.SEM_ON or cfg.DISP.DISP_ON else collate_minibatch)

    assert_and_infer_cfg()
    #for args.step, input_data in zip(range(100), dataloader):
    #    data_L = input_data['data']
    #    data_R = input_data['data_R']
    #    label = input_data['disp_label_0']
    #    cv2.imwrite('ims_L.png', data_L[0].numpy()[0].transpose(1,2,0)[:,:,::-1]+cfg.PIXEL_MEANS)
    #    cv2.imwrite('ims_R.png', data_R[0].numpy()[0].transpose(1,2,0)[:,:,::-1]+cfg.PIXEL_MEANS)
    #    cv2.imwrite('label.png', label[0].numpy()[0])
    #    return
    ### Model ###
    dispSeg = DispSeg()

    if cfg.CUDA:
        dispSeg.to('cuda')

    pspnet_bias_params = []
    pspnet_nonbias_params = []
    for key, value in dict(dispSeg.pspnet.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                pspnet_bias_params.append(value)
            else:
                pspnet_nonbias_params.append(value)

    pspnet_params = [{
        'params': pspnet_nonbias_params,
        'lr': cfg.SOLVER.BASE_LR,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        pspnet_bias_params,
        'lr':
        cfg.SOLVER.BASE_LR * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }]

    glassGCN_bias_params = []
    glassGCN_nonbias_params = []
    for key, value in dict(dispSeg.glassGCN.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                glassGCN_bias_params.append(value)
            else:
                glassGCN_nonbias_params.append(value)

    segdisp3d_params = [{
        'params': glassGCN_nonbias_params,
        'lr': cfg.SOLVER.BASE_LR,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        glassGCN_bias_params,
        'lr':
        cfg.SOLVER.BASE_LR * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }]

    if cfg.SOLVER.TYPE == "SGD":
        optimizerP = torch.optim.SGD(pspnet_params,
                                     momentum=cfg.SOLVER.MOMENTUM)
        optimizerS = torch.optim.SGD(segdisp3d_params,
                                     momentum=cfg.SOLVER.MOMENTUM)

    elif cfg.SOLVER.TYPE == "Adam":
        optimizerP = torch.optim.Adam(pspnet_params)
        optimizerS = torch.optim.Adam(segdisp3d_params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(pspnet, checkpoint['model'])
        net_utils.load_ckpt(segdisp3d, checkpoint['model'])

        if args.resume:
            assert checkpoint['iters_per_epoch'] == train_size // args.batch_size, \
                "iters_per_epoch should match for resume"
            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer,
                                                 checkpoint['optimizer'])
            if checkpoint['step'] == (checkpoint['iters_per_epoch'] - 1):
                # Resume from end of an epoch
                args.start_epoch = checkpoint['epoch'] + 1
                args.start_iter = 0
            else:
                # Resume from the middle of an epoch.
                # NOTE: dataloader is not synced with previous state
                args.start_epoch = checkpoint['epoch']
                args.start_iter = checkpoint['step'] + 1
        del checkpoint
        torch.cuda.empty_cache()

    lr = optimizerP.param_groups[0][
        'lr']  # lr of non-bias parameters, for commmand line outputs.

    dispSeg = mynn.DataParallel(dispSeg,
                                cpu_keywords=['im_info', 'roidb'],
                                minibatch=True)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name()
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            #from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    dispSeg.train()
    training_stats = TrainingStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)

    iters_per_epoch = int(train_size / args.batch_size)  # drop last
    args.iters_per_epoch = iters_per_epoch
    ckpt_interval_per_epoch = iters_per_epoch // args.ckpt_num_per_epoch
    try:
        logger.info('Training starts !')
        args.step = args.start_iter
        global_step = iters_per_epoch * args.start_epoch + args.step
        for args.epoch in range(args.start_epoch,
                                args.start_epoch + args.num_epochs):
            # ---- Start of epoch ----

            # adjust learning rate
            if args.lr_decay_epochs and args.epoch == args.lr_decay_epochs[
                    0] and args.start_iter == 0:
                args.lr_decay_epochs.pop(0)
                net_utils.decay_learning_rate(optimizerP, lr, cfg.SOLVER.GAMMA)
                net_utils.decay_learning_rate(optimizerS, lr, cfg.SOLVER.GAMMA)
                lr *= cfg.SOLVER.GAMMA

            for args.step, input_data in zip(
                    range(args.start_iter, iters_per_epoch), dataloader):

                if cfg.DISP.DISP_ON:
                    input_data['data'] = list(
                        map(lambda x, y: torch.cat((x, y), dim=0),
                            input_data['data'], input_data['data_R']))
                    if cfg.SEM.DECODER_TYPE.endswith('3Ddeepsup'):
                        input_data['disp_scans'] = torch.arange(
                            0, cfg.DISP.MAX_DISPLACEMENT).float().view(
                                1, cfg.DISP.MAX_DISPLACEMENT, 1,
                                1).repeat(args.batch_size, 1, 1, 1)
                        input_data['semseg_scans'] = torch.arange(
                            0, cfg.MODEL.NUM_CLASSES).long().view(
                                1, cfg.MODEL.NUM_CLASSES, 1,
                                1).repeat(args.batch_size, 1, 1, 1)
                    del input_data['data_R']

                for key in input_data:
                    if key != 'roidb':  # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(
                            map(
                                lambda x: Variable(x, requires_grad=False).to(
                                    'cuda'), input_data[key]))
                training_stats.IterTic()
                net_outputs = dispSeg(**input_data)
                training_stats.UpdateIterStats(net_outputs)
                #loss = net_outputs['losses']['loss_semseg']
                #acc  = net_outputs['metrics']['accuracy_pixel']
                #print (loss.item(), acc)
                #for key in net_outputs.keys():
                #    print(key)
                loss = net_outputs['total_loss']

                #print("loss.shape:",loss)
                optimizerP.zero_grad()
                optimizerS.zero_grad()
                loss.backward()
                optimizerP.step()
                optimizerS.step()
                training_stats.IterToc()

                if args.step % args.disp_interval == 0:
                    #disp_image=net_outputs['disp_image']
                    #semseg_image=net_outputs['semseg_image']
                    #tblogger.add_image('disp_image',disp_image,global_step)
                    #tblogger.add_image('semseg_image',semseg_image,global_step)
                    log_training_stats(training_stats, global_step, lr)
                global_step += 1
            # ---- End of epoch ----
            # save checkpoint

            net_utils.save_ckpt(output_dir, args, dispSeg, optimizerS)
            # reset starting iter number after first epoch
            args.start_iter = 0

        # ---- Training ends ----
        #if iters_per_epoch % args.disp_interval != 0:
        # log last stats at the end
        #    log_training_stats(training_stats, global_step, lr)

    except (RuntimeError, KeyboardInterrupt):
        logger.info('Save ckpt on exception ...')
        net_utils.save_ckpt(output_dir, args, dispSeg, optimizerS)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
def main():
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    cfg.DATASET = args.dataset
    if args.dataset == "vg80k":
        cfg.TRAIN.DATASETS = ('vg80k_train', )
        cfg.TEST.DATASETS = ('vg80k_val', )
        cfg.MODEL.NUM_CLASSES = 53305  # includes background
        cfg.MODEL.NUM_PRD_CLASSES = 29086  # excludes background
    elif args.dataset == "vg8k":
        cfg.TRAIN.DATASETS = ('vg8k_train', )
        cfg.TEST.DATASETS = ('vg8k_val', )
        cfg.MODEL.NUM_CLASSES = 5331  # includes background
        cfg.MODEL.NUM_PRD_CLASSES = 2000  # excludes background
    elif args.dataset == "vrd":
        cfg.TRAIN.DATASETS = ('vrd_train', )
        cfg.TEST.DATASETS = ('vrd_val', )
        cfg.MODEL.NUM_CLASSES = 101
        cfg.MODEL.NUM_PRD_CLASSES = 70  # exclude background
    elif args.dataset == "vg":
        cfg.TRAIN.DATASETS = ('vg_train', )
        cfg.TEST.DATASETS = ('vg_val', )
        cfg.MODEL.NUM_CLASSES = 151
        cfg.MODEL.NUM_PRD_CLASSES = 50  # exclude background
    elif args.dataset == "gvqa20k":
        cfg.TRAIN.DATASETS = ('gvqa20k_train', )
        cfg.TEST.DATASETS = ('gvqa20k_val', )
        cfg.MODEL.NUM_CLASSES = 1704  # includes background
        cfg.MODEL.NUM_PRD_CLASSES = 310  # exclude background
    elif args.dataset == "gvqa10k":
        cfg.TRAIN.DATASETS = ('gvqa10k_train', )
        cfg.TEST.DATASETS = ('gvqa10k_val', )
        cfg.MODEL.NUM_CLASSES = 1704  # includes background
        cfg.MODEL.NUM_PRD_CLASSES = 310  # exclude background
    elif args.dataset == "gvqa":
        cfg.TRAIN.DATASETS = ('gvqa_train', )
        cfg.TEST.DATASETS = ('gvqa_val', )
        cfg.MODEL.NUM_CLASSES = 1704  # includes background
        cfg.MODEL.NUM_PRD_CLASSES = 310  # exclude background

    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    if args.seed:
        cfg.RNG_SEED = args.seed

    # Some imports need to be done after loading the config to avoid using default values
    from datasets.roidb_rel import combined_roidb_for_training
    from modeling.model_builder_rel import Generalized_RCNN
    from core.test_engine_rel import run_eval_inference, run_inference
    from core.test_engine_rel import get_inference_dataset, get_roidb_and_dataset

    logger.info('Training with config:')
    logger.info(pprint.pformat(cfg))

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH
    original_num_gpus = cfg.NUM_GPUS
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    effective_batch_size = args.iter_size * args.batch_size
    print('effective_batch_size = batch_size * iter_size = %d * %d' %
          (args.batch_size, args.iter_size))

    print('Adaptive config changes:')
    print('    effective_batch_size: %d --> %d' %
          (original_batch_size, effective_batch_size))
    print('    NUM_GPUS:             %d --> %d' %
          (original_num_gpus, cfg.NUM_GPUS))
    print('    IMS_PER_BATCH:        %d --> %d' %
          (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH))

    ### Adjust learning based on batch size change linearly
    # For iter_size > 1, gradients are `accumulated`, so lr is scaled based
    # on batch_size instead of effective_batch_size
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch_size change:\n'
          '    BASE_LR: {} --> {}'.format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Adjust solver steps
    step_scale = original_batch_size / effective_batch_size
    old_solver_steps = cfg.SOLVER.STEPS
    old_max_iter = cfg.SOLVER.MAX_ITER
    cfg.SOLVER.STEPS = list(
        map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS))
    cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5)
    print(
        'Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n'
        '    SOLVER.STEPS: {} --> {}\n'
        '    SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps,
                                                cfg.SOLVER.STEPS, old_max_iter,
                                                cfg.SOLVER.MAX_ITER))

    if cfg.FPN.FPN_ON and cfg.MODEL.FASTER_RCNN:
        cfg.FPN.RPN_COLLECT_SCALE = cfg.TRAIN.IMS_PER_BATCH / original_ims_per_batch
        print(
            'Scale FPN rpn_proposals collect size directly propotional to the change of IMS_PER_BATCH:\n'
            '    cfg.FPN.RPN_COLLECT_SCALE: {}'.format(
                cfg.FPN.RPN_COLLECT_SCALE))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma
    assert_and_infer_cfg()

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb',
                timers['roidb'].average_time)

    # Effective training sample size for one epoch
    train_size = roidb_size // args.batch_size * args.batch_size

    batchSampler = BatchSampler(sampler=MinibatchSampler(
        ratio_list, ratio_index),
                                batch_size=args.batch_size,
                                drop_last=True)
    dataset = RoiDataLoader(roidb, cfg.MODEL.NUM_CLASSES, training=True)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=batchSampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    dataiterator = iter(dataloader)

    ### Model ###
    maskRCNN = Generalized_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    # record backbone params, i.e., conv_body and box_head params
    gn_params = []
    backbone_bias_params = []
    backbone_bias_param_names = []
    prd_branch_bias_params = []
    prd_branch_bias_param_names = []
    backbone_nonbias_params = []
    backbone_nonbias_param_names = []
    prd_branch_nonbias_params = []
    prd_branch_nonbias_param_names = []

    if cfg.MODEL.DECOUPLE:
        for key, value in dict(maskRCNN.named_parameters()).items():
            if not 'so_sem_embeddings.2' in key:
                value.requires_grad = False

    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'gn' in key:
                gn_params.append(value)
            elif 'Conv_Body' in key or 'Box_Head' in key or 'Box_Outs' in key or 'RPN' in key:
                if 'bias' in key:
                    backbone_bias_params.append(value)
                    backbone_bias_param_names.append(key)
                else:
                    backbone_nonbias_params.append(value)
                    backbone_nonbias_param_names.append(key)
            else:
                if 'bias' in key:
                    prd_branch_bias_params.append(value)
                    prd_branch_bias_param_names.append(key)
                else:
                    prd_branch_nonbias_params.append(value)
                    prd_branch_nonbias_param_names.append(key)
    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [{
        'params': backbone_nonbias_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        backbone_bias_params,
        'lr':
        0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }, {
        'params': prd_branch_nonbias_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        prd_branch_bias_params,
        'lr':
        0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }, {
        'params': gn_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN
    }]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    load_ckpt_dir = './'
    ### Load checkpoint
    if args.load_ckpt_dir:
        load_name = get_checkpoint_resume_file(args.load_ckpt_dir)
        load_ckpt_dir = args.load_ckpt_dir
    elif args.load_ckpt:
        load_name = args.load_ckpt
        load_ckpt_dir = os.path.dirname(args.load_ckpt)

    if args.load_ckpt or args.load_ckpt_dir:
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)

        if cfg.MODEL.DECOUPLE:
            del checkpoint['model']['RelDN.so_sem_embeddings.2.weight']
            del checkpoint['model']['RelDN.so_sem_embeddings.2.bias']
            del checkpoint['model']['RelDN.prd_sem_embeddings.2.weight']
            del checkpoint['model']['RelDN.prd_sem_embeddings.2.bias']

        net_utils.load_ckpt(maskRCNN, checkpoint['model'])
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print(
                        'train_size value: %d different from the one in checkpoint: %d'
                        % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer,
                                                 checkpoint['optimizer'])

        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[2][
        'lr']  # lr of non-backbone parameters, for commmand line outputs.
    backbone_lr = optimizer.param_groups[0][
        'lr']  # lr of backbone parameters, for commmand line outputs.

    prd_categories = maskRCNN.prd_categories
    obj_categories = maskRCNN.obj_categories
    prd_freq_dict = maskRCNN.prd_freq_dict
    obj_freq_dict = maskRCNN.obj_freq_dict

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name() + '_step_with_prd_cls_v' + str(
        cfg.MODEL.SUBTYPE)
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        ckpt_dir = os.path.join(output_dir, 'ckpt')

        if not os.path.exists(ckpt_dir):
            os.makedirs(ckpt_dir)

        # if os.path.exists(os.path.join(ckpt_dir, 'best.json')):
        #     best = json.load(open(os.path.join(ckpt_dir, 'best.json')))
        if args.resume and os.path.exists(
                os.path.join(load_ckpt_dir, 'best.json')):
            logger.info('Loading best json from :' +
                        os.path.join(load_ckpt_dir, 'best.json'))
            best = json.load(open(os.path.join(load_ckpt_dir, 'best.json')))
            json.dump(best, open(os.path.join(ckpt_dir, 'best.json'), 'w'))
        else:
            best = {}
            best['avg_per_class_acc'] = 0.0
            best['iteration'] = 0
            best['accuracies'] = []
            json.dump(best, open(os.path.join(ckpt_dir, 'best.json'), 'w'))

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    args.output_dir = output_dir
    args.do_val = True
    args.use_gt_boxes = True
    args.use_gt_labels = True

    logger.info('Creating val roidb')
    val_dataset_name, val_proposal_file = get_inference_dataset(0)
    val_roidb, val_dataset, start_ind, end_ind, total_num_images = get_roidb_and_dataset(
        val_dataset_name, val_proposal_file, None, args.do_val)
    logger.info('Done')

    ### Training Loop ###
    maskRCNN.train()

    # CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)
    # CHECKPOINT_PERIOD = cfg.SOLVER.MAX_ITER / cfg.TRAIN.SNAPSHOT_FREQ
    CHECKPOINT_PERIOD = 10000
    EVAL_PERIOD = cfg.TRAIN.EVAL_PERIOD
    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):

            maskRCNN.train()
            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 -
                                                                 alpha) + alpha
                else:
                    raise KeyError(
                        'Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate_rel(optimizer, lr, lr_new)
                lr = optimizer.param_groups[2]['lr']
                backbone_lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate_rel(optimizer, lr,
                                                   cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[2]['lr']
                backbone_lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate_rel(optimizer, lr, lr_new)
                lr = optimizer.param_groups[2]['lr']
                backbone_lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()
            for inner_iter in range(args.iter_size):
                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)

                for key in input_data:
                    if key != 'roidb':  # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))

                net_outputs = maskRCNN(**input_data)
                training_stats.UpdateIterStats(net_outputs, inner_iter)
                loss = net_outputs['total_loss']
                loss.backward()

            optimizer.step()
            training_stats.IterToc()

            training_stats.LogIterStats(step, lr, backbone_lr)

            if (step + 1) % EVAL_PERIOD == 0 or (step
                                                 == cfg.SOLVER.MAX_ITER - 1):
                logger.info('Validating model')
                eval_model = maskRCNN.module
                eval_model = mynn.DataParallel(
                    eval_model,
                    cpu_keywords=['im_info', 'roidb'],
                    device_ids=[0],
                    minibatch=True)
                eval_model.eval()
                all_results = run_eval_inference(eval_model,
                                                 val_roidb,
                                                 args,
                                                 val_dataset,
                                                 val_dataset_name,
                                                 val_proposal_file,
                                                 ind_range=None,
                                                 multi_gpu_testing=False,
                                                 check_expected_results=True)
                csv_path = os.path.join(output_dir, 'eval.csv')
                all_results = all_results[0]
                generate_csv_file_from_det_obj(all_results, csv_path,
                                               obj_categories, prd_categories,
                                               obj_freq_dict, prd_freq_dict)
                overall_metrics, per_class_metrics = get_metrics_from_csv(
                    csv_path)
                obj_acc = per_class_metrics[(csv_path, 'obj', 'top1')]
                sbj_acc = per_class_metrics[(csv_path, 'sbj', 'top1')]
                prd_acc = per_class_metrics[(csv_path, 'rel', 'top1')]
                avg_obj_sbj = (obj_acc + sbj_acc) / 2.0
                avg_acc = (prd_acc + avg_obj_sbj) / 2.0

                best = json.load(open(os.path.join(ckpt_dir, 'best.json')))
                if avg_acc > best['avg_per_class_acc']:
                    print('Found new best validation accuracy at {:2.2f}%'.
                          format(avg_acc))
                    print('Saving best model..')
                    best['avg_per_class_acc'] = avg_acc
                    best['iteration'] = step
                    best['per_class_metrics'] = {
                        'obj_top1':
                        per_class_metrics[(csv_path, 'obj', 'top1')],
                        'sbj_top1':
                        per_class_metrics[(csv_path, 'sbj', 'top1')],
                        'prd_top1':
                        per_class_metrics[(csv_path, 'rel', 'top1')]
                    }
                    best['overall_metrics'] = {
                        'obj_top1': overall_metrics[(csv_path, 'obj', 'top1')],
                        'sbj_top1': overall_metrics[(csv_path, 'sbj', 'top1')],
                        'prd_top1': overall_metrics[(csv_path, 'rel', 'top1')]
                    }
                    save_best_ckpt(output_dir, args, step, train_size,
                                   maskRCNN, optimizer)
                    json.dump(best,
                              open(os.path.join(ckpt_dir, 'best.json'), 'w'))

            if (step + 1) % CHECKPOINT_PERIOD == 0:
                print('Saving Checkpoint..')
                save_ckpt(output_dir, args, step, train_size, maskRCNN,
                          optimizer)

        # ---- Training ends ----
        # Save last checkpoint
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

    except Exception as e:
        del dataiterator
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()