コード例 #1
0
def get_lfb(params_file, is_train):
    """
    Wrapper function for getting an LFB, which is either inferred given a
    baseline model, or loaded from a file.
    """

    if cfg.LFB.LOAD_LFB:
        return load_lfb(is_train)

    assert params_file, 'LFB.MODEL_PARAMS_FILE is not specified.'
    logger.info('Inferring LFB from %s' % params_file)

    cfg.GET_TRAIN_LFB = is_train

    timer = Timer()

    test_model = model_builder_video.ModelBuilder(
        train=False,
        use_cudnn=True,
        cudnn_exhaustive_search=True,
        split=cfg.TEST.DATA_TYPE,
    )

    suffix = 'infer_{}'.format('train' if is_train else 'test')
    test_model.build_model(
        lfb_infer_only=True,
        suffix=suffix,
        shift=1,
    )

    if cfg.PROF_DAG:
        test_model.net.Proto().type = 'prof_dag'
    else:
        test_model.net.Proto().type = 'dag'

    workspace.RunNetOnce(test_model.param_init_net)
    workspace.CreateNet(test_model.net)

    total_test_net_iters = misc.get_total_test_iters(test_model)

    test_model.start_data_loader()

    checkpoints.load_model_from_params_file_for_test(test_model, params_file)

    all_features = []
    all_metadata = []

    for test_iter in range(total_test_net_iters):

        timer.tic()
        workspace.RunNet(test_model.net.Proto().name)
        timer.toc()

        if test_iter == 0:
            misc.print_net(test_model)
            os.system('nvidia-smi')
        if test_iter % 10 == 0:
            logger.info("Iter {}/{} Time: {}".format(test_iter,
                                                     total_test_net_iters,
                                                     timer.diff))

        if cfg.DATASET in ['ava', 'avabox']:
            all_features.append(get_features('box_pooled'))
            all_metadata.append(get_features('metadata{}'.format(suffix)))
        elif cfg.DATASET in ['charades', 'epic']:
            all_features.append(get_features('pool5'))

    lfb = construct_lfb(all_features, all_metadata, test_model.input_db,
                        is_train)

    logger.info("Shutting down data loader...")
    test_model.shutdown_data_loader()

    workspace.ResetWorkspace()
    logger.info("Done ResetWorkspace...")

    cfg.GET_TRAIN_LFB = False

    if cfg.LFB.WRITE_LFB:
        write_lfb(lfb, is_train)

    return lfb
