Exemplo n.º 1
0
def initialize_model_from_cfg(args, gpu_id=0):
    """Initialize a model from the global cfg. Loads test-time weights and
    set to evaluation mode.
    """
    Generalized_RCNN = importlib.import_module('modeling_rel.' +
                                               cfg.MODEL.TYPE).Generalized_RCNN
    model = Generalized_RCNN()
    model.eval()

    if args.cuda:
        model.cuda()

    if args.load_ckpt:
        load_name = args.load_ckpt
        logger.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils_rel.load_ckpt_rel(model, checkpoint['model'])

    if args.load_detectron:
        logger.info("loading detectron weights %s", args.load_detectron)
        load_detectron_weight(model, args.load_detectron)

    model = mynn.DataParallel(model,
                              cpu_keywords=['im_info', 'roidb'],
                              minibatch=True)

    return model
Exemplo n.º 2
0
def initialize_model(load_ckpt):
    """Initialize a model from the global cfg. Loads test-time weights and
    set to evaluation mode.
    """
    model = model_builder_rel.Generalized_RCNN()
    model.train()
    model.cuda()

    load_name = load_ckpt
    checkpoint = torch.load(load_name, map_location=lambda storage, loc: storage)
    net_utils_rel.load_ckpt_rel(model, checkpoint['model'])
    #model = mynn.DataParallel(model, cpu_keywords=['im_info', 'roidb'], minibatch=True)
    return model
 def load_detector_weights(self, weight_name):
     logger.info("loading pretrained weights from %s", weight_name)
     checkpoint = torch.load(weight_name,
                             map_location=lambda storage, loc: storage)
     net_utils_rel.load_ckpt_rel(self, checkpoint['model'])
     # freeze everything above the rel module
     for p in self.Conv_Body.parameters():
         p.requires_grad = False
     for p in self.RPN.parameters():
         p.requires_grad = False
     if not cfg.MODEL.UNFREEZE_DET:
         for p in self.Box_Head.parameters():
             p.requires_grad = False
         for p in self.Box_Outs.parameters():
             p.requires_grad = False
