Example #1
0
def eval(model_path):
    vector_length = 8
    memory_size = (128, 20)
    hidden_layer_size = 100
    lstm_controller = not args.ff

    model = NTM(vector_length, hidden_layer_size, memory_size, lstm_controller)

    print(f"Loading model from {model_path}")
    checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint)
    model.eval()

    lengths = [20, 100]
    for l in lengths:
        sequence_length = l
        input, target = get_training_sequence(sequence_length, sequence_length,
                                              vector_length)
        state = model.get_initial_state()
        for vector in input:
            _, state = model(vector, state)
        y_out = torch.zeros(target.size())
        for j in range(len(target)):
            y_out[j], state = model(torch.zeros(1, vector_length + 1), state)
        y_out_binarized = y_out.clone().data
        y_out_binarized.apply_(lambda x: 0 if x < 0.5 else 1)

        plot_copy_results(target, y_out, vector_length)