コード例 #2
0
def train(opts):
    workspace.GlobalInit(['caffe2', '--caffe2_log_level=0'])
    logging.getLogger(__name__)

    assert opts.test_net, "opts.test_net == False is not implemented."

    # generate seed
    misc.generate_random_seed(opts)

    # create checkpoint dir
    checkpoint_dir = checkpoints.create_and_get_checkpoint_directory()
    logger.info('Checkpoint directory created: {}'.format(checkpoint_dir))

    # -------------------------------------------------------------------------
    # build test_model
    # we build test_model first, as we don't want to overwrite init (if any)
    # -------------------------------------------------------------------------
    test_model, test_timer, test_meter = create_wrapper(is_train=False)
    total_test_iters = int(
        math.ceil(cfg.TEST.DATASET_SIZE / float(cfg.TEST.BATCH_SIZE)))
    logger.info('Test iters: {}'.format(total_test_iters))

    # -------------------------------------------------------------------------
    # now, build train_model
    # -------------------------------------------------------------------------
    train_model, train_timer, train_meter = create_wrapper(is_train=True)

    # -------------------------------------------------------------------------
    # build the bn auxilary model (BN, always BN!)
    # -------------------------------------------------------------------------
    if cfg.TRAIN.COMPUTE_PRECISE_BN:
        bn_aux = bn_helper.BatchNormHelper()
        bn_aux.create_bn_aux_model(node_id=opts.node_id)

    # resumed from checkpoint or pre-trained file
    # see checkpoints.load_model_from_params_file for more details
    start_model_iter = 0
    if cfg.CHECKPOINT.RESUME or cfg.TRAIN.PARAMS_FILE:
        start_model_iter = checkpoints.load_model_from_params_file(train_model)

    # -------------------------------------------------------------------------
    # now, start training
    # -------------------------------------------------------------------------
    logger.info("------------- Training model... -------------")
    train_meter.reset()
    last_checkpoint = checkpoints.get_checkpoint_resume_file()

    for curr_iter in range(start_model_iter, cfg.SOLVER.MAX_ITER):
        # set lr
        train_model.UpdateWorkspaceLr(curr_iter)

        # do SGD on 1 training mini-batch
        train_timer.tic()
        workspace.RunNet(train_model.net.Proto().name)
        train_timer.toc()

        test_debug = False
        if test_debug is True:
            save_path = 'temp_save/'
            data_blob = workspace.FetchBlob('gpu_0/data')
            label_blob = workspace.FetchBlob('gpu_0/labels')
            label_blob1 = workspace.FetchBlob('gpu_1/labels')
            data_blob = data_blob * cfg.MODEL.STD + cfg.MODEL.MEAN
            print(label_blob)
            print(label_blob1)
            for i in range(data_blob.shape[0]):
                for j in range(data_blob.shape[2]):
                    temp_img = data_blob[i, :, j, :, :]
                    temp_img = temp_img.transpose([1, 2, 0])
                    temp_img = temp_img.astype(np.uint8)
                    fname = save_path + 'ori_' + str(curr_iter) \
                        + '_' + str(i) + '_' + str(j) + '.jpg'
                    cv2.imwrite(fname, temp_img)

        # show info after iter 1
        if curr_iter == start_model_iter:
            misc.print_net(train_model)
            os.system('nvidia-smi')
            misc.show_flops_params(train_model)

        # check nan
        misc.check_nan_losses()

        if (curr_iter + 1) % cfg.CHECKPOINT.CHECKPOINT_PERIOD == 0 \
                or curr_iter + 1 == cfg.SOLVER.MAX_ITER:
            # --------------------------------------------------------
            # we update bn before testing or checkpointing
            if cfg.TRAIN.COMPUTE_PRECISE_BN:
                bn_aux.compute_and_update_bn_stats(curr_iter)
            # --------------------------------------------------------
            last_checkpoint = os.path.join(
                checkpoint_dir, 'c2_model_iter{}.pkl'.format(curr_iter + 1))
            checkpoints.save_model_params(model=train_model,
                                          params_file=last_checkpoint,
                                          model_iter=curr_iter)

        train_meter.calculate_and_log_all_metrics_train(curr_iter, train_timer)

        # --------------------------------------------------------
        # test model
        # --------------------------------------------------------
        if (curr_iter + 1) % cfg.TRAIN.EVAL_PERIOD == 0:
            # we update bn before testing or checkpointing
            if cfg.TRAIN.COMPUTE_PRECISE_BN:
                bn_aux.compute_and_update_bn_stats(curr_iter)

            # start test
            test_meter.reset()
            logger.info("=> Testing model")
            for test_iter in range(0, total_test_iters):
                test_timer.tic()
                workspace.RunNet(test_model.net.Proto().name)
                test_timer.toc()

                test_meter.calculate_and_log_all_metrics_test(
                    test_iter, test_timer, total_test_iters)

            # finishing test
            test_meter.finalize_metrics()
            test_meter.compute_and_log_best()
            test_meter.log_final_metrics(curr_iter)

            # --------------------------------------------------------
            # we finalize and reset train_meter after each test
            train_meter.finalize_metrics()

            json_stats = metrics.get_json_stats_dict(train_meter, test_meter,
                                                     curr_iter)
            misc.log_json_stats(json_stats)

            train_meter.reset()

    if cfg.TRAIN.TEST_AFTER_TRAIN is True:

        # -------------------------------------------------------------------------
        # training finished; test
        # -------------------------------------------------------------------------
        cfg.TEST.PARAMS_FILE = last_checkpoint

        cfg.TEST.OUTPUT_NAME = 'softmax'
        # 10-clip center-crop
        # cfg.TEST.TEST_FULLY_CONV = False
        # test_net()
        # logger.info("10-clip center-crop testing finished")

        # 10-clip spatial fcn
        cfg.TEST.TEST_FULLY_CONV = True
        test_net()
        logger.info("10-clip spatial fcn testing finished")
