예제 #1
0
def get_lfb(params_file, is_train):
    """
    Wrapper function for getting an LFB, which is either inferred given a
    baseline model, or loaded from a file.
    """

    if cfg.LFB.LOAD_LFB:
        return load_lfb(is_train)

    assert params_file, 'LFB.MODEL_PARAMS_FILE is not specified.'
    logger.info('Inferring LFB from %s' % params_file)

    cfg.GET_TRAIN_LFB = is_train

    timer = Timer()

    test_model = model_builder_video.ModelBuilder(
        train=False,
        use_cudnn=True,
        cudnn_exhaustive_search=True,
        split=cfg.TEST.DATA_TYPE,
    )

    suffix = 'infer_{}'.format('train' if is_train else 'test')
    test_model.build_model(
        lfb_infer_only=True,
        suffix=suffix,
        shift=1,
    )

    if cfg.PROF_DAG:
        test_model.net.Proto().type = 'prof_dag'
    else:
        test_model.net.Proto().type = 'dag'

    workspace.RunNetOnce(test_model.param_init_net)
    workspace.CreateNet(test_model.net)

    total_test_net_iters = misc.get_total_test_iters(test_model)

    test_model.start_data_loader()

    checkpoints.load_model_from_params_file_for_test(test_model, params_file)

    all_features = []
    all_metadata = []

    for test_iter in range(total_test_net_iters):

        timer.tic()
        workspace.RunNet(test_model.net.Proto().name)
        timer.toc()

        if test_iter == 0:
            misc.print_net(test_model)
            os.system('nvidia-smi')
        if test_iter % 10 == 0:
            logger.info("Iter {}/{} Time: {}".format(test_iter,
                                                     total_test_net_iters,
                                                     timer.diff))

        if cfg.DATASET in ['ava', 'avabox']:
            all_features.append(get_features('box_pooled'))
            all_metadata.append(get_features('metadata{}'.format(suffix)))
        elif cfg.DATASET in ['charades', 'epic']:
            all_features.append(get_features('pool5'))

    lfb = construct_lfb(all_features, all_metadata, test_model.input_db,
                        is_train)

    logger.info("Shutting down data loader...")
    test_model.shutdown_data_loader()

    workspace.ResetWorkspace()
    logger.info("Done ResetWorkspace...")

    cfg.GET_TRAIN_LFB = False

    if cfg.LFB.WRITE_LFB:
        write_lfb(lfb, is_train)

    return lfb
예제 #2
0
def load_feature_map(params_file, is_train):
    assert params_file, 'FEATURE_MAP_LOADER.MODEL_PARAMS_FILE is not specified.'
    assert cfg.FEATURE_MAP_LOADER.OUT_DIR, 'FEATURE_MAP_LOADER.OUT_DIR is not specified.'
    logger.info('Inferring feature map from %s' % params_file)

    cfg.FEATURE_MAP_LOADER.ENALBE = True

    cfg.GET_TRAIN_LFB = is_train

    timer = Timer()

    test_model = model_builder_video.ModelBuilder(
        train=False,
        use_cudnn=True,
        cudnn_exhaustive_search=True,
        split=cfg.TEST.DATA_TYPE,
    )

    suffix = 'infer_{}'.format('train' if is_train else 'test')

    if cfg.LFB.ENABLED:
        lfb_path = os.path.join(cfg.LFB.LOAD_LFB_PATH,
                                'train_lfb.pkl' if is_train else 'val_lfb.pkl')
        logger.info('Loading LFB from %s' % lfb_path)
        with open(lfb_path, 'r') as f:
            lfb = pickle.load(f)

        test_model.build_model(
            lfb=lfb,
            suffix=suffix,
            shift=1,
        )

    else:
        test_model.build_model(
            lfb=None,
            suffix=suffix,
            shift=1,
        )

    if cfg.PROF_DAG:
        test_model.net.Proto().type = 'prof_dag'
    else:
        test_model.net.Proto().type = 'dag'

    workspace.RunNetOnce(test_model.param_init_net)
    workspace.CreateNet(test_model.net)

    total_test_net_iters = misc.get_total_test_iters(test_model)

    test_model.start_data_loader()

    checkpoints.load_model_from_params_file_for_test(test_model, params_file)

    all_features = {}
    for feat_name in cfg.FEATURE_MAP_LOADER.NAME_LIST:
        all_features[feat_name] = []

    all_metadata = []

    all_labels = []
    all_proposals = []
    all_original_boxes = []

    if cfg.FEATURE_MAP_LOADER.TEST_ITERS > 0:
        total_test_net_iters = cfg.FEATURE_MAP_LOADER.TEST_ITERS

    for test_iter in range(total_test_net_iters):

        timer.tic()
        workspace.RunNet(test_model.net.Proto().name)
        timer.toc()

        if test_iter == 0:
            misc.print_net(test_model)
            os.system('nvidia-smi')
        if test_iter % 10 == 0:
            logger.info("Iter {}/{} Time: {}".format(test_iter,
                                                     total_test_net_iters,
                                                     timer.diff))

        if cfg.DATASET == "ava":
            for feat_name in cfg.FEATURE_MAP_LOADER.NAME_LIST:
                all_features[feat_name].append(get_features(feat_name))

            all_metadata.append(get_features('metadata{}'.format(suffix)))

            all_labels.append(get_features('labels{}'.format(suffix)))
            all_proposals.append(get_features('proposals{}'.format(suffix)))
            all_original_boxes.append(
                get_features('original_boxes{}'.format(suffix)))

