Esempio n. 1
0
def main():
    assert tf.executing_eagerly(), 'Only eager mode is supported.'
    assert args.config is not None, 'You need to pass in model config file path'
    assert args.data is not None, 'You need to pass in episode config file path'
    assert args.env is not None, 'You need to pass in environ config file path'
    assert args.tag is not None, 'You need to specify a tag'

    log.info('Command line args {}'.format(args))
    config = get_config(args.config, ExperimentConfig)
    data_config = get_config(args.data, EpisodeConfig)
    env_config = get_config(args.env, EnvironmentConfig)
    log.info('Model: \n{}'.format(config))
    log.info('Data episode: \n{}'.format(data_config))
    log.info('Environment: \n{}'.format(env_config))
    config.num_classes = data_config.nway  # Assign num classes.
    config.num_steps = data_config.maxlen
    config.memory_net_config.max_classes = data_config.nway
    config.memory_net_config.max_stages = data_config.nstage
    config.memory_net_config.max_items = data_config.maxlen
    config.oml_config.num_classes = data_config.nway
    config.fix_unknown = data_config.fix_unknown  # Assign fix unknown ID.
    log.info('Number of classes {}'.format(data_config.nway))
    if 'SLURM_JOB_ID' in os.environ:
        log.info('SLURM job ID: {}'.format(os.environ['SLURM_JOB_ID']))

    # Create save folder.
    save_folder = os.path.join(env_config.results, env_config.dataset,
                               args.tag)
    results_file = os.path.join(save_folder, 'results.pkl')
    logfile = os.path.join(save_folder, 'results.csv')

    # To get CNN features from.
    model = build_pretrain_net(config)
    if args.rnn:
        rnn_memory = get_module("lstm",
                                "lstm",
                                model.backbone.get_output_dimension()[0],
                                config.lstm_config.hidden_dim,
                                layernorm=config.lstm_config.layernorm,
                                dtype=tf.float32)
        memory = RNNEncoder(
            "proto_plus_rnn_ssl_v4",
            rnn_memory,
            readout_type=config.mann_config.readout_type,
            use_pred_beta_gamma=config.hybrid_config.use_pred_beta_gamma,
            use_feature_fuse=config.hybrid_config.use_feature_fuse,
            use_feature_fuse_gate=config.hybrid_config.use_feature_fuse_gate,
            use_feature_scaling=config.hybrid_config.use_feature_scaling,
            use_feature_memory_only=config.hybrid_config.
            use_feature_memory_only,
            skip_unk_memory_update=config.hybrid_config.skip_unk_memory_update,
            use_ssl=config.hybrid_config.use_ssl,
            use_ssl_beta_gamma_write=config.hybrid_config.
            use_ssl_beta_gamma_write,
            use_ssl_temp=config.hybrid_config.use_ssl_temp,
            dtype=tf.float32)
        rnn_model = RNNEncoderNet(config, model.backbone, memory)
        f = lambda x: rnn_model.forward(x, is_training=False)  # NOQA
    else:
        f = lambda x: model.backbone(x[0], is_training=False)[None, :, :
                                                              ]  # NOQA

    K = config.num_classes
    # Get dataset.
    dataset = get_data_fs(env_config, load_train=True)

    # Get data iterators.
    if env_config.dataset in ["matterport"]:
        data = get_dataiter_sim(
            dataset,
            data_config,
            batch_size=config.optimizer_config.batch_size,
            nchw=model.backbone.config.data_format == 'NCHW',
            seed=args.seed)
    else:
        data = get_dataiter_continual(
            dataset,
            data_config,
            batch_size=config.optimizer_config.batch_size,
            nchw=model.backbone.config.data_format == 'NCHW',
            save_additional_info=True,
            random_box=data_config.random_box,
            seed=args.seed)
    if os.path.exists(results_file) and args.reeval:
        # Re-display results.
        results_all = pkl.load(open(results_file, 'rb'))
        for split, name in zip(['trainval_fs', 'val_fs', 'test_fs'],
                               ['Train', 'Val', 'Test']):
            stats = get_stats(results_all[split], tmax=data_config.maxlen)
            log_results(stats, prefix=name, filename=logfile)
    else:
        latest = latest_file(save_folder, 'weights-')

        if args.rnn:
            rnn_model.load(latest)
        else:
            model.load(latest)

        data['trainval_fs'].reset()
        data['val_fs'].reset()
        data['test_fs'].reset()

        results_all = {}
        if args.testonly:
            split_list = ['test_fs']
            name_list = ['Test']
            nepisode_list = [config.num_episodes]
        else:
            split_list = ['trainval_fs', 'val_fs', 'test_fs']
            name_list = ['Train', 'Val', 'Test']
            nepisode_list = [600, config.num_episodes, config.num_episodes]

        for split, name, N in zip(split_list, name_list, nepisode_list):
            data[split].reset()
            r1 = evaluate(f, K, data[split], N)
            stats = get_stats(r1, tmax=data_config.maxlen)
            log_results(stats, prefix=name, filename=logfile)
            results_all[split] = r1
        pkl.dump(results_all, open(results_file, 'wb'))
