def eval_network(fn_in_model): # Input # fn_in_model : filename of saved model # # Create filename for output fn_out_res = fn_in_model fn_out_res = fn_out_res.replace('.tar', '.txt') fn_out_res_test = fn_out_res.replace('/net_', '/res_test_') # Load and evaluate the network in filename 'fn_in_model' assert (os.path.isfile(fn_in_model)) print(' Checkpoint found...') print(' Processing model: ' + fn_in_model) print(' Writing to file: ' + fn_out_res_test) checkpoint = torch.load(fn_in_model, map_location='cpu') # evaluate model on CPU input_lang = checkpoint['input_lang'] output_lang = checkpoint['output_lang'] emb_size = checkpoint['emb_size'] nlayers = checkpoint['nlayers'] dropout_p = checkpoint['dropout'] input_size = input_lang.n_symbols output_size = output_lang.n_symbols samples_val = checkpoint['episodes_validation'] disable_memory = checkpoint['disable_memory'] max_length_eval = checkpoint['max_length_eval'] if 'args' not in checkpoint or 'disable_attention' not in checkpoint[ 'args']: use_attention = True else: args = checkpoint['args'] use_attention = not args.disable_attention if disable_memory: encoder = WrapperEncoderRNN(emb_size, input_size, output_size, nlayers, dropout_p) else: encoder = MetaNetRNN(emb_size, input_size, output_size, nlayers, dropout_p) if use_attention: decoder = AttnDecoderRNN(emb_size, output_size, nlayers, dropout_p) else: decoder = DecoderRNN(emb_size, output_size, nlayers, dropout_p) if USE_CUDA: encoder = encoder.cuda() decoder = decoder.cuda() encoder.load_state_dict(checkpoint['encoder_state_dict']) decoder.load_state_dict(checkpoint['decoder_state_dict']) with open(fn_out_res_test, 'w') as f_test: with redirect_stdout(f_test): if 'episode' in checkpoint: print(' Loading epoch ' + str(checkpoint['episode']) + ' of ' + str(checkpoint['num_episodes'])) describe_model(encoder) describe_model(decoder) if eval_type == 'val': print( 'Evaluating VALIDATION performance on pre-generated validation set' ) acc_val_gen, acc_val_retrieval = evaluation_battery( samples_val, encoder, decoder, input_lang, output_lang, max_length_eval, verbose=True) print('Acc Retrieval (val): ' + str(round(acc_val_retrieval, 1))) print('Acc Generalize (val): ' + str(round(acc_val_gen, 1))) elif eval_type == 'addprim_jump': print('Evaluating TEST performance on SCAN addprim_jump') print(' ...support set is just the isolated primitives') mybatch = scan_evaluation_prim_only('addprim_jump', 'test', input_lang, output_lang) acc_val_gen, acc_val_retrieval = evaluation_battery( [mybatch], encoder, decoder, input_lang, output_lang, max_length_eval, verbose=True) elif eval_type == 'length': print('Evaluating TEST performance on SCAN length') print( ' ...over multiple support sets as contributed by the pre-generated validation set' ) samples_val = scan_evaluation_val_support( 'length', 'test', input_lang, output_lang, samples_val) acc_val_gen, acc_val_retrieval = evaluation_battery( samples_val, encoder, decoder, input_lang, output_lang, max_length_eval, verbose=True) print('Acc Retrieval (val): ' + str(round(acc_val_retrieval, 1))) print('Acc Generalize (val): ' + str(round(acc_val_gen, 1))) elif eval_type == 'template_around_right': print('Evaluating TEST performance on the SCAN around right') print(' ...with just direction mappings as support set') mybatch = scan_evaluation_dir_only('template_around_right', 'test', input_lang, output_lang) acc_val_gen, acc_val_retrieval = evaluation_battery( [mybatch], encoder, decoder, input_lang, output_lang, max_length_eval, verbose=True) else: assert False
print(' Set learning rate to ' + str(adam_learning_rate)) encoder_optimizer = optim.Adam(encoder.parameters(), lr=adam_learning_rate) decoder_optimizer = optim.Adam(decoder.parameters(), lr=adam_learning_rate) print("") print("Architecture options...") print(" Decoder attention is USED") if use_attention else print( " Decoder attention is NOT used") print(" External memory is USED") if not disable_memory else print( " External memory is NOT used") print(" Reconstruction loss is USED" ) if not disable_recon_loss else print( " Reconstruction loss is NOT used") print("") describe_model(encoder) describe_model(decoder) # create validation episodes tabu_episodes = set([]) samples_val = [] for i in range(num_episodes_val): sample = generate_episode_test(tabu_episodes) samples_val.append(sample) tabu_episodes = tabu_update(tabu_episodes, sample['identifier']) # train over a set of random episodes avg_train_loss = 0. counter = 0 # used to count updates since the loss was last reported start = time.time() for episode in range(1, num_episodes + 1):