#         elif cfg.DATASET in ['charades', 'epic']:
#             all_features.append(get_features('pool5'))
        else:
            raise Exception("Dataset {} not recognized.".format(cfg.DATASET))

    lfb = construct_lfb(all_features, all_metadata, all_labels, all_proposals,
                        all_original_boxes, test_model.input_db, is_train)

    write_lfb(lfb, is_train)

    logger.info("Shutting down data loader...")
    test_model.shutdown_data_loader()

    workspace.ResetWorkspace()
    logger.info("Done ResetWorkspace...")

    cfg.GET_TRAIN_LFB = False
예제 #3
0
def train(opts):
    """Train a model."""

    workspace.GlobalInit(['caffe2', '--caffe2_log_level=0'])
    logging.getLogger(__name__)

    # Generate seed.
    misc.generate_random_seed(opts)

    # Create checkpoint dir.
    checkpoint_dir = checkpoints.create_and_get_checkpoint_directory()
    logger.info('Checkpoint directory created: {}'.format(checkpoint_dir))

    # Create tensorborad logger
    tb_writer = SummaryWriter(os.path.join(cfg.CHECKPOINT.DIR, 'tb'))

    # Setting training-time-specific configurations.
    cfg.AVA.FULL_EVAL = cfg.AVA.FULL_EVAL_DURING_TRAINING
    cfg.AVA.DETECTION_SCORE_THRESH = cfg.AVA.DETECTION_SCORE_THRESH_TRAIN
    cfg.CHARADES.NUM_TEST_CLIPS = cfg.CHARADES.NUM_TEST_CLIPS_DURING_TRAINING

    test_lfb, train_lfb = None, None

    if cfg.LFB.ENABLED:
        test_lfb = get_lfb(cfg.LFB.MODEL_PARAMS_FILE, is_train=False)
        train_lfb = get_lfb(cfg.LFB.MODEL_PARAMS_FILE, is_train=True)

    # Build test_model.
    # We build test_model first, so that we don't overwrite init.
    test_model, test_timer, test_meter = create_wrapper(
        is_train=False,
        lfb=test_lfb,
    )
    total_test_iters = misc.get_total_test_iters(test_model)
    logger.info('Test iters: {}'.format(total_test_iters))

    # Build train_model.
    train_model, train_timer, train_meter = create_wrapper(
        is_train=True,
        lfb=train_lfb,
    )

    # Bould BN auxilary model.
    if cfg.TRAIN.COMPUTE_PRECISE_BN:
        bn_aux = bn_helper.BatchNormHelper()
        bn_aux.create_bn_aux_model(node_id=opts.node_id)

    # Load checkpoint or pre-trained weight.
    # See checkpoints.load_model_from_params_file for more details.
    start_model_iter = 0
    if cfg.CHECKPOINT.RESUME or cfg.TRAIN.PARAMS_FILE:
        start_model_iter = checkpoints.load_model_from_params_file(train_model)

    logger.info("------------- Training model... -------------")
    train_meter.reset()
    last_checkpoint = checkpoints.get_checkpoint_resume_file()

    for curr_iter in range(start_model_iter, cfg.SOLVER.MAX_ITER):
        train_model.UpdateWorkspaceLr(curr_iter)

        train_timer.tic()
        # SGD step.
        workspace.RunNet(train_model.net.Proto().name)
        train_timer.toc()

        if curr_iter == start_model_iter:
            misc.print_net(train_model)
            os.system('nvidia-smi')
            misc.show_flops_params(train_model)

        misc.check_nan_losses()

        # Checkpoint.
        if (curr_iter + 1) % cfg.CHECKPOINT.CHECKPOINT_PERIOD == 0 \
                or curr_iter + 1 == cfg.SOLVER.MAX_ITER:
            if cfg.TRAIN.COMPUTE_PRECISE_BN:
                bn_aux.compute_and_update_bn_stats(curr_iter)

            last_checkpoint = os.path.join(
                checkpoint_dir, 'c2_model_iter{}.pkl'.format(curr_iter + 1))
            checkpoints.save_model_params(model=train_model,
                                          params_file=last_checkpoint,
                                          model_iter=curr_iter)

        train_meter.calculate_and_log_all_metrics_train(curr_iter,
                                                        train_timer,
                                                        suffix='_train',
                                                        tb_writer=tb_writer)

        # Evaluation.
        if (curr_iter + 1) % cfg.TRAIN.EVAL_PERIOD == 0:
            if cfg.TRAIN.COMPUTE_PRECISE_BN:
                bn_aux.compute_and_update_bn_stats(curr_iter)

            test_meter.reset()
            logger.info("=> Testing model")
            for test_iter in range(0, total_test_iters):
                test_timer.tic()
                workspace.RunNet(test_model.net.Proto().name)
                test_timer.toc()

                test_meter.calculate_and_log_all_metrics_test(test_iter,
                                                              test_timer,
                                                              total_test_iters,
                                                              suffix='_test')

            test_meter.finalize_metrics(name='iter%d' % (curr_iter + 1))
            test_meter.compute_and_log_best()
            test_meter.log_final_metrics(curr_iter)

            tb_writer.add_scalar('Test/mini_MAP', test_meter.full_map,
                                 curr_iter + 1)

            # Finalize and reset train_meter after test.
            train_meter.finalize_metrics(is_train=True)

            json_stats = metrics.get_json_stats_dict(train_meter, test_meter,
                                                     curr_iter)
            misc.log_json_stats(json_stats)

            train_meter.reset()

    train_model.shutdown_data_loader()
    test_model.shutdown_data_loader()

    if cfg.TRAIN.TEST_AFTER_TRAIN:
        cfg.TEST.PARAMS_FILE = last_checkpoint
        test_net(test_lfb)