Esempio n. 2
0
def main():
    assert tf.executing_eagerly(), 'Only eager mode is supported.'
    assert args.config is not None, 'You need to pass in model config file path'
    assert args.data is not None, 'You need to pass in episode config file path'
    assert args.env is not None, 'You need to pass in environ config file path'
    assert args.tag is not None, 'You need to specify a tag'

    log.info('Command line args {}'.format(args))
    config = get_config(args.config, ExperimentConfig)
    data_config = get_config(args.data, EpisodeConfig)
    env_config = get_config(args.env, EnvironmentConfig)
    log.info('Model: \n{}'.format(config))
    log.info('Data episode: \n{}'.format(data_config))
    log.info('Environment: \n{}'.format(env_config))
    config.num_classes = data_config.nway  # Assign num classes.
    config.num_steps = data_config.maxlen
    config.memory_net_config.max_classes = data_config.nway
    config.memory_net_config.max_stages = data_config.nstage
    config.memory_net_config.max_items = data_config.maxlen
    config.oml_config.num_classes = data_config.nway
    config.fix_unknown = data_config.fix_unknown  # Assign fix unknown ID.
    log.info('Number of classes {}'.format(data_config.nway))

    if 'SLURM_JOB_ID' in os.environ:
        log.info('SLURM job ID: {}'.format(os.environ['SLURM_JOB_ID']))

    # Create save folder.
    save_folder = os.path.join(env_config.results, env_config.dataset,
                               args.tag)

    if not args.reeval:
        model = build_pretrain_net(config)
        mem_model = build_net(config, backbone=model.backbone)
        reload_flag = None
        restore_steps = 0
        # Checkpoint folder.
        ckpt_path = env_config.checkpoint
        if len(ckpt_path) > 0 and os.path.exists(ckpt_path):
            ckpt_folder = os.path.join(ckpt_path, os.environ['SLURM_JOB_ID'])
        else:
            ckpt_folder = save_folder

        # Reload previous checkpoint.
        if os.path.exists(ckpt_folder) and not args.eval:
            latest = latest_file(ckpt_folder, 'weights-')
            if latest is not None:
                log.info('Checkpoint already exists. Loading from {}'.format(
                    latest))
                mem_model.load(latest)  # Not loading optimizer weights here.
                reload_flag = latest
                restore_steps = int(reload_flag.split('-')[-1])

        if not args.eval:
            save_config(config, save_folder)

        # Create TB logger.
        if not args.eval:
            writer = tf.summary.create_file_writer(save_folder)
            logger = ExperimentLogger(writer)

        # Get dataset.
        dataset = get_data_fs(env_config, load_train=True)

        # Get data iterators.
        if env_config.dataset in ["matterport"]:
            data = get_dataiter_sim(
                dataset,
                data_config,
                batch_size=config.optimizer_config.batch_size,
                nchw=mem_model.backbone.config.data_format == 'NCHW',
                seed=args.seed + restore_steps)
        else:
            data = get_dataiter_continual(
                dataset,
                data_config,
                batch_size=config.optimizer_config.batch_size,
                nchw=mem_model.backbone.config.data_format == 'NCHW',
                save_additional_info=True,
                random_box=data_config.random_box,
                seed=args.seed + restore_steps)

    # Load model, training loop.
    if not args.eval:
        if args.pretrain is not None and reload_flag is None:
            model.load(latest_file(args.pretrain, 'weights-'))
            if config.freeze_backbone:
                model.set_trainable(False)  # Freeze the network.
                log.info('Backbone network is now frozen')
        with writer.as_default() if writer is not None else dummy_context_mgr(
        ) as gs:
            train(mem_model,
                  data['train_fs'],
                  data['trainval_fs'],
                  data['val_fs'],
                  ckpt_folder,
                  final_save_folder=save_folder,
                  maxlen=data_config.maxlen,
                  logger=logger,
                  writer=writer,
                  in_stage=config.in_stage,
                  reload_flag=reload_flag)
    else:
        results_file = os.path.join(save_folder, 'results.pkl')
        logfile = os.path.join(save_folder, 'results.csv')
        if os.path.exists(results_file) and args.reeval:
            # Re-display results.
            results_all = pkl.load(open(results_file, 'rb'))
            for split, name in zip(['trainval_fs', 'val_fs', 'test_fs'],
                                   ['Train', 'Val', 'Test']):
                stats = get_stats(results_all[split], tmax=data_config.maxlen)
                log_results(stats, prefix=name, filename=logfile)
        else:
            # Load the most recent checkpoint.
            if args.usebest:
                latest = latest_file(save_folder, 'best-')
            else:
                latest = latest_file(save_folder, 'weights-')

            if latest is not None:
                mem_model.load(latest)
            else:
                latest = latest_file(args.pretrain, 'weights-')
                if latest is not None:
                    mem_model.load(latest)
                else:
                    raise ValueError('Checkpoint not found')
            data['trainval_fs'].reset()
            data['val_fs'].reset()
            data['test_fs'].reset()

            results_all = {}
            if args.testonly:
                split_list = ['test_fs']
                name_list = ['Test']
                nepisode_list = [config.num_episodes]
            else:
                split_list = ['trainval_fs', 'val_fs', 'test_fs']
                name_list = ['Train', 'Val', 'Test']
                nepisode_list = [600, config.num_episodes, config.num_episodes]

            for split, name, N in zip(split_list, name_list, nepisode_list):
                data[split].reset()
                r1 = evaluate(mem_model, data[split], N, verbose=True)
                stats = get_stats(r1, tmax=data_config.maxlen)
                log_results(stats, prefix=name, filename=logfile)
                results_all[split] = r1
            pkl.dump(results_all, open(results_file, 'wb'))