Exemplo n.º 4
0
def initialize_model_from_cfg(args, gpu_id=0):
    """Initialize a model from the global cfg. Loads test-time weights and
    set to evaluation mode.
    """
    model = model_builder_rel.Generalized_RCNN()
    model.eval()

    if args.cuda:
        model.cuda()

    if args.load_ckpt:
        load_name = args.load_ckpt
        logger.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name, map_location=lambda storage, loc: storage)
        net_utils_rel.load_ckpt_rel(model, checkpoint['model'])

    if args.load_detectron:
        logger.info("loading detectron weights %s", args.load_detectron)
        load_detectron_weight(model, args.load_detectron)
    model.RelDN.mix_cent_loss.centroids.data = torch.from_numpy(np.load('/home/wwt/ECCV2020/stage_one/Outputs/mix_centroids_headlimittailgo.npy')).cuda()
    #model.RelDN.mix_cent_loss.centroids.data = torch.zeros((51,1024)).cuda()
    #model.RelDN.mix_cent_loss.centroids.data = (torch.randn((51,1024))*0.08+0.02).cuda()
    model = mynn.DataParallel(model, cpu_keywords=['im_info', 'roidb'], minibatch=True)
    return model
    def _init_modules(self):
        # VGG16 imagenet pretrained model is initialized in VGG16.py
        if cfg.RESNETS.IMAGENET_PRETRAINED_WEIGHTS != '':
            logger.info("Loading pretrained weights from %s",
                        cfg.RESNETS.IMAGENET_PRETRAINED_WEIGHTS)
            resnet_utils.load_pretrained_imagenet_weights(self)
            for p in self.Conv_Body.parameters():
                p.requires_grad = False

        if cfg.RESNETS.VRD_PRETRAINED_WEIGHTS != '':
            self.load_detector_weights(cfg.RESNETS.VRD_PRETRAINED_WEIGHTS)
        if cfg.VGG16.VRD_PRETRAINED_WEIGHTS != '':
            self.load_detector_weights(cfg.VGG16.VRD_PRETRAINED_WEIGHTS)

        if cfg.RESNETS.VG_PRETRAINED_WEIGHTS != '':
            self.load_detector_weights(cfg.RESNETS.VG_PRETRAINED_WEIGHTS)
        if cfg.VGG16.VG_PRETRAINED_WEIGHTS != '':
            self.load_detector_weights(cfg.VGG16.VG_PRETRAINED_WEIGHTS)

        if cfg.RESNETS.OI_REL_PRETRAINED_WEIGHTS != '':
            self.load_detector_weights(cfg.RESNETS.OI_REL_PRETRAINED_WEIGHTS)
        if cfg.VGG16.OI_REL_PRETRAINED_WEIGHTS != '':
            self.load_detector_weights(cfg.VGG16.OI_REL_PRETRAINED_WEIGHTS)

        if cfg.RESNETS.VRD_PRD_PRETRAINED_WEIGHTS != '' or cfg.VGG16.VRD_PRD_PRETRAINED_WEIGHTS != '' or \
            cfg.RESNETS.VG_PRD_PRETRAINED_WEIGHTS != '' or cfg.VGG16.VG_PRD_PRETRAINED_WEIGHTS != '' or \
            cfg.RESNETS.OI_REL_PRD_PRETRAINED_WEIGHTS != '' or cfg.VGG16.OI_REL_PRD_PRETRAINED_WEIGHTS != '':
            if cfg.RESNETS.VRD_PRD_PRETRAINED_WEIGHTS != '':
                logger.info("loading prd pretrained weights from %s",
                            cfg.RESNETS.VRD_PRD_PRETRAINED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.RESNETS.VRD_PRD_PRETRAINED_WEIGHTS,
                    map_location=lambda storage, loc: storage)
            if cfg.VGG16.VRD_PRD_PRETRAINED_WEIGHTS != '':
                logger.info("loading prd pretrained weights from %s",
                            cfg.VGG16.VRD_PRD_PRETRAINED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.VGG16.VRD_PRD_PRETRAINED_WEIGHTS,
                    map_location=lambda storage, loc: storage)
            if cfg.RESNETS.VG_PRD_PRETRAINED_WEIGHTS != '':
                logger.info("loading prd pretrained weights from %s",
                            cfg.RESNETS.VG_PRD_PRETRAINED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.RESNETS.VG_PRD_PRETRAINED_WEIGHTS,
                    map_location=lambda storage, loc: storage)
            if cfg.VGG16.VG_PRD_PRETRAINED_WEIGHTS != '':
                logger.info("loading prd pretrained weights from %s",
                            cfg.VGG16.VG_PRD_PRETRAINED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.VGG16.VG_PRD_PRETRAINED_WEIGHTS,
                    map_location=lambda storage, loc: storage)
            if cfg.RESNETS.OI_REL_PRD_PRETRAINED_WEIGHTS != '':
                logger.info("loading prd pretrained weights from %s",
                            cfg.RESNETS.OI_REL_PRD_PRETRAINED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.RESNETS.OI_REL_PRD_PRETRAINED_WEIGHTS,
                    map_location=lambda storage, loc: storage)
            if cfg.VGG16.OI_REL_PRD_PRETRAINED_WEIGHTS != '':
                logger.info("loading prd pretrained weights from %s",
                            cfg.VGG16.OI_REL_PRD_PRETRAINED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.VGG16.OI_REL_PRD_PRETRAINED_WEIGHTS,
                    map_location=lambda storage, loc: storage)
            # not using the last softmax layers
            del checkpoint['model']['Box_Outs.cls_score.weight']
            del checkpoint['model']['Box_Outs.cls_score.bias']
            del checkpoint['model']['Box_Outs.bbox_pred.weight']
            del checkpoint['model']['Box_Outs.bbox_pred.bias']
            net_utils_rel.load_ckpt_rel(self.Prd_RCNN, checkpoint['model'])
            if cfg.TRAIN.FREEZE_PRD_CONV_BODY:
                for p in self.Prd_RCNN.Conv_Body.parameters():
                    p.requires_grad = False
            if cfg.TRAIN.FREEZE_PRD_BOX_HEAD:
                for p in self.Prd_RCNN.Box_Head.parameters():
                    p.requires_grad = False

        if cfg.RESNETS.TO_BE_FINETUNED_WEIGHTS != '' or cfg.VGG16.TO_BE_FINETUNED_WEIGHTS != '':
            if cfg.RESNETS.TO_BE_FINETUNED_WEIGHTS != '':
                logger.info(
                    "loading trained and to be finetuned weights from %s",
                    cfg.RESNETS.TO_BE_FINETUNED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.RESNETS.TO_BE_FINETUNED_WEIGHTS,
                    map_location=lambda storage, loc: storage)
            if cfg.VGG16.TO_BE_FINETUNED_WEIGHTS != '':
                logger.info(
                    "loading trained and to be finetuned weights from %s",
                    cfg.VGG16.TO_BE_FINETUNED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.VGG16.TO_BE_FINETUNED_WEIGHTS,
                    map_location=lambda storage, loc: storage)
            net_utils_rel.load_ckpt_rel(self, checkpoint['model'])
            for p in self.Conv_Body.parameters():
                p.requires_grad = False
            for p in self.RPN.parameters():
                p.requires_grad = False
            if not cfg.MODEL.UNFREEZE_DET:
                for p in self.Box_Head.parameters():
                    p.requires_grad = False
                for p in self.Box_Outs.parameters():
                    p.requires_grad = False

        if cfg.RESNETS.REL_PRETRAINED_WEIGHTS != '':
            logger.info("loading rel pretrained weights from %s",
                        cfg.RESNETS.REL_PRETRAINED_WEIGHTS)
            checkpoint = torch.load(cfg.RESNETS.REL_PRETRAINED_WEIGHTS,
                                    map_location=lambda storage, loc: storage)
            prd_rcnn_state_dict = {}
            reldn_state_dict = {}
            for name in checkpoint['model']:
                if name.find('Prd_RCNN') >= 0:
                    prd_rcnn_state_dict[name] = checkpoint['model'][name]
                if name.find('RelDN') >= 0:
                    reldn_state_dict[name] = checkpoint['model'][name]
            net_utils_rel.load_ckpt_rel(self.Prd_RCNN, prd_rcnn_state_dict)
            if cfg.TRAIN.FREEZE_PRD_CONV_BODY:
                for p in self.Prd_RCNN.Conv_Body.parameters():
                    p.requires_grad = False
            if cfg.TRAIN.FREEZE_PRD_BOX_HEAD:
                for p in self.Prd_RCNN.Box_Head.parameters():
                    p.requires_grad = False
            del reldn_state_dict['RelDN.prd_cls_scores.weight']
            del reldn_state_dict['RelDN.prd_cls_scores.bias']
            if 'RelDN.prd_sbj_scores.weight' in reldn_state_dict:
                del reldn_state_dict['RelDN.prd_sbj_scores.weight']
            if 'RelDN.prd_sbj_scores.bias' in reldn_state_dict:
                del reldn_state_dict['RelDN.prd_sbj_scores.bias']
            if 'RelDN.prd_obj_scores.weight' in reldn_state_dict:
                del reldn_state_dict['RelDN.prd_obj_scores.weight']
            if 'RelDN.prd_obj_scores.bias' in reldn_state_dict:
                del reldn_state_dict['RelDN.prd_obj_scores.bias']
            if 'RelDN.spt_cls_scores.weight' in reldn_state_dict:
                del reldn_state_dict['RelDN.spt_cls_scores.weight']
            if 'RelDN.spt_cls_scores.bias' in reldn_state_dict:
                del reldn_state_dict['RelDN.spt_cls_scores.bias']
            net_utils_rel.load_ckpt_rel(self.RelDN, reldn_state_dict)