コード例 #3
0
def train(opts):
    """Train a model."""

    workspace.GlobalInit(['caffe2', '--caffe2_log_level=0'])
    logging.getLogger(__name__)

    # Generate seed.
    misc.generate_random_seed(opts)

    # Create checkpoint dir.
    checkpoint_dir = checkpoints.create_and_get_checkpoint_directory()
    logger.info('Checkpoint directory created: {}'.format(checkpoint_dir))

    # Create tensorborad logger
    tb_writer = SummaryWriter(os.path.join(cfg.CHECKPOINT.DIR, 'tb'))

    # Setting training-time-specific configurations.
    cfg.AVA.FULL_EVAL = cfg.AVA.FULL_EVAL_DURING_TRAINING
    cfg.AVA.DETECTION_SCORE_THRESH = cfg.AVA.DETECTION_SCORE_THRESH_TRAIN
    cfg.CHARADES.NUM_TEST_CLIPS = cfg.CHARADES.NUM_TEST_CLIPS_DURING_TRAINING

    test_lfb, train_lfb = None, None

    if cfg.LFB.ENABLED:
        test_lfb = get_lfb(cfg.LFB.MODEL_PARAMS_FILE, is_train=False)
        train_lfb = get_lfb(cfg.LFB.MODEL_PARAMS_FILE, is_train=True)

    # Build test_model.
    # We build test_model first, so that we don't overwrite init.
    test_model, test_timer, test_meter = create_wrapper(
        is_train=False,
        lfb=test_lfb,
    )
    total_test_iters = misc.get_total_test_iters(test_model)
    logger.info('Test iters: {}'.format(total_test_iters))

    # Build train_model.
    train_model, train_timer, train_meter = create_wrapper(
        is_train=True,
        lfb=train_lfb,
    )

    # Bould BN auxilary model.
    if cfg.TRAIN.COMPUTE_PRECISE_BN:
        bn_aux = bn_helper.BatchNormHelper()
        bn_aux.create_bn_aux_model(node_id=opts.node_id)

    # Load checkpoint or pre-trained weight.
    # See checkpoints.load_model_from_params_file for more details.
    start_model_iter = 0
    if cfg.CHECKPOINT.RESUME or cfg.TRAIN.PARAMS_FILE:
        start_model_iter = checkpoints.load_model_from_params_file(train_model)

    logger.info("------------- Training model... -------------")
    train_meter.reset()
    last_checkpoint = checkpoints.get_checkpoint_resume_file()

    for curr_iter in range(start_model_iter, cfg.SOLVER.MAX_ITER):
        train_model.UpdateWorkspaceLr(curr_iter)

        train_timer.tic()
        # SGD step.
        workspace.RunNet(train_model.net.Proto().name)
        train_timer.toc()

        if curr_iter == start_model_iter:
            misc.print_net(train_model)
            os.system('nvidia-smi')
            misc.show_flops_params(train_model)

        misc.check_nan_losses()

        # Checkpoint.
        if (curr_iter + 1) % cfg.CHECKPOINT.CHECKPOINT_PERIOD == 0 \
                or curr_iter + 1 == cfg.SOLVER.MAX_ITER:
            if cfg.TRAIN.COMPUTE_PRECISE_BN:
                bn_aux.compute_and_update_bn_stats(curr_iter)

            last_checkpoint = os.path.join(
                checkpoint_dir, 'c2_model_iter{}.pkl'.format(curr_iter + 1))
            checkpoints.save_model_params(model=train_model,
                                          params_file=last_checkpoint,
                                          model_iter=curr_iter)

        train_meter.calculate_and_log_all_metrics_train(curr_iter,
                                                        train_timer,
                                                        suffix='_train',
                                                        tb_writer=tb_writer)

        # Evaluation.
        if (curr_iter + 1) % cfg.TRAIN.EVAL_PERIOD == 0:
            if cfg.TRAIN.COMPUTE_PRECISE_BN:
                bn_aux.compute_and_update_bn_stats(curr_iter)

            test_meter.reset()
            logger.info("=> Testing model")
            for test_iter in range(0, total_test_iters):
                test_timer.tic()
                workspace.RunNet(test_model.net.Proto().name)
                test_timer.toc()

                test_meter.calculate_and_log_all_metrics_test(test_iter,
                                                              test_timer,
                                                              total_test_iters,
                                                              suffix='_test')

            test_meter.finalize_metrics(name='iter%d' % (curr_iter + 1))
            test_meter.compute_and_log_best()
            test_meter.log_final_metrics(curr_iter)

            tb_writer.add_scalar('Test/mini_MAP', test_meter.full_map,
                                 curr_iter + 1)

            # Finalize and reset train_meter after test.
            train_meter.finalize_metrics(is_train=True)

            json_stats = metrics.get_json_stats_dict(train_meter, test_meter,
                                                     curr_iter)
            misc.log_json_stats(json_stats)

            train_meter.reset()

    train_model.shutdown_data_loader()
    test_model.shutdown_data_loader()

    if cfg.TRAIN.TEST_AFTER_TRAIN:
        cfg.TEST.PARAMS_FILE = last_checkpoint
        test_net(test_lfb)