Esempio n. 3
0
def test_load_one(folder, seed=0):
    # log.info('Command line args {}'.format(args))
    config_file = os.path.join(folder, 'config.prototxt')
    config = get_config(config_file, ExperimentConfig)
    # config.c4_config.data_format = 'NHWC'
    # config.resnet_config.data_format = 'NHWC'
    if 'omniglot' in folder:
        if 'ssl' in folder:
            data_config_file = 'configs/episodes/roaming-omniglot/roaming-omniglot-150-ssl.prototxt'  # NOQA
        else:
            data_config_file = 'configs/episodes/roaming-omniglot/roaming-omniglot-150.prototxt'  # NOQA
        env_config_file = 'configs/environ/roaming-omniglot-docker.prototxt'
    elif 'rooms' in folder:
        if 'ssl' in folder:
            data_config_file = 'configs/episodes/roaming-rooms/roaming-rooms-100.prototxt'  # NOQA
        else:
            data_config_file = 'configs/episodes/roaming-rooms/romaing-rooms-100.prototxt'  # NOQA
        env_config_file = 'configs/environ/roaming-rooms-docker.prototxt'
    data_config = get_config(data_config_file, EpisodeConfig)
    env_config = get_config(env_config_file, EnvironmentConfig)
    log.info('Model: \n{}'.format(config))
    log.info('Data episode: \n{}'.format(data_config))
    log.info('Environment: \n{}'.format(env_config))
    config.num_classes = data_config.nway  # Assign num classes.
    config.num_steps = data_config.maxlen
    config.memory_net_config.max_classes = data_config.nway
    config.memory_net_config.max_stages = data_config.nstage
    config.memory_net_config.max_items = data_config.maxlen
    config.oml_config.num_classes = data_config.nway
    config.fix_unknown = data_config.fix_unknown  # Assign fix unknown ID.
    log.info('Number of classes {}'.format(data_config.nway))

    model = build_pretrain_net(config)
    mem_model = build_net(config, backbone=model.backbone)
    reload_flag = None
    restore_steps = 0

    # Reload previous checkpoint.
    latest = latest_file(folder, 'best-')
    if latest is None:
        latest = latest_file(folder, 'weights-')
    assert latest is not None, "Checkpoint not found."

    if latest is not None:
        log.info('Checkpoint already exists. Loading from {}'.format(latest))
        mem_model.load(latest)  # Not loading optimizer weights here.
        reload_flag = latest
        restore_steps = int(reload_flag.split('-')[-1])

    # Get dataset.
    dataset = get_data_fs(env_config, load_train=True)

    # Get data iterators.
    if env_config.dataset in ["roaming-rooms", "matterport"]:
        data = get_dataiter_sim(
            dataset,
            data_config,
            batch_size=config.optimizer_config.batch_size,
            nchw=mem_model.backbone.config.data_format == 'NCHW',
            seed=seed + restore_steps)
    else:
        data = get_dataiter_continual(
            dataset,
            data_config,
            batch_size=config.optimizer_config.batch_size,
            nchw=mem_model.backbone.config.data_format == 'NCHW',
            save_additional_info=True,
            random_box=data_config.random_box,
            seed=seed + restore_steps)

    # Load the most recent checkpoint.
    latest = latest_file(folder, 'best-')
    if latest is None:
        latest = latest_file(folder, 'weights-')

    data['trainval_fs'].reset()
    data['val_fs'].reset()
    data['test_fs'].reset()

    results_all = {}
    split_list = ['trainval_fs', 'val_fs', 'test_fs']
    name_list = ['Train', 'Val', 'Test']
    nepisode_list = [5, 5, 5]
    # nepisode_list = [600, config.num_episodes, config.num_episodes]

    for split, name, N in zip(split_list, name_list, nepisode_list):
        # print(name)
        data[split].reset()
        r1 = evaluate(mem_model, data[split], N, verbose=False)
        stats = get_stats(r1, tmax=data_config.maxlen)
        print(split, stats['ap'])
