Exemplo n.º 1
0
def eval_network(fn_in_model):
    # Input
    #  fn_in_model : filename of saved model
    #
    # Create filename for output
    fn_out_res = fn_in_model
    fn_out_res = fn_out_res.replace('.tar', '.txt')
    fn_out_res_test = fn_out_res.replace('/net_', '/res_test_')

    # Load and evaluate the network in filename 'fn_in_model'
    assert (os.path.isfile(fn_in_model))
    print('  Checkpoint found...')
    print('  Processing model: ' + fn_in_model)
    print('  Writing to file: ' + fn_out_res_test)
    checkpoint = torch.load(fn_in_model,
                            map_location='cpu')  # evaluate model on CPU
    input_lang = checkpoint['input_lang']
    output_lang = checkpoint['output_lang']
    emb_size = checkpoint['emb_size']
    nlayers = checkpoint['nlayers']
    dropout_p = checkpoint['dropout']
    input_size = input_lang.n_symbols
    output_size = output_lang.n_symbols
    samples_val = checkpoint['episodes_validation']
    disable_memory = checkpoint['disable_memory']
    max_length_eval = checkpoint['max_length_eval']
    if 'args' not in checkpoint or 'disable_attention' not in checkpoint[
            'args']:
        use_attention = True
    else:
        args = checkpoint['args']
        use_attention = not args.disable_attention
    if disable_memory:
        encoder = WrapperEncoderRNN(emb_size, input_size, output_size, nlayers,
                                    dropout_p)
    else:
        encoder = MetaNetRNN(emb_size, input_size, output_size, nlayers,
                             dropout_p)
    if use_attention:
        decoder = AttnDecoderRNN(emb_size, output_size, nlayers, dropout_p)
    else:
        decoder = DecoderRNN(emb_size, output_size, nlayers, dropout_p)
    if USE_CUDA:
        encoder = encoder.cuda()
        decoder = decoder.cuda()
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    decoder.load_state_dict(checkpoint['decoder_state_dict'])

    with open(fn_out_res_test, 'w') as f_test:
        with redirect_stdout(f_test):
            if 'episode' in checkpoint:
                print(' Loading epoch ' + str(checkpoint['episode']) + ' of ' +
                      str(checkpoint['num_episodes']))
            describe_model(encoder)
            describe_model(decoder)
            if eval_type == 'val':
                print(
                    'Evaluating VALIDATION performance on pre-generated validation set'
                )
                acc_val_gen, acc_val_retrieval = evaluation_battery(
                    samples_val,
                    encoder,
                    decoder,
                    input_lang,
                    output_lang,
                    max_length_eval,
                    verbose=True)
                print('Acc Retrieval (val): ' +
                      str(round(acc_val_retrieval, 1)))
                print('Acc Generalize (val): ' + str(round(acc_val_gen, 1)))
            elif eval_type == 'addprim_jump':
                print('Evaluating TEST performance on SCAN addprim_jump')
                print('  ...support set is just the isolated primitives')
                mybatch = scan_evaluation_prim_only('addprim_jump', 'test',
                                                    input_lang, output_lang)
                acc_val_gen, acc_val_retrieval = evaluation_battery(
                    [mybatch],
                    encoder,
                    decoder,
                    input_lang,
                    output_lang,
                    max_length_eval,
                    verbose=True)
            elif eval_type == 'length':
                print('Evaluating TEST performance on SCAN length')
                print(
                    '  ...over multiple support sets as contributed by the pre-generated validation set'
                )
                samples_val = scan_evaluation_val_support(
                    'length', 'test', input_lang, output_lang, samples_val)
                acc_val_gen, acc_val_retrieval = evaluation_battery(
                    samples_val,
                    encoder,
                    decoder,
                    input_lang,
                    output_lang,
                    max_length_eval,
                    verbose=True)
                print('Acc Retrieval (val): ' +
                      str(round(acc_val_retrieval, 1)))
                print('Acc Generalize (val): ' + str(round(acc_val_gen, 1)))
            elif eval_type == 'template_around_right':
                print('Evaluating TEST performance on the SCAN around right')
                print(' ...with just direction mappings as support set')
                mybatch = scan_evaluation_dir_only('template_around_right',
                                                   'test', input_lang,
                                                   output_lang)
                acc_val_gen, acc_val_retrieval = evaluation_battery(
                    [mybatch],
                    encoder,
                    decoder,
                    input_lang,
                    output_lang,
                    max_length_eval,
                    verbose=True)
            else:
                assert False
Exemplo n.º 2
0
        print('  Set learning rate to ' + str(adam_learning_rate))
        encoder_optimizer = optim.Adam(encoder.parameters(),
                                       lr=adam_learning_rate)
        decoder_optimizer = optim.Adam(decoder.parameters(),
                                       lr=adam_learning_rate)
        print("")
        print("Architecture options...")
        print(" Decoder attention is USED") if use_attention else print(
            " Decoder attention is NOT used")
        print(" External memory is USED") if not disable_memory else print(
            " External memory is NOT used")
        print(" Reconstruction loss is USED"
              ) if not disable_recon_loss else print(
                  " Reconstruction loss is NOT used")
        print("")
        describe_model(encoder)
        describe_model(decoder)

        # create validation episodes
        tabu_episodes = set([])
        samples_val = []
        for i in range(num_episodes_val):
            sample = generate_episode_test(tabu_episodes)
            samples_val.append(sample)
            tabu_episodes = tabu_update(tabu_episodes, sample['identifier'])

        # train over a set of random episodes
        avg_train_loss = 0.
        counter = 0  # used to count updates since the loss was last reported
        start = time.time()
        for episode in range(1, num_episodes + 1):