def test_combiner_d2_4():
    torch.manual_seed(1)
    combiner = FFCombiner(6)
    head_feat = ag.Variable(torch.randn(1, 6))
    modifier_feat = ag.Variable(torch.randn(1, 6))
    combined = combiner(head_feat, modifier_feat)
    combined_list = combined.view(-1).data.tolist()
    true_out = [-0.1194517, 0.31343767, 0.29966655, 0.26423377, -0.09193783, -0.56594414]
    check_tensor_correctness([(combined_list, true_out)])
def test_predict_after_train_d3_1():
    global test_sent, gold, word_to_ix, vocab
    torch.manual_seed(1)
    feat_extract = SimpleFeatureExtractor()
    word_embed = VanillaWordEmbedding(word_to_ix, TEST_EMBEDDING_DIM)
    act_chooser = FFActionChooser(TEST_EMBEDDING_DIM * NUM_FEATURES)
    combiner = FFCombiner(TEST_EMBEDDING_DIM)

    parser = TransitionParser(feat_extract, word_embed, act_chooser, combiner)

    # Train
    for i in range(75):
        train([(test_sent[:-1], gold)],
              parser,
              optim.SGD(parser.parameters(), lr=0.01),
              verbose=False)

    # predict
    pred = parser.predict(test_sent[:-1])
    gold_graph = dependency_graph_from_oracle(test_sent[:-1], gold)
    assert pred == gold_graph
def test_parse_logic_d3_1():
    global test_sent, gold, word_to_ix, vocab
    torch.manual_seed(1)

    feat_extract = SimpleFeatureExtractor()
    word_embed = VanillaWordEmbedding(word_to_ix, TEST_EMBEDDING_DIM)
    act_chooser = FFActionChooser(TEST_EMBEDDING_DIM * NUM_FEATURES)
    combiner = FFCombiner(TEST_EMBEDDING_DIM)

    parser = TransitionParser(feat_extract, word_embed, act_chooser, combiner)
    output, dep_graph, actions_done = parser(test_sent[:-1], gold)

    assert len(output) == 16  # Made the right number of decisions

    # check one of the outputs
    checked_out = output[9].view(-1).data.tolist()
    true_out = [-1.2444578409194946, -1.3128550052642822, -0.8145193457603455]
    check_tensor_correctness([(true_out, checked_out)])

    true_dep_graph = dependency_graph_from_oracle(test_sent, gold)
    assert true_dep_graph == dep_graph
    assert actions_done == [0, 1, 0, 1, 0, 0, 1, 2, 0, 0, 0, 1, 2, 2, 2, 0]