Esempio n. 4
0
def main():
  data_config = get_config(args.data, EpisodeConfig)
  env_config = get_config(args.env, EnvironmentConfig)

  # Get dataset.
  dataset = get_data_fs(env_config, load_train=True)

  # Get data iterators.
  if data_config.hierarchical:
    save_additional_info = True
  else:
    save_additional_info = False

  if env_config.dataset in ["matterport"]:
    data = get_dataiter_sim(
        dataset,
        data_config,
        batch_size=1,
        nchw=False,
        distributed=False,
        prefetch=False)
  else:
    data = get_dataiter_continual(
        dataset,
        data_config,
        batch_size=1,
        nchw=False,
        prefetch=False,
        save_additional_info=save_additional_info,
        random_box=data_config.random_box)

  if not os.path.exists(args.output):
    os.makedirs(args.output)
    os.makedirs(os.path.join(args.output, 'val'))
    os.makedirs(os.path.join(args.output, 'train'))

  if args.restore is None:
    log.info('No model provided.')
    fs_model = None
  else:
    config = os.path.join(args.restore, 'config.prototxt')
    config = get_config(config, ExperimentConfig)
    model = build_pretrain_net(config)
    fs_model = build_net(config, backbone=model.backbone)
    fs_model.load(latest_file(args.restore, 'weights-'))

  if data_config.maxlen <= 40:
    ncol = 5
    nrow = 8
  elif data_config.maxlen <= 100:
    ncol = 6
    nrow = 10
  else:
    ncol = 6
    nrow = 25

  # Plot episodes.
  plot_episodes(
      data['val_fs'],
      args.nepisode,
      os.path.join(args.output, 'val', 'episode_{:06d}'),
      model=fs_model,
      nway=data_config.nway,
      ncol=ncol,
      nrow=nrow,
      nstage=data_config.nstage)

  # Plot episodes.
  plot_episodes(
      data['train_fs'],
      args.nepisode,
      os.path.join(args.output, 'train', 'episode_{:06d}'),
      model=fs_model,
      nway=data_config.nway,
      ncol=ncol,
      nrow=nrow,
      nstage=data_config.nstage)
