def main(): assert tf.executing_eagerly(), 'Only eager mode is supported.' assert args.config is not None, 'You need to pass in model config file path' assert args.data is not None, 'You need to pass in episode config file path' assert args.env is not None, 'You need to pass in environ config file path' assert args.tag is not None, 'You need to specify a tag' log.info('Command line args {}'.format(args)) config = get_config(args.config, ExperimentConfig) data_config = get_config(args.data, EpisodeConfig) env_config = get_config(args.env, EnvironmentConfig) log.info('Model: \n{}'.format(config)) log.info('Data episode: \n{}'.format(data_config)) log.info('Environment: \n{}'.format(env_config)) config.num_classes = data_config.nway # Assign num classes. config.num_steps = data_config.maxlen config.memory_net_config.max_classes = data_config.nway config.memory_net_config.max_stages = data_config.nstage config.memory_net_config.max_items = data_config.maxlen config.oml_config.num_classes = data_config.nway config.fix_unknown = data_config.fix_unknown # Assign fix unknown ID. log.info('Number of classes {}'.format(data_config.nway)) if 'SLURM_JOB_ID' in os.environ: log.info('SLURM job ID: {}'.format(os.environ['SLURM_JOB_ID'])) # Create save folder. save_folder = os.path.join(env_config.results, env_config.dataset, args.tag) if not args.reeval: model = build_pretrain_net(config) mem_model = build_net(config, backbone=model.backbone) reload_flag = None restore_steps = 0 # Checkpoint folder. ckpt_path = env_config.checkpoint if len(ckpt_path) > 0 and os.path.exists(ckpt_path): ckpt_folder = os.path.join(ckpt_path, os.environ['SLURM_JOB_ID']) else: ckpt_folder = save_folder # Reload previous checkpoint. if os.path.exists(ckpt_folder) and not args.eval: latest = latest_file(ckpt_folder, 'weights-') if latest is not None: log.info('Checkpoint already exists. Loading from {}'.format( latest)) mem_model.load(latest) # Not loading optimizer weights here. reload_flag = latest restore_steps = int(reload_flag.split('-')[-1]) if not args.eval: save_config(config, save_folder) # Create TB logger. if not args.eval: writer = tf.summary.create_file_writer(save_folder) logger = ExperimentLogger(writer) # Get dataset. dataset = get_data_fs(env_config, load_train=True) # Get data iterators. if env_config.dataset in ["matterport"]: data = get_dataiter_sim( dataset, data_config, batch_size=config.optimizer_config.batch_size, nchw=mem_model.backbone.config.data_format == 'NCHW', seed=args.seed + restore_steps) else: data = get_dataiter_continual( dataset, data_config, batch_size=config.optimizer_config.batch_size, nchw=mem_model.backbone.config.data_format == 'NCHW', save_additional_info=True, random_box=data_config.random_box, seed=args.seed + restore_steps) # Load model, training loop. if not args.eval: if args.pretrain is not None and reload_flag is None: model.load(latest_file(args.pretrain, 'weights-')) if config.freeze_backbone: model.set_trainable(False) # Freeze the network. log.info('Backbone network is now frozen') with writer.as_default() if writer is not None else dummy_context_mgr( ) as gs: train(mem_model, data['train_fs'], data['trainval_fs'], data['val_fs'], ckpt_folder, final_save_folder=save_folder, maxlen=data_config.maxlen, logger=logger, writer=writer, in_stage=config.in_stage, reload_flag=reload_flag) else: results_file = os.path.join(save_folder, 'results.pkl') logfile = os.path.join(save_folder, 'results.csv') if os.path.exists(results_file) and args.reeval: # Re-display results. results_all = pkl.load(open(results_file, 'rb')) for split, name in zip(['trainval_fs', 'val_fs', 'test_fs'], ['Train', 'Val', 'Test']): stats = get_stats(results_all[split], tmax=data_config.maxlen) log_results(stats, prefix=name, filename=logfile) else: # Load the most recent checkpoint. if args.usebest: latest = latest_file(save_folder, 'best-') else: latest = latest_file(save_folder, 'weights-') if latest is not None: mem_model.load(latest) else: latest = latest_file(args.pretrain, 'weights-') if latest is not None: mem_model.load(latest) else: raise ValueError('Checkpoint not found') data['trainval_fs'].reset() data['val_fs'].reset() data['test_fs'].reset() results_all = {} if args.testonly: split_list = ['test_fs'] name_list = ['Test'] nepisode_list = [config.num_episodes] else: split_list = ['trainval_fs', 'val_fs', 'test_fs'] name_list = ['Train', 'Val', 'Test'] nepisode_list = [600, config.num_episodes, config.num_episodes] for split, name, N in zip(split_list, name_list, nepisode_list): data[split].reset() r1 = evaluate(mem_model, data[split], N, verbose=True) stats = get_stats(r1, tmax=data_config.maxlen) log_results(stats, prefix=name, filename=logfile) results_all[split] = r1 pkl.dump(results_all, open(results_file, 'wb'))
def main(): assert tf.executing_eagerly(), 'Only eager mode is supported.' assert args.config is not None, 'You need to pass in model config file path' assert args.data is not None, 'You need to pass in episode config file path' assert args.env is not None, 'You need to pass in environ config file path' assert args.tag is not None, 'You need to specify a tag' log.info('Command line args {}'.format(args)) config = get_config(args.config, ExperimentConfig) data_config = get_config(args.data, EpisodeConfig) env_config = get_config(args.env, EnvironmentConfig) log.info('Model: \n{}'.format(config)) log.info('Data episode: \n{}'.format(data_config)) log.info('Environment: \n{}'.format(env_config)) config.num_classes = data_config.nway # Assign num classes. config.num_steps = data_config.maxlen config.memory_net_config.max_classes = data_config.nway config.memory_net_config.max_stages = data_config.nstage config.memory_net_config.max_items = data_config.maxlen config.oml_config.num_classes = data_config.nway config.fix_unknown = data_config.fix_unknown # Assign fix unknown ID. log.info('Number of classes {}'.format(data_config.nway)) if 'SLURM_JOB_ID' in os.environ: log.info('SLURM job ID: {}'.format(os.environ['SLURM_JOB_ID'])) # Create save folder. save_folder = os.path.join(env_config.results, env_config.dataset, args.tag) results_file = os.path.join(save_folder, 'results.pkl') logfile = os.path.join(save_folder, 'results.csv') # To get CNN features from. model = build_pretrain_net(config) if args.rnn: rnn_memory = get_module("lstm", "lstm", model.backbone.get_output_dimension()[0], config.lstm_config.hidden_dim, layernorm=config.lstm_config.layernorm, dtype=tf.float32) memory = RNNEncoder( "proto_plus_rnn_ssl_v4", rnn_memory, readout_type=config.mann_config.readout_type, use_pred_beta_gamma=config.hybrid_config.use_pred_beta_gamma, use_feature_fuse=config.hybrid_config.use_feature_fuse, use_feature_fuse_gate=config.hybrid_config.use_feature_fuse_gate, use_feature_scaling=config.hybrid_config.use_feature_scaling, use_feature_memory_only=config.hybrid_config. use_feature_memory_only, skip_unk_memory_update=config.hybrid_config.skip_unk_memory_update, use_ssl=config.hybrid_config.use_ssl, use_ssl_beta_gamma_write=config.hybrid_config. use_ssl_beta_gamma_write, use_ssl_temp=config.hybrid_config.use_ssl_temp, dtype=tf.float32) rnn_model = RNNEncoderNet(config, model.backbone, memory) f = lambda x: rnn_model.forward(x, is_training=False) # NOQA else: f = lambda x: model.backbone(x[0], is_training=False)[None, :, : ] # NOQA K = config.num_classes # Get dataset. dataset = get_data_fs(env_config, load_train=True) # Get data iterators. if env_config.dataset in ["matterport"]: data = get_dataiter_sim( dataset, data_config, batch_size=config.optimizer_config.batch_size, nchw=model.backbone.config.data_format == 'NCHW', seed=args.seed) else: data = get_dataiter_continual( dataset, data_config, batch_size=config.optimizer_config.batch_size, nchw=model.backbone.config.data_format == 'NCHW', save_additional_info=True, random_box=data_config.random_box, seed=args.seed) if os.path.exists(results_file) and args.reeval: # Re-display results. results_all = pkl.load(open(results_file, 'rb')) for split, name in zip(['trainval_fs', 'val_fs', 'test_fs'], ['Train', 'Val', 'Test']): stats = get_stats(results_all[split], tmax=data_config.maxlen) log_results(stats, prefix=name, filename=logfile) else: latest = latest_file(save_folder, 'weights-') if args.rnn: rnn_model.load(latest) else: model.load(latest) data['trainval_fs'].reset() data['val_fs'].reset() data['test_fs'].reset() results_all = {} if args.testonly: split_list = ['test_fs'] name_list = ['Test'] nepisode_list = [config.num_episodes] else: split_list = ['trainval_fs', 'val_fs', 'test_fs'] name_list = ['Train', 'Val', 'Test'] nepisode_list = [600, config.num_episodes, config.num_episodes] for split, name, N in zip(split_list, name_list, nepisode_list): data[split].reset() r1 = evaluate(f, K, data[split], N) stats = get_stats(r1, tmax=data_config.maxlen) log_results(stats, prefix=name, filename=logfile) results_all[split] = r1 pkl.dump(results_all, open(results_file, 'wb'))
def train(model, dataiter, dataiter_traintest, dataiter_test, ckpt_folder, final_save_folder=None, nshot_max=5, maxlen=40, logger=None, writer=None, in_stage=False, is_chief=True, reload_flag=None): """Trains the online few-shot model. Args: model: Model instance. dataiter: Dataset iterator. dataiter_test: Dataset iterator for validation. save_folder: Path to save the checkpoints. """ N = model.max_train_steps config = model.config.train_config def try_log(*args, **kwargs): if logger is not None: logger.log(*args, **kwargs) def try_flush(*args, **kwargs): if logger is not None: logger.flush() r = None keep = True restart = False best_val = 0.0 while keep: keep = False start = model.step.numpy() if start > 0: log.info('Restore from step {}'.format(start)) it = six.moves.xrange(start, N) if is_chief: it = tqdm(it, ncols=0) for i, batch in zip(it, dataiter): tf.summary.experimental.set_step(i + 1) x = batch['x_s'] y = batch['y_s'] y = batch['y_s'] y_gt = batch['y_gt'] kwargs = {'y_gt': y_gt, 'flag': batch['flag_s']} kwargs['writer'] = writer loss = model.train_step(x, y, **kwargs) if i == start and reload_flag is not None and not restart: model.load(reload_flag, load_optimizer=True) # Synchronize distributed weights. if i == start and model._distributed and not restart: import horovod.tensorflow as hvd hvd.broadcast_variables(model.var_to_optimize(), root_rank=0) hvd.broadcast_variables(model.optimizer.variables(), root_rank=0) if model.config.set_backbone_lr: hvd.broadcast_variables(model._bb_optimizer.variables(), root_rank=0) if loss > 10.0 and i > start + config.steps_per_save: # Something wrong happened. log.error('Something wrong happened. loss = {}'.format(loss)) import pickle as pkl pkl.dump( batch, open(os.path.join(final_save_folder, 'debug.pkl'), 'wb')) log.error('debug file dumped') restart = True keep = True latest = latest_file(ckpt_folder, 'weights-') model.load(latest_file) log.error( 'Reloaded latest checkpoint from {}'.format(latest_file)) break # Evaluate. if is_chief and ((i + 1) % config.steps_per_val == 0 or i == 0): for key, data_it_ in zip(['train', 'val'], [dataiter_traintest, dataiter_test]): data_it_.reset() r1 = evaluate(model, data_it_, 120) r = get_stats(r1, nshot_max=nshot_max, tmax=maxlen) for s in range(nshot_max): try_log('online fs acc {}/s{}'.format(key, s), i + 1, r['acc_nshot'][s] * 100.0) try_log('online fs ap {}'.format(key), i + 1, r['ap'] * 100.0) try_log('lr', i + 1, model.learn_rate()) print() # Save. if is_chief and ((i + 1) % config.steps_per_save == 0 or i == 0): model.save( os.path.join(ckpt_folder, 'weights-{}'.format(i + 1))) # Save the best checkpoint. if r is not None: if r['ap'] > best_val: model.save( os.path.join(final_save_folder, 'best-{}'.format(i + 1))) best_val = r['ap'] # Write logs. if is_chief and ((i + 1) % config.steps_per_log == 0 or i == 0): try_log('loss', i + 1, loss) try_flush() # Update progress bar. post_fix_dict = {} post_fix_dict['lr'] = '{:.3e}'.format(model.learn_rate()) post_fix_dict['loss'] = '{:.3e}'.format(loss) if r is not None: post_fix_dict['ap_val'] = '{:.3f}'.format(r['ap'] * 100.0) it.set_postfix(**post_fix_dict) # Save. if is_chief and final_save_folder is not None: model.save(os.path.join(final_save_folder, 'weights-{}'.format(N)))
def test_load_one(folder, seed=0): # log.info('Command line args {}'.format(args)) config_file = os.path.join(folder, 'config.prototxt') config = get_config(config_file, ExperimentConfig) # config.c4_config.data_format = 'NHWC' # config.resnet_config.data_format = 'NHWC' if 'omniglot' in folder: if 'ssl' in folder: data_config_file = 'configs/episodes/roaming-omniglot/roaming-omniglot-150-ssl.prototxt' # NOQA else: data_config_file = 'configs/episodes/roaming-omniglot/roaming-omniglot-150.prototxt' # NOQA env_config_file = 'configs/environ/roaming-omniglot-docker.prototxt' elif 'rooms' in folder: if 'ssl' in folder: data_config_file = 'configs/episodes/roaming-rooms/roaming-rooms-100.prototxt' # NOQA else: data_config_file = 'configs/episodes/roaming-rooms/romaing-rooms-100.prototxt' # NOQA env_config_file = 'configs/environ/roaming-rooms-docker.prototxt' data_config = get_config(data_config_file, EpisodeConfig) env_config = get_config(env_config_file, EnvironmentConfig) log.info('Model: \n{}'.format(config)) log.info('Data episode: \n{}'.format(data_config)) log.info('Environment: \n{}'.format(env_config)) config.num_classes = data_config.nway # Assign num classes. config.num_steps = data_config.maxlen config.memory_net_config.max_classes = data_config.nway config.memory_net_config.max_stages = data_config.nstage config.memory_net_config.max_items = data_config.maxlen config.oml_config.num_classes = data_config.nway config.fix_unknown = data_config.fix_unknown # Assign fix unknown ID. log.info('Number of classes {}'.format(data_config.nway)) model = build_pretrain_net(config) mem_model = build_net(config, backbone=model.backbone) reload_flag = None restore_steps = 0 # Reload previous checkpoint. latest = latest_file(folder, 'best-') if latest is None: latest = latest_file(folder, 'weights-') assert latest is not None, "Checkpoint not found." if latest is not None: log.info('Checkpoint already exists. Loading from {}'.format(latest)) mem_model.load(latest) # Not loading optimizer weights here. reload_flag = latest restore_steps = int(reload_flag.split('-')[-1]) # Get dataset. dataset = get_data_fs(env_config, load_train=True) # Get data iterators. if env_config.dataset in ["roaming-rooms", "matterport"]: data = get_dataiter_sim( dataset, data_config, batch_size=config.optimizer_config.batch_size, nchw=mem_model.backbone.config.data_format == 'NCHW', seed=seed + restore_steps) else: data = get_dataiter_continual( dataset, data_config, batch_size=config.optimizer_config.batch_size, nchw=mem_model.backbone.config.data_format == 'NCHW', save_additional_info=True, random_box=data_config.random_box, seed=seed + restore_steps) # Load the most recent checkpoint. latest = latest_file(folder, 'best-') if latest is None: latest = latest_file(folder, 'weights-') data['trainval_fs'].reset() data['val_fs'].reset() data['test_fs'].reset() results_all = {} split_list = ['trainval_fs', 'val_fs', 'test_fs'] name_list = ['Train', 'Val', 'Test'] nepisode_list = [5, 5, 5] # nepisode_list = [600, config.num_episodes, config.num_episodes] for split, name, N in zip(split_list, name_list, nepisode_list): # print(name) data[split].reset() r1 = evaluate(mem_model, data[split], N, verbose=False) stats = get_stats(r1, tmax=data_config.maxlen) print(split, stats['ap'])
def main(): data_config = get_config(args.data, EpisodeConfig) env_config = get_config(args.env, EnvironmentConfig) # Get dataset. dataset = get_data_fs(env_config, load_train=True) # Get data iterators. if data_config.hierarchical: save_additional_info = True else: save_additional_info = False if env_config.dataset in ["matterport"]: data = get_dataiter_sim( dataset, data_config, batch_size=1, nchw=False, distributed=False, prefetch=False) else: data = get_dataiter_continual( dataset, data_config, batch_size=1, nchw=False, prefetch=False, save_additional_info=save_additional_info, random_box=data_config.random_box) if not os.path.exists(args.output): os.makedirs(args.output) os.makedirs(os.path.join(args.output, 'val')) os.makedirs(os.path.join(args.output, 'train')) if args.restore is None: log.info('No model provided.') fs_model = None else: config = os.path.join(args.restore, 'config.prototxt') config = get_config(config, ExperimentConfig) model = build_pretrain_net(config) fs_model = build_net(config, backbone=model.backbone) fs_model.load(latest_file(args.restore, 'weights-')) if data_config.maxlen <= 40: ncol = 5 nrow = 8 elif data_config.maxlen <= 100: ncol = 6 nrow = 10 else: ncol = 6 nrow = 25 # Plot episodes. plot_episodes( data['val_fs'], args.nepisode, os.path.join(args.output, 'val', 'episode_{:06d}'), model=fs_model, nway=data_config.nway, ncol=ncol, nrow=nrow, nstage=data_config.nstage) # Plot episodes. plot_episodes( data['train_fs'], args.nepisode, os.path.join(args.output, 'train', 'episode_{:06d}'), model=fs_model, nway=data_config.nway, ncol=ncol, nrow=nrow, nstage=data_config.nstage)
def main(): assert tf.executing_eagerly(), 'Only eager mode is supported.' assert args.config is not None, 'You need to pass in model config file path' assert args.data is not None, 'You need to pass in episode config file path' assert args.env is not None, 'You need to pass in environ config file path' assert args.tag is not None, 'You need to specify a tag' log.info('Command line args {}'.format(args)) config = get_config(args.config, ExperimentConfig) data_config = get_config(args.data, EpisodeConfig) env_config = get_config(args.env, EnvironmentConfig) log.info('Model: \n{}'.format(config)) log.info('Data episode: \n{}'.format(data_config)) log.info('Environment: \n{}'.format(env_config)) config.num_classes = data_config.nway # Assign num classes. config.num_steps = data_config.maxlen config.memory_net_config.max_classes = data_config.nway config.memory_net_config.max_stages = data_config.nstage config.memory_net_config.max_items = data_config.maxlen config.oml_config.num_classes = data_config.nway config.fix_unknown = data_config.fix_unknown # Assign fix unknown ID. # Modify optimization config. if config.optimizer_config.lr_scaling: for i in range(len(config.optimizer_config.lr_decay_steps)): config.optimizer_config.lr_decay_steps[i] //= len(gpus) config.optimizer_config.max_train_steps //= len(gpus) # Linearly scale learning rate. for i in range(len(config.optimizer_config.lr_list)): config.optimizer_config.lr_list[i] *= float(len(gpus)) log.info('Number of classes {}'.format(data_config.nway)) # Build model. model = build_pretrain_net(config) mem_model = build_net(config, backbone=model.backbone, distributed=True) reload_flag = None restore_steps = 0 if 'SLURM_JOB_ID' in os.environ: log.info('SLURM job ID: {}'.format(os.environ['SLURM_JOB_ID'])) # Create save folder. if is_chief: save_folder = os.path.join(env_config.results, env_config.dataset, args.tag) ckpt_path = env_config.checkpoint if len(ckpt_path) > 0 and os.path.exists(ckpt_path): ckpt_folder = os.path.join(ckpt_path, os.environ['SLURM_JOB_ID']) log.info('Checkpoint folder: {}'.format(ckpt_folder)) else: ckpt_folder = save_folder latest = None if os.path.exists(ckpt_folder): latest = latest_file(ckpt_folder, 'weights-') if latest is None and os.path.exists(save_folder): latest = latest_file(save_folder, 'weights-') if latest is not None: log.info( 'Checkpoint already exists. Loading from {}'.format(latest)) mem_model.load(latest) # Not loading optimizer weights here. reload_flag = latest restore_steps = int(reload_flag.split('-')[-1]) # Create TB logger. save_config(config, save_folder) writer = tf.summary.create_file_writer(save_folder) logger = ExperimentLogger(writer) else: save_folder = None ckpt_folder = None writer = None logger = None # Get dataset. dataset = get_data_fs(env_config, load_train=True) # Get data iterators. if env_config.dataset in ["matterport", "roaming-rooms"]: data = get_dataiter_sim( dataset, data_config, batch_size=config.optimizer_config.batch_size, nchw=mem_model.backbone.config.data_format == 'NCHW', distributed=True, seed=args.seed + restore_steps) else: data = get_dataiter_continual( dataset, data_config, batch_size=config.optimizer_config.batch_size, nchw=mem_model.backbone.config.data_format == 'NCHW', save_additional_info=True, random_box=data_config.random_box, distributed=True, seed=args.seed + restore_steps) # Load model, training loop. if args.pretrain is not None and reload_flag is None: mem_model.load(latest_file(args.pretrain, 'weights-')) if config.freeze_backbone: model.set_trainable(False) # Freeze the network. log.info('Backbone network is now frozen') with writer.as_default() if writer is not None else dummy_context_mgr( ) as gs: train(mem_model, data['train_fs'], data['trainval_fs'], data['val_fs'], ckpt_folder, final_save_folder=save_folder, maxlen=data_config.maxlen, logger=logger, writer=writer, is_chief=is_chief, reload_flag=reload_flag)
def main(): assert tf.executing_eagerly(), 'Only eager mode is supported.' assert args.config is not None, 'You need to pass in model config file path' assert args.data is not None, 'You need to pass in episode config file path' assert args.env is not None, 'You need to pass in environ config file path' assert args.tag is not None, 'You need to specify a tag' log.info('Command line args {}'.format(args)) config = get_config(args.config, ExperimentConfig) data_config = get_config(args.data, EpisodeConfig) env_config = get_config(args.env, EnvironmentConfig) log.info('Model: \n{}'.format(config)) log.info('Data episode: \n{}'.format(data_config)) log.info('Environment: \n{}'.format(env_config)) config.num_classes = data_config.nway # Assign num classes. config.num_steps = data_config.maxlen config.memory_net_config.max_classes = data_config.nway config.memory_net_config.max_stages = data_config.nstage config.memory_net_config.max_items = data_config.maxlen config.oml_config.num_classes = data_config.nway config.fix_unknown = data_config.fix_unknown # Assign fix unknown ID. log.info('Number of classes {}'.format(data_config.nway)) if 'SLURM_JOB_ID' in os.environ: log.info('SLURM job ID: {}'.format(os.environ['SLURM_JOB_ID'])) # Create save folder. save_folder = os.path.join(env_config.results, env_config.dataset, args.tag) results_file = os.path.join(save_folder, 'results_tsne.pkl') # Plot features if os.path.exists(results_file): batch_list = pkl.load(open(results_file, 'rb')) return visualize(batch_list, save_folder) # To get CNN features from. model = build_pretrain_net(config) if args.rnn: rnn_memory = get_module( "lstm", "lstm", model.backbone.get_output_dimension()[0], config.lstm_config.hidden_dim, layernorm=config.lstm_config.layernorm, dtype=tf.float32) proto_memory = get_module( 'ssl_min_dist_proto_memory', 'proto_memory', config.memory_net_config.radius_init, max_classes=config.memory_net_config.max_classes, fix_unknown=config.fix_unknown, unknown_id=config.num_classes if config.fix_unknown else None, similarity=config.memory_net_config.similarity, static_beta_gamma=not config.hybrid_config.use_pred_beta_gamma, dtype=tf.float32) memory = RNNEncoder( "proto_plus_rnn_ssl_v4", rnn_memory, proto_memory, readout_type=config.mann_config.readout_type, use_pred_beta_gamma=config.hybrid_config.use_pred_beta_gamma, use_feature_fuse=config.hybrid_config.use_feature_fuse, use_feature_fuse_gate=config.hybrid_config.use_feature_fuse_gate, use_feature_scaling=config.hybrid_config.use_feature_scaling, use_feature_memory_only=config.hybrid_config.use_feature_memory_only, skip_unk_memory_update=config.hybrid_config.skip_unk_memory_update, use_ssl=config.hybrid_config.use_ssl, use_ssl_beta_gamma_write=config.hybrid_config.use_ssl_beta_gamma_write, use_ssl_temp=config.hybrid_config.use_ssl_temp, dtype=tf.float32) rnn_model = RNNEncoderNet(config, model.backbone, memory) f = lambda x, y: rnn_model.forward(x, y, is_training=False) # NOQA else: f = lambda x, y: model.backbone(x[0], is_training=False)[None, :, :] # NOQA K = config.num_classes # Get dataset. dataset = get_data_fs(env_config, load_train=True) # Get data iterators. if env_config.dataset in ["matterport"]: data = get_dataiter_sim( dataset, data_config, batch_size=config.optimizer_config.batch_size, nchw=model.backbone.config.data_format == 'NCHW', seed=args.seed) else: data = get_dataiter_continual( dataset, data_config, batch_size=config.optimizer_config.batch_size, nchw=model.backbone.config.data_format == 'NCHW', save_additional_info=True, random_box=data_config.random_box, seed=args.seed) if args.usebest: latest = latest_file(save_folder, 'best-') else: latest = latest_file(save_folder, 'weights-') if args.rnn: rnn_model.load(latest) else: model.load(latest) data['trainval_fs'].reset() data['val_fs'].reset() data['test_fs'].reset() results_all = {} split_list = ['test_fs'] name_list = ['Test'] nepisode_list = [100] for split, name, N in zip(split_list, name_list, nepisode_list): data[split].reset() data_fname = '{}/{}.pkl'.format(env_config.data_folder, split) if not os.path.exists(data_fname): batch_list = [] for i, batch in zip(range(N), data[split]): for k in batch.keys(): batch[k] = batch[k] batch_list.append(batch) pkl.dump(batch_list, open(data_fname, 'wb')) batch_list = pkl.load(open(data_fname, 'rb')) r1 = evaluate(f, K, batch_list, N) results_all[split] = r1 pkl.dump(results_all, open(results_file, 'wb'))