コード例 #4
0
def test_net_one_section():
    """
    To save test-time memory, we perform multi-clip test in multiple "sections":
    e.g., 10-clip test can be done in 2 sections of 5-clip test
    """
    timer = Timer()
    results = []
    seen_inds = defaultdict(int)

    logger.warning('Testing started...')  # for monitoring cluster jobs
    test_model = model_builder_video.ModelBuilder(name='{}_test'.format(
        cfg.MODEL.MODEL_NAME),
                                                  train=False,
                                                  use_cudnn=True,
                                                  cudnn_exhaustive_search=True,
                                                  split=cfg.TEST.DATA_TYPE)

    test_model.build_model()

    if cfg.PROF_DAG:
        test_model.net.Proto().type = 'prof_dag'
    else:
        test_model.net.Proto().type = 'dag'

    workspace.RunNetOnce(test_model.param_init_net)
    workspace.CreateNet(test_model.net)

    misc.save_net_proto(test_model.net)
    misc.save_net_proto(test_model.param_init_net)

    total_test_net_iters = int(
        math.ceil(
            float(cfg.TEST.DATASET_SIZE * cfg.TEST.NUM_TEST_CLIPS) /
            cfg.TEST.BATCH_SIZE))

    if cfg.TEST.PARAMS_FILE:
        checkpoints.load_model_from_params_file_for_test(
            test_model, cfg.TEST.PARAMS_FILE)
    else:
        raise Exception('No params files specified for testing model.')

    for test_iter in range(total_test_net_iters):
        timer.tic()
        workspace.RunNet(test_model.net.Proto().name)
        timer.toc()

        if test_iter == 0:
            misc.print_net(test_model)
            os.system('nvidia-smi')

        test_debug = False
        if test_debug is True:
            save_path = 'temp_save/'
            data_blob = workspace.FetchBlob('gpu_0/data')
            label_blob = workspace.FetchBlob('gpu_0/labels')
            print(label_blob)
            data_blob = data_blob * cfg.MODEL.STD + cfg.MODEL.MEAN
            for i in range(data_blob.shape[0]):
                for j in range(4):
                    temp_img = data_blob[i, :, j, :, :]
                    temp_img = temp_img.transpose([1, 2, 0])
                    temp_img = temp_img.astype(np.uint8)
                    fname = save_path + 'ori_' + str(test_iter) \
                        + '_' + str(i) + '_' + str(j) + '.jpg'
                    cv2.imwrite(fname, temp_img)
        """
        When testing, we assume all samples in the same gpu are of the same id
        """
        video_ids_list = []  # for logging
        for gpu_id in range(cfg.NUM_GPUS):
            prefix = 'gpu_{}/'.format(gpu_id)

            softmax_gpu = workspace.FetchBlob(prefix + cfg.TEST.OUTPUT_NAME)
            softmax_gpu = softmax_gpu.reshape((softmax_gpu.shape[0], -1))
            video_id_gpu = workspace.FetchBlob(prefix + 'labels')

            for i in range(len(video_id_gpu)):
                seen_inds[video_id_gpu[i]] += 1

            video_ids_list.append(video_id_gpu[0])
            # print(video_id_gpu)

            # collect results
            for i in range(softmax_gpu.shape[0]):
                probs = softmax_gpu[i].tolist()
                vid = video_id_gpu[i]
                if seen_inds[vid] > cfg.TEST.NUM_TEST_CLIPS:
                    logger.warning('Video id {} have been seen. Skip.'.format(
                        vid, ))
                    continue

                save_pairs = [vid, probs]
                results.append(save_pairs)

        # ---- log
        eta = timer.average_time * (total_test_net_iters - test_iter - 1)
        eta = str(datetime.timedelta(seconds=int(eta)))
        logger.info(('{}/{} iter ({}/{} videos):' +
                     ' Time: {:.3f} (ETA: {}). ID: {}').format(
                         test_iter,
                         total_test_net_iters,
                         len(seen_inds),
                         cfg.TEST.DATASET_SIZE,
                         timer.diff,
                         eta,
                         video_ids_list,
                     ))

    return results