Esempio n. 5
0
def main():
    assert tf.executing_eagerly(), 'Only eager mode is supported.'
    assert args.config is not None, 'You need to pass in model config file path'
    assert args.data is not None, 'You need to pass in episode config file path'
    assert args.env is not None, 'You need to pass in environ config file path'
    assert args.tag is not None, 'You need to specify a tag'

    log.info('Command line args {}'.format(args))
    config = get_config(args.config, ExperimentConfig)
    data_config = get_config(args.data, EpisodeConfig)
    env_config = get_config(args.env, EnvironmentConfig)
    log.info('Model: \n{}'.format(config))
    log.info('Data episode: \n{}'.format(data_config))
    log.info('Environment: \n{}'.format(env_config))
    config.num_classes = data_config.nway  # Assign num classes.
    config.num_steps = data_config.maxlen
    config.memory_net_config.max_classes = data_config.nway
    config.memory_net_config.max_stages = data_config.nstage
    config.memory_net_config.max_items = data_config.maxlen
    config.oml_config.num_classes = data_config.nway
    config.fix_unknown = data_config.fix_unknown  # Assign fix unknown ID.

    # Modify optimization config.
    if config.optimizer_config.lr_scaling:
        for i in range(len(config.optimizer_config.lr_decay_steps)):
            config.optimizer_config.lr_decay_steps[i] //= len(gpus)
        config.optimizer_config.max_train_steps //= len(gpus)

        # Linearly scale learning rate.
        for i in range(len(config.optimizer_config.lr_list)):
            config.optimizer_config.lr_list[i] *= float(len(gpus))

    log.info('Number of classes {}'.format(data_config.nway))

    # Build model.
    model = build_pretrain_net(config)
    mem_model = build_net(config, backbone=model.backbone, distributed=True)
    reload_flag = None
    restore_steps = 0

    if 'SLURM_JOB_ID' in os.environ:
        log.info('SLURM job ID: {}'.format(os.environ['SLURM_JOB_ID']))

    # Create save folder.
    if is_chief:
        save_folder = os.path.join(env_config.results, env_config.dataset,
                                   args.tag)
        ckpt_path = env_config.checkpoint
        if len(ckpt_path) > 0 and os.path.exists(ckpt_path):
            ckpt_folder = os.path.join(ckpt_path, os.environ['SLURM_JOB_ID'])
            log.info('Checkpoint folder: {}'.format(ckpt_folder))
        else:
            ckpt_folder = save_folder

        latest = None
        if os.path.exists(ckpt_folder):
            latest = latest_file(ckpt_folder, 'weights-')

        if latest is None and os.path.exists(save_folder):
            latest = latest_file(save_folder, 'weights-')

        if latest is not None:
            log.info(
                'Checkpoint already exists. Loading from {}'.format(latest))
            mem_model.load(latest)  # Not loading optimizer weights here.
            reload_flag = latest
            restore_steps = int(reload_flag.split('-')[-1])

        # Create TB logger.
        save_config(config, save_folder)
        writer = tf.summary.create_file_writer(save_folder)
        logger = ExperimentLogger(writer)
    else:
        save_folder = None
        ckpt_folder = None
        writer = None
        logger = None

    # Get dataset.
    dataset = get_data_fs(env_config, load_train=True)

    # Get data iterators.
    if env_config.dataset in ["matterport", "roaming-rooms"]:
        data = get_dataiter_sim(
            dataset,
            data_config,
            batch_size=config.optimizer_config.batch_size,
            nchw=mem_model.backbone.config.data_format == 'NCHW',
            distributed=True,
            seed=args.seed + restore_steps)
    else:
        data = get_dataiter_continual(
            dataset,
            data_config,
            batch_size=config.optimizer_config.batch_size,
            nchw=mem_model.backbone.config.data_format == 'NCHW',
            save_additional_info=True,
            random_box=data_config.random_box,
            distributed=True,
            seed=args.seed + restore_steps)

    # Load model, training loop.
    if args.pretrain is not None and reload_flag is None:
        mem_model.load(latest_file(args.pretrain, 'weights-'))
        if config.freeze_backbone:
            model.set_trainable(False)  # Freeze the network.
            log.info('Backbone network is now frozen')

    with writer.as_default() if writer is not None else dummy_context_mgr(
    ) as gs:
        train(mem_model,
              data['train_fs'],
              data['trainval_fs'],
              data['val_fs'],
              ckpt_folder,
              final_save_folder=save_folder,
              maxlen=data_config.maxlen,
              logger=logger,
              writer=writer,
              is_chief=is_chief,
              reload_flag=reload_flag)