예제 #4
0
def test_one_crop(lfb=None, suffix='', shift=None):
    """Test one crop."""
    workspace.GlobalInit(['caffe2', '--caffe2_log_level=0'])
    np.random.seed(cfg.RNG_SEED)

    cfg.AVA.FULL_EVAL = True

    if lfb is None and cfg.LFB.ENABLED:
        print_cfg()
        lfb = get_lfb(cfg.LFB.MODEL_PARAMS_FILE, is_train=False)

    print_cfg()

    workspace.ResetWorkspace()
    logger.info("Done ResetWorkspace...")

    timer = Timer()

    logger.warning('Testing started...')  # for monitoring cluster jobs

    if shift is None:
        shift = cfg.TEST.CROP_SHIFT
    test_model = model_builder_video.ModelBuilder(train=False,
                                                  use_cudnn=True,
                                                  cudnn_exhaustive_search=True,
                                                  split=cfg.TEST.DATA_TYPE)

    test_model.build_model(lfb=lfb, suffix=suffix, shift=shift)

    if cfg.PROF_DAG:
        test_model.net.Proto().type = 'prof_dag'
    else:
        test_model.net.Proto().type = 'dag'

    workspace.RunNetOnce(test_model.param_init_net)
    workspace.CreateNet(test_model.net)

    misc.save_net_proto(test_model.net)
    misc.save_net_proto(test_model.param_init_net)

    total_test_net_iters = misc.get_total_test_iters(test_model)

    test_model.start_data_loader()
    test_meter = metrics.MetricsCalculator(
        model=test_model,
        split=cfg.TEST.DATA_TYPE,
        video_idx_to_name=test_model.input_db._video_idx_to_name,
        total_num_boxes=(test_model.input_db._num_boxes_used
                         if cfg.DATASET in ['ava', 'avabox'] else None))

    if cfg.TEST.PARAMS_FILE:
        checkpoints.load_model_from_params_file_for_test(
            test_model, cfg.TEST.PARAMS_FILE)
    else:
        raise Exception('No params files specified for testing model.')

    begin_time = time.time()

    for test_iter in range(total_test_net_iters):
        timer.tic()
        workspace.RunNet(test_model.net.Proto().name)
        timer.toc()

        if test_iter == 0:
            misc.print_net(test_model)
            os.system('nvidia-smi')
            misc.show_flops_params(test_model)

        test_meter.calculate_and_log_all_metrics_test(test_iter, timer,
                                                      total_test_net_iters,
                                                      suffix)

    logger.info('TTTTTTTIME: {}'.format(time.time() - begin_time))

    test_meter.finalize_metrics(name=get_test_name(shift))
    test_meter.log_final_metrics(test_iter, total_test_net_iters)
    test_model.shutdown_data_loader()