コード例 #5
0
ファイル: feature_loader.py プロジェクト: CV-IP/CRCNN-Action
def load_feature_map(params_file, is_train):
    assert params_file, 'FEATURE_MAP_LOADER.MODEL_PARAMS_FILE is not specified.'
    assert cfg.FEATURE_MAP_LOADER.OUT_DIR, 'FEATURE_MAP_LOADER.OUT_DIR is not specified.'
    logger.info('Inferring feature map from %s' % params_file)

    cfg.FEATURE_MAP_LOADER.ENALBE = True

    cfg.GET_TRAIN_LFB = is_train

    timer = Timer()

    test_model = model_builder_video.ModelBuilder(
        train=False,
        use_cudnn=True,
        cudnn_exhaustive_search=True,
        split=cfg.TEST.DATA_TYPE,
    )

    suffix = 'infer_{}'.format('train' if is_train else 'test')

    if cfg.LFB.ENABLED:
        lfb_path = os.path.join(cfg.LFB.LOAD_LFB_PATH,
                                'train_lfb.pkl' if is_train else 'val_lfb.pkl')
        logger.info('Loading LFB from %s' % lfb_path)
        with open(lfb_path, 'r') as f:
            lfb = pickle.load(f)

        test_model.build_model(
            lfb=lfb,
            suffix=suffix,
            shift=1,
        )

    else:
        test_model.build_model(
            lfb=None,
            suffix=suffix,
            shift=1,
        )

    if cfg.PROF_DAG:
        test_model.net.Proto().type = 'prof_dag'
    else:
        test_model.net.Proto().type = 'dag'

    workspace.RunNetOnce(test_model.param_init_net)
    workspace.CreateNet(test_model.net)

    total_test_net_iters = misc.get_total_test_iters(test_model)

    test_model.start_data_loader()

    checkpoints.load_model_from_params_file_for_test(test_model, params_file)

    all_features = {}
    for feat_name in cfg.FEATURE_MAP_LOADER.NAME_LIST:
        all_features[feat_name] = []

    all_metadata = []

    all_labels = []
    all_proposals = []
    all_original_boxes = []

    if cfg.FEATURE_MAP_LOADER.TEST_ITERS > 0:
        total_test_net_iters = cfg.FEATURE_MAP_LOADER.TEST_ITERS

    for test_iter in range(total_test_net_iters):

        timer.tic()
        workspace.RunNet(test_model.net.Proto().name)
        timer.toc()

        if test_iter == 0:
            misc.print_net(test_model)
            os.system('nvidia-smi')
        if test_iter % 10 == 0:
            logger.info("Iter {}/{} Time: {}".format(test_iter,
                                                     total_test_net_iters,
                                                     timer.diff))

        if cfg.DATASET == "ava":
            for feat_name in cfg.FEATURE_MAP_LOADER.NAME_LIST:
                all_features[feat_name].append(get_features(feat_name))

            all_metadata.append(get_features('metadata{}'.format(suffix)))

            all_labels.append(get_features('labels{}'.format(suffix)))
            all_proposals.append(get_features('proposals{}'.format(suffix)))
            all_original_boxes.append(
                get_features('original_boxes{}'.format(suffix)))