Exemplo n.º 6
0
def main():
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    if args.dataset == "vrd":
        cfg.TRAIN.DATASETS = ('vrd_train', )
        cfg.TEST.DATASETS = ('vrd_val', )
        cfg.MODEL.NUM_CLASSES = 101
        cfg.MODEL.NUM_PRD_CLASSES = 70  # exclude background
    elif args.dataset == "vg_mini":
        cfg.TRAIN.DATASETS = ('vg_train_mini', )
        cfg.MODEL.NUM_CLASSES = 151
        cfg.MODEL.NUM_PRD_CLASSES = 50  # exclude background
    elif args.dataset == "vg":
        cfg.TRAIN.DATASETS = ('vg_train', )
        cfg.MODEL.NUM_CLASSES = 151
        cfg.MODEL.NUM_PRD_CLASSES = 50  # exclude background
    elif args.dataset == "oi_rel":
        cfg.TRAIN.DATASETS = ('oi_rel_train', )
        # cfg.MODEL.NUM_CLASSES = 62
        cfg.MODEL.NUM_CLASSES = 58
        cfg.MODEL.NUM_PRD_CLASSES = 9  # rel, exclude background
    elif args.dataset == "oi_rel_mini":
        cfg.TRAIN.DATASETS = ('oi_rel_train_mini', )
        # cfg.MODEL.NUM_CLASSES = 62
        cfg.MODEL.NUM_CLASSES = 58
        cfg.MODEL.NUM_PRD_CLASSES = 9  # rel, exclude background
    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    Generalized_RCNN = importlib.import_module('modeling_rel.' +
                                               cfg.MODEL.TYPE).Generalized_RCNN
    from core.test_engine_rel_mps import get_metrics_det_boxes, get_metrics_gt_boxes

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    original_num_gpus = cfg.NUM_GPUS

    original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH

    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    effective_batch_size = args.iter_size * args.batch_size
    print('effective_batch_size = batch_size * iter_size = %d * %d' %
          (args.batch_size, args.iter_size))

    print('Adaptive config changes:')
    print('    effective_batch_size: %d --> %d' %
          (original_batch_size, effective_batch_size))
    print('    NUM_GPUS:             %d --> %d' %
          (original_num_gpus, cfg.NUM_GPUS))
    print('    IMS_PER_BATCH:        %d --> %d' %
          (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH))

    ### Adjust learning based on batch size change linearly
    # For iter_size > 1, gradients are `accumulated`, so lr is scaled based
    # on batch_size instead of effective_batch_size
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch_size change:\n'
          '    BASE_LR: {} --> {}'.format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Adjust solver steps
    step_scale = original_batch_size / effective_batch_size
    old_solver_steps = cfg.SOLVER.STEPS
    old_max_iter = cfg.SOLVER.MAX_ITER
    cfg.SOLVER.STEPS = list(
        map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS))
    cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5)
    print(
        'Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n'
        '    SOLVER.STEPS: {} --> {}\n'
        '    SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps,
                                                cfg.SOLVER.STEPS, old_max_iter,
                                                cfg.SOLVER.MAX_ITER))

    # Scale FPN rpn_proposals collect size (post_nms_topN) in `collect` function
    # of `collect_and_distribute_fpn_rpn_proposals.py`
    #
    # post_nms_topN = int(cfg[cfg_key].RPN_POST_NMS_TOP_N * cfg.FPN.RPN_COLLECT_SCALE + 0.5)
    if cfg.FPN.FPN_ON and cfg.MODEL.FASTER_RCNN:
        cfg.FPN.RPN_COLLECT_SCALE = cfg.TRAIN.IMS_PER_BATCH / original_ims_per_batch
        print(
            'Scale FPN rpn_proposals collect size directly propotional to the change of IMS_PER_BATCH:\n'
            '    cfg.FPN.RPN_COLLECT_SCALE: {}'.format(
                cfg.FPN.RPN_COLLECT_SCALE))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma
    assert_and_infer_cfg()

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index, ds = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb',
                timers['roidb'].average_time)

    # Effective training sample size for one epoch
    train_size = roidb_size // args.batch_size * args.batch_size

    batchSampler = BatchSampler(sampler=MinibatchSampler(
        ratio_list, ratio_index),
                                batch_size=args.batch_size,
                                drop_last=True)
    dataset = RoiDataLoader(roidb, cfg.MODEL.NUM_CLASSES, training=True)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=batchSampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    dataiterator = iter(dataloader)

    ### Model ###
    maskRCNN = Generalized_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    # record backbone params, i.e., conv_body and box_head params
    gn_params = []
    backbone_bias_params = []
    backbone_bias_param_names = []
    prd_branch_bias_params = []
    prd_branch_bias_param_names = []
    backbone_nonbias_params = []
    backbone_nonbias_param_names = []
    prd_branch_nonbias_params = []
    prd_branch_nonbias_param_names = []
    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'gn' in key:
                gn_params.append(value)
            elif 'Conv_Body' in key or 'Box_Head' in key or 'Box_Outs' in key or 'RPN' in key:
                if 'bias' in key:
                    backbone_bias_params.append(value)
                    backbone_bias_param_names.append(key)
                else:
                    backbone_nonbias_params.append(value)
                    backbone_nonbias_param_names.append(key)
            else:
                if 'bias' in key:
                    prd_branch_bias_params.append(value)
                    prd_branch_bias_param_names.append(key)
                else:
                    prd_branch_nonbias_params.append(value)
                    prd_branch_nonbias_param_names.append(key)
    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [{
        'params': backbone_nonbias_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        backbone_bias_params,
        'lr':
        0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }, {
        'params': prd_branch_nonbias_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        prd_branch_bias_params,
        'lr':
        0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }, {
        'params': gn_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN
    }]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils_rel.load_ckpt_rel(maskRCNN, checkpoint['model'])
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print(
                        'train_size value: %d different from the one in checkpoint: %d'
                        % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer,
                                                 checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    # lr = optimizer.param_groups[0]['lr']  # lr of non-bias parameters, for commmand line outputs.
    lr = optimizer.param_groups[2][
        'lr']  # lr of non-backbone parameters, for commmand line outputs.
    backbone_lr = optimizer.param_groups[0][
        'lr']  # lr of backbone parameters, for commmand line outputs.

    device_ids = list(range(torch.cuda.device_count()))

    maskRCNN_one_gpu = mynn.DataParallel(maskRCNN,
                                         cpu_keywords=['im_info', 'roidb'],
                                         minibatch=True,
                                         device_ids=[device_ids[0]])

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name(
    ) + '_' + args.exp + '_' + '_step_with_prd_cls_v' + str(cfg.MODEL.SUBTYPE)
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    # CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)
    CHECKPOINT_PERIOD = ds.len // effective_batch_size

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    # metrics = get_metrics_gt_boxes(maskRCNN, timers, cfg.TEST.DATASETS[0])
    # tblogger.add_scalar(args.dataset + '_r@100', metrics, 0)
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):

            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 -
                                                                 alpha) + alpha
                else:
                    raise KeyError(
                        'Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils_rel.update_learning_rate_rel(optimizer, lr, lr_new)
                # lr = optimizer.param_groups[0]['lr']
                lr = optimizer.param_groups[2]['lr']
                backbone_lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils_rel.update_learning_rate_rel(optimizer, lr,
                                                       cfg.SOLVER.BASE_LR)
                # lr = optimizer.param_groups[0]['lr']
                lr = optimizer.param_groups[2]['lr']
                backbone_lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils_rel.update_learning_rate_rel(optimizer, lr, lr_new)
                # lr = optimizer.param_groups[0]['lr']
                lr = optimizer.param_groups[2]['lr']
                backbone_lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()
            for inner_iter in range(args.iter_size):
                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)

                for key in input_data:
                    if key != 'roidb':  # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))

                net_outputs = maskRCNN(**input_data)
                training_stats.UpdateIterStats(net_outputs, inner_iter)
                loss = net_outputs['total_loss']
                loss.backward()
            optimizer.step()
            training_stats.IterToc()
            if step == args.start_step:
                for n, p in maskRCNN.named_parameters():
                    if p.requires_grad == True and p.grad is None:
                        logger.warning('The module was defined but no-use!')
                        logger.warning(n)

            training_stats.LogIterStats(step, lr, backbone_lr)

            if int(step + 1) % CHECKPOINT_PERIOD == 0:
                save_ckpt(output_dir, args, step, train_size, maskRCNN,
                          optimizer)
                metrics = get_metrics_gt_boxes(maskRCNN_one_gpu, timers,
                                               cfg.TEST.DATASETS[0])
                maskRCNN.train()
                tblogger.add_scalar(args.dataset + '_metrics', metrics, step)

        # ---- Training ends ----
        # Save last checkpoint
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        # metrics = get_metrics(maskRCNN, timers, cfg.TEST.DATASETS)
        # tblogger.add_scalar(args.dataset + '_r@100', metrics, step)

    except (RuntimeError, KeyboardInterrupt):
        del dataiterator
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
Exemplo n.º 7
0
 def load_detector_weights(self, weight_name):
     logger.info("loading pretrained weights from %s", weight_name)
     checkpoint = torch.load(weight_name, map_location=lambda storage, loc: storage)
     net_utils_rel.load_ckpt_rel(self, checkpoint['model'])