Esempio n. 6
0
def main():
  assert tf.executing_eagerly(), 'Only eager mode is supported.'
  assert args.config is not None, 'You need to pass in model config file path'
  assert args.data is not None, 'You need to pass in episode config file path'
  assert args.env is not None, 'You need to pass in environ config file path'
  assert args.tag is not None, 'You need to specify a tag'

  log.info('Command line args {}'.format(args))
  config = get_config(args.config, ExperimentConfig)
  data_config = get_config(args.data, EpisodeConfig)
  env_config = get_config(args.env, EnvironmentConfig)
  log.info('Model: \n{}'.format(config))
  log.info('Data episode: \n{}'.format(data_config))
  log.info('Environment: \n{}'.format(env_config))
  config.num_classes = data_config.nway  # Assign num classes.
  config.num_steps = data_config.maxlen
  config.memory_net_config.max_classes = data_config.nway
  config.memory_net_config.max_stages = data_config.nstage
  config.memory_net_config.max_items = data_config.maxlen
  config.oml_config.num_classes = data_config.nway
  config.fix_unknown = data_config.fix_unknown  # Assign fix unknown ID.
  log.info('Number of classes {}'.format(data_config.nway))
  if 'SLURM_JOB_ID' in os.environ:
    log.info('SLURM job ID: {}'.format(os.environ['SLURM_JOB_ID']))

  # Create save folder.
  save_folder = os.path.join(env_config.results, env_config.dataset, args.tag)
  results_file = os.path.join(save_folder, 'results_tsne.pkl')

  # Plot features
  if os.path.exists(results_file):
    batch_list = pkl.load(open(results_file, 'rb'))
    return visualize(batch_list, save_folder)

  # To get CNN features from.
  model = build_pretrain_net(config)
  if args.rnn:
    rnn_memory = get_module(
        "lstm",
        "lstm",
        model.backbone.get_output_dimension()[0],
        config.lstm_config.hidden_dim,
        layernorm=config.lstm_config.layernorm,
        dtype=tf.float32)
    proto_memory = get_module(
        'ssl_min_dist_proto_memory',
        'proto_memory',
        config.memory_net_config.radius_init,
        max_classes=config.memory_net_config.max_classes,
        fix_unknown=config.fix_unknown,
        unknown_id=config.num_classes if config.fix_unknown else None,
        similarity=config.memory_net_config.similarity,
        static_beta_gamma=not config.hybrid_config.use_pred_beta_gamma,
        dtype=tf.float32)
    memory = RNNEncoder(
        "proto_plus_rnn_ssl_v4",
        rnn_memory,
        proto_memory,
        readout_type=config.mann_config.readout_type,
        use_pred_beta_gamma=config.hybrid_config.use_pred_beta_gamma,
        use_feature_fuse=config.hybrid_config.use_feature_fuse,
        use_feature_fuse_gate=config.hybrid_config.use_feature_fuse_gate,
        use_feature_scaling=config.hybrid_config.use_feature_scaling,
        use_feature_memory_only=config.hybrid_config.use_feature_memory_only,
        skip_unk_memory_update=config.hybrid_config.skip_unk_memory_update,
        use_ssl=config.hybrid_config.use_ssl,
        use_ssl_beta_gamma_write=config.hybrid_config.use_ssl_beta_gamma_write,
        use_ssl_temp=config.hybrid_config.use_ssl_temp,
        dtype=tf.float32)
    rnn_model = RNNEncoderNet(config, model.backbone, memory)
    f = lambda x, y: rnn_model.forward(x, y, is_training=False)  # NOQA
  else:
    f = lambda x, y: model.backbone(x[0], is_training=False)[None, :, :] # NOQA

  K = config.num_classes
  # Get dataset.
  dataset = get_data_fs(env_config, load_train=True)

  # Get data iterators.
  if env_config.dataset in ["matterport"]:
    data = get_dataiter_sim(
        dataset,
        data_config,
        batch_size=config.optimizer_config.batch_size,
        nchw=model.backbone.config.data_format == 'NCHW',
        seed=args.seed)
  else:
    data = get_dataiter_continual(
        dataset,
        data_config,
        batch_size=config.optimizer_config.batch_size,
        nchw=model.backbone.config.data_format == 'NCHW',
        save_additional_info=True,
        random_box=data_config.random_box,
        seed=args.seed)

  if args.usebest:
    latest = latest_file(save_folder, 'best-')
  else:
    latest = latest_file(save_folder, 'weights-')

  if args.rnn:
    rnn_model.load(latest)
  else:
    model.load(latest)

  data['trainval_fs'].reset()
  data['val_fs'].reset()
  data['test_fs'].reset()

  results_all = {}
  split_list = ['test_fs']
  name_list = ['Test']
  nepisode_list = [100]

  for split, name, N in zip(split_list, name_list, nepisode_list):
    data[split].reset()

    data_fname = '{}/{}.pkl'.format(env_config.data_folder, split)
    if not os.path.exists(data_fname):
      batch_list = []
      for i, batch in zip(range(N), data[split]):
        for k in batch.keys():
          batch[k] = batch[k]
        batch_list.append(batch)
      pkl.dump(batch_list, open(data_fname, 'wb'))

    batch_list = pkl.load(open(data_fname, 'rb'))
    r1 = evaluate(f, K, batch_list, N)
    results_all[split] = r1
  pkl.dump(results_all, open(results_file, 'wb'))