#         elif cfg.DATASET in ['charades', 'epic']:
#             all_features.append(get_features('pool5'))
        else:
            raise Exception("Dataset {} not recognized.".format(cfg.DATASET))

    lfb = construct_lfb(all_features, all_metadata, all_labels, all_proposals,
                        all_original_boxes, test_model.input_db, is_train)

    write_lfb(lfb, is_train)

    logger.info("Shutting down data loader...")
    test_model.shutdown_data_loader()

    workspace.ResetWorkspace()
    logger.info("Done ResetWorkspace...")

    cfg.GET_TRAIN_LFB = False
コード例 #6
0
ファイル: test_net.py プロジェクト: CV-IP/CRCNN-Action
def test_one_crop(lfb=None, suffix='', shift=None):
    """Test one crop."""
    workspace.GlobalInit(['caffe2', '--caffe2_log_level=0'])
    np.random.seed(cfg.RNG_SEED)

    cfg.AVA.FULL_EVAL = True

    if lfb is None and cfg.LFB.ENABLED:
        print_cfg()
        lfb = get_lfb(cfg.LFB.MODEL_PARAMS_FILE, is_train=False)

    print_cfg()

    workspace.ResetWorkspace()
    logger.info("Done ResetWorkspace...")

    timer = Timer()

    logger.warning('Testing started...')  # for monitoring cluster jobs

    if shift is None:
        shift = cfg.TEST.CROP_SHIFT
    test_model = model_builder_video.ModelBuilder(train=False,
                                                  use_cudnn=True,
                                                  cudnn_exhaustive_search=True,
                                                  split=cfg.TEST.DATA_TYPE)

    test_model.build_model(lfb=lfb, suffix=suffix, shift=shift)

    if cfg.PROF_DAG:
        test_model.net.Proto().type = 'prof_dag'
    else:
        test_model.net.Proto().type = 'dag'

    workspace.RunNetOnce(test_model.param_init_net)
    workspace.CreateNet(test_model.net)

    misc.save_net_proto(test_model.net)
    misc.save_net_proto(test_model.param_init_net)

    total_test_net_iters = misc.get_total_test_iters(test_model)

    test_model.start_data_loader()
    test_meter = metrics.MetricsCalculator(
        model=test_model,
        split=cfg.TEST.DATA_TYPE,
        video_idx_to_name=test_model.input_db._video_idx_to_name,
        total_num_boxes=(test_model.input_db._num_boxes_used
                         if cfg.DATASET in ['ava', 'avabox'] else None))

    if cfg.TEST.PARAMS_FILE:
        checkpoints.load_model_from_params_file_for_test(
            test_model, cfg.TEST.PARAMS_FILE)
    else:
        raise Exception('No params files specified for testing model.')

    begin_time = time.time()

    for test_iter in range(total_test_net_iters):
        timer.tic()
        workspace.RunNet(test_model.net.Proto().name)
        timer.toc()

        if test_iter == 0:
            misc.print_net(test_model)
            os.system('nvidia-smi')
            misc.show_flops_params(test_model)

        test_meter.calculate_and_log_all_metrics_test(test_iter, timer,
                                                      total_test_net_iters,
                                                      suffix)

    logger.info('TTTTTTTIME: {}'.format(time.time() - begin_time))

    test_meter.finalize_metrics(name=get_test_name(shift))
    test_meter.log_final_metrics(test_iter, total_test_net_iters)
    test_model.shutdown_data_loader()