Exemplo n.º 8
0
    def _init_modules(self):
        # VGG16 imagenet pretrained model is initialized in VGG16.py
        if cfg.RESNETS.IMAGENET_PRETRAINED_WEIGHTS != '':
            logger.info("Loading pretrained weights from %s",
                        cfg.RESNETS.IMAGENET_PRETRAINED_WEIGHTS)
            resnet_utils.load_pretrained_imagenet_weights(self)
            for p in self.Conv_Body.parameters():
                p.requires_grad = False

        if cfg.RESNETS.VRD_PRETRAINED_WEIGHTS != '':
            self.load_detector_weights(cfg.RESNETS.VRD_PRETRAINED_WEIGHTS)
        if cfg.VGG16.VRD_PRETRAINED_WEIGHTS != '':
            self.load_detector_weights(cfg.VGG16.VRD_PRETRAINED_WEIGHTS)

        if cfg.RESNETS.VG_PRETRAINED_WEIGHTS != '':
            self.load_detector_weights(cfg.RESNETS.VG_PRETRAINED_WEIGHTS)
        if cfg.VGG16.VG_PRETRAINED_WEIGHTS != '':
            self.load_detector_weights(cfg.VGG16.VG_PRETRAINED_WEIGHTS)

        if cfg.RESNETS.OI_REL_PRETRAINED_WEIGHTS != '':
            self.load_detector_weights(cfg.RESNETS.OI_REL_PRETRAINED_WEIGHTS)
        if cfg.VGG16.OI_REL_PRETRAINED_WEIGHTS != '':
            self.load_detector_weights(cfg.VGG16.OI_REL_PRETRAINED_WEIGHTS)

        if cfg.RESNETS.VRD_PRD_PRETRAINED_WEIGHTS != '' or cfg.VGG16.VRD_PRD_PRETRAINED_WEIGHTS != '' or \
            cfg.RESNETS.VG_PRD_PRETRAINED_WEIGHTS != '' or cfg.VGG16.VG_PRD_PRETRAINED_WEIGHTS != '' or \
            cfg.RESNETS.OI_REL_PRD_PRETRAINED_WEIGHTS != '' or cfg.VGG16.OI_REL_PRD_PRETRAINED_WEIGHTS != '':
            if cfg.RESNETS.VRD_PRD_PRETRAINED_WEIGHTS != '':
                logger.info("loading prd pretrained weights from %s",
                            cfg.RESNETS.VRD_PRD_PRETRAINED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.RESNETS.VRD_PRD_PRETRAINED_WEIGHTS,
                    map_location=lambda storage, loc: storage)
            if cfg.VGG16.VRD_PRD_PRETRAINED_WEIGHTS != '':
                logger.info("loading prd pretrained weights from %s",
                            cfg.VGG16.VRD_PRD_PRETRAINED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.VGG16.VRD_PRD_PRETRAINED_WEIGHTS,
                    map_location=lambda storage, loc: storage)
            if cfg.RESNETS.VG_PRD_PRETRAINED_WEIGHTS != '':
                logger.info("loading prd pretrained weights from %s",
                            cfg.RESNETS.VG_PRD_PRETRAINED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.RESNETS.VG_PRD_PRETRAINED_WEIGHTS,
                    map_location=lambda storage, loc: storage)
            if cfg.VGG16.VG_PRD_PRETRAINED_WEIGHTS != '':
                logger.info("loading prd pretrained weights from %s",
                            cfg.VGG16.VG_PRD_PRETRAINED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.VGG16.VG_PRD_PRETRAINED_WEIGHTS,
                    map_location=lambda storage, loc: storage)
            if cfg.RESNETS.OI_REL_PRD_PRETRAINED_WEIGHTS != '':
                logger.info("loading prd pretrained weights from %s",
                            cfg.RESNETS.OI_REL_PRD_PRETRAINED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.RESNETS.OI_REL_PRD_PRETRAINED_WEIGHTS,
                    map_location=lambda storage, loc: storage)
            if cfg.VGG16.OI_REL_PRD_PRETRAINED_WEIGHTS != '':
                logger.info("loading prd pretrained weights from %s",
                            cfg.VGG16.OI_REL_PRD_PRETRAINED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.VGG16.OI_REL_PRD_PRETRAINED_WEIGHTS,
                    map_location=lambda storage, loc: storage)

            self.Box_Head_sg.heads[0].weight.data.copy_(
                checkpoint['model']['Box_Head.heads.0.weight'])
            self.Box_Head_sg.heads[0].bias.data.copy_(
                checkpoint['model']['Box_Head.heads.0.bias'])
            self.Box_Head_sg.heads[3].weight.data.copy_(
                checkpoint['model']['Box_Head.heads.3.weight'])
            self.Box_Head_sg.heads[3].bias.data.copy_(
                checkpoint['model']['Box_Head.heads.3.bias'])
            self.Box_Head_prd.heads[0].weight.data.copy_(
                checkpoint['model']['Box_Head.heads.0.weight'])
            self.Box_Head_prd.heads[0].bias.data.copy_(
                checkpoint['model']['Box_Head.heads.0.bias'])
            self.Box_Head_prd.heads[3].weight.data.copy_(
                checkpoint['model']['Box_Head.heads.3.weight'])
            self.Box_Head_prd.heads[3].bias.data.copy_(
                checkpoint['model']['Box_Head.heads.3.bias'])

        if cfg.RESNETS.TO_BE_FINETUNED_WEIGHTS != '' or cfg.VGG16.TO_BE_FINETUNED_WEIGHTS != '':
            if cfg.RESNETS.TO_BE_FINETUNED_WEIGHTS != '':
                logger.info(
                    "loading trained and to be finetuned weights from %s",
                    cfg.RESNETS.TO_BE_FINETUNED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.RESNETS.TO_BE_FINETUNED_WEIGHTS,
                    map_location=lambda storage, loc: storage)
            if cfg.VGG16.TO_BE_FINETUNED_WEIGHTS != '':
                logger.info(
                    "loading trained and to be finetuned weights from %s",
                    cfg.VGG16.TO_BE_FINETUNED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.VGG16.TO_BE_FINETUNED_WEIGHTS,
                    map_location=lambda storage, loc: storage)
            net_utils_rel.load_ckpt_rel(self, checkpoint['model'])
            for p in self.Conv_Body.parameters():
                p.requires_grad = False
            for p in self.RPN.parameters():
                p.requires_grad = False
            if not cfg.MODEL.UNFREEZE_DET:
                for p in self.Box_Head.parameters():
                    p.requires_grad = False
                for p in self.Box_Outs.parameters():
                    p.requires_grad = False