コード例 #7
0
def test_net_one_section(full_label_fname=None, store_vis=False):
    """
    To save test-time memory, we perform multi-clip test in multiple
    "sections":
    e.g., 10-clip test can be done in 2 sections of 5-clip test
    Args:
        full_label_id: If set uses this LMDB file, and assumes the full labels
            are being provided
        store_vis: Store visualization of what the model learned, CAM
            style stuff
    """
    timer = Timer()
    results = []
    seen_inds = defaultdict(int)

    logger.warning('Testing started...')  # for monitoring cluster jobs
    test_model = model_builder_video.ModelBuilder(
        name='{}_test'.format(cfg.MODEL.MODEL_NAME),
        train=False,
        use_cudnn=True,
        cudnn_exhaustive_search=True,
        split=cfg.TEST.DATA_TYPE,
        split_dir_name=(full_label_fname if full_label_fname is not None else
                        cfg.TEST.DATA_TYPE))

    test_model.build_model()

    if cfg.PROF_DAG:
        test_model.net.Proto().type = 'prof_dag'
    else:
        test_model.net.Proto().type = 'dag'

    workspace.RunNetOnce(test_model.param_init_net)
    workspace.CreateNet(test_model.net)

    misc.save_net_proto(test_model.net)
    misc.save_net_proto(test_model.param_init_net)

    total_test_net_iters = int(
        math.ceil(
            float(cfg.TEST.DATASET_SIZE * cfg.TEST.NUM_TEST_CLIPS) /
            cfg.TEST.BATCH_SIZE))

    if cfg.TEST.PARAMS_FILE:
        checkpoints.load_model_from_params_file_for_test(
            test_model, cfg.TEST.PARAMS_FILE)
    else:
        cfg.TEST.PARAMS_FILE = checkpoints.get_checkpoint_resume_file()
        checkpoints.load_model_from_params_file_for_test(
            test_model, cfg.TEST.PARAMS_FILE)
        logging.info('No params file specified for testing but found the last '
                     'trained one {}'.format(cfg.TEST.PARAMS_FILE))
        # raise Exception('No params files specified for testing model.')

    for test_iter in range(total_test_net_iters):
        timer.tic()
        workspace.RunNet(test_model.net.Proto().name)
        timer.toc()

        if test_iter == 0:
            misc.print_net(test_model)
            os.system('nvidia-smi')

        test_debug = False
        if test_debug is True:
            save_path = 'temp_save/'
            data_blob = workspace.FetchBlob('gpu_0/data')
            label_blob = workspace.FetchBlob('gpu_0/labels')
            print(label_blob)
            data_blob = data_blob * cfg.MODEL.STD + cfg.MODEL.MEAN
            for i in range(data_blob.shape[0]):
                for j in range(4):
                    temp_img = data_blob[i, :, j, :, :]
                    temp_img = temp_img.transpose([1, 2, 0])
                    temp_img = temp_img.astype(np.uint8)
                    fname = save_path + 'ori_' + str(test_iter) \
                        + '_' + str(i) + '_' + str(j) + '.jpg'
                    cv2.imwrite(fname, temp_img)
        """
        When testing, we assume all samples in the same gpu are of the same id.
        ^ This comment is from the original code. Anyway not sure why it should
        be the case.. we are extracting out the labels for each element of the
        batch anyway... Where is this assumption being used?
        ^ Checked with Xiaolong, ignore this.
        """
        video_ids_list = []  # for logging
        for gpu_id in range(cfg.NUM_GPUS):
            prefix = 'gpu_{}/'.format(gpu_id)

            # Note that this is called softmax_gpu, but could also be
            # sigmoid.
            softmax_gpu = workspace.FetchBlob(prefix + 'activation')
            softmax_gpu = softmax_gpu.reshape((softmax_gpu.shape[0], -1))
            # Mean the fc7 over time and space, to get a compact feature
            # This has already been passed through AvgPool op, but might not
            # have averaged all the way
            fc7 = np.mean(workspace.FetchBlob(prefix + 'fc7'),
                          axis=(-1, -2, -3))
            # IMP! The label blob at test time contains the "index" to the
            # video, and not the video class. This is how the lmdb gen scripts
            # are set up. @xiaolonw needs it to get predictions for each video
            # and then re-reads the label file to get the actual class labels
            # to compute the test accuracy.
            video_id_gpu = workspace.FetchBlob(prefix + 'labels')
            temporal_crop_id = [None] * len(video_id_gpu)
            spatial_crop_id = [None] * len(video_id_gpu)
            if full_label_fname is not None:
                video_id_gpu, temporal_crop_id, spatial_crop_id = (
                    label_id_to_parts(video_id_gpu))

            for i in range(len(video_id_gpu)):
                seen_inds[video_id_gpu[i]] += 1

            video_ids_list.append(video_id_gpu[0])
            # print(video_id_gpu)

            if store_vis:
                save_dir = osp.join(cfg.CHECKPOINT.DIR,
                                    'vis_{}'.format(full_label_fname))
                data_blob = workspace.FetchBlob(prefix + 'data')
                label_blob = workspace.FetchBlob(prefix + 'labels')
                fc7_full = workspace.FetchBlob(prefix + 'fc7_beforeAvg')
                data_blob = data_blob * cfg.MODEL.STD + cfg.MODEL.MEAN
                for i in range(data_blob.shape[0]):
                    if temporal_crop_id[i] != 0 or spatial_crop_id[i] != 1:
                        # Only visualizing the first center clip
                        continue
                    gen_store_vis(frames=data_blob[i],
                                  fc7_feats=fc7_full[i],
                                  outfpath=osp.join(save_dir,
                                                    str(video_id_gpu[i])))

            # collect results
            for i in range(softmax_gpu.shape[0]):
                probs = softmax_gpu[i].tolist()
                vid = video_id_gpu[i]
                if seen_inds[vid] > cfg.TEST.NUM_TEST_CLIPS:
                    logger.warning('Video id {} have been seen. Skip.'.format(
                        vid, ))
                    continue

                save_pairs = [
                    vid, probs, temporal_crop_id[i], spatial_crop_id[i], fc7[i]
                ]
                results.append(save_pairs)

        # ---- log
        eta = timer.average_time * (total_test_net_iters - test_iter - 1)
        eta = str(datetime.timedelta(seconds=int(eta)))
        logger.info(('{}/{} iter ({}/{} videos):' +
                     ' Time: {:.3f} (ETA: {}). ID: {}').format(
                         test_iter,
                         total_test_net_iters,
                         len(seen_inds),
                         cfg.TEST.DATASET_SIZE,
                         timer.diff,
                         eta,
                         video_ids_list,
                     ))

    return results