Beispiel #1
0
def test_parse_logic_d3_1():
    global test_sent, gold, word_to_ix, vocab
    torch.manual_seed(1)

    feat_extract = SimpleFeatureExtractor()
    word_embed = VanillaWordEmbedding(word_to_ix, TEST_EMBEDDING_DIM)
    act_chooser = FFActionChooser(TEST_EMBEDDING_DIM * NUM_FEATURES)
    combiner = FFCombiner(TEST_EMBEDDING_DIM)

    parser = TransitionParser(feat_extract, word_embed, act_chooser, combiner)
    output, dep_graph, actions_done = parser(test_sent[:-1], gold)

    assert len(output) == 16  # Made the right number of decisions

    # check one of the outputs
    checked_out = output[9].view(-1).data.tolist()
    true_out = [-1.2444578409194946, -1.3128550052642822, -0.8145193457603455]
    check_tensor_correctness([(true_out, checked_out)])

    true_dep_graph = dependency_graph_from_oracle(test_sent, gold)
    assert true_dep_graph == dep_graph
    assert actions_done == [0, 1, 0, 1, 0, 0, 1, 2, 0, 0, 0, 1, 2, 2, 2, 0]
Beispiel #2
0
def test_word_embed_lookup_d2_1():
    global test_sent, gold, word_to_ix, vocab
    torch.manual_seed(1)

    embedder = VanillaWordEmbedding(word_to_ix, TEST_EMBEDDING_DIM)
    embeds = embedder(test_sent)
    assert len(embeds) == len(test_sent)
    assert isinstance(embeds, list)
    assert isinstance(embeds[0], ag.Variable)
    assert embeds[0].size() == (1, TEST_EMBEDDING_DIM)

    embeds_list = make_list(embeds)

    true = ([-1.02760863, -0.56305277, -0.89229053, -0.05825018],
            [-0.42119515, -0.51069999, -1.57266521,
             -0.12324776], [3.5869894, -1.83129013, 1.59870028, -1.27700698],
            [0.41074166, -0.98800713, -0.90807337,
             0.54227364], [-0.19550958, -0.96563596, 0.42241532, 0.267317],
            [0.11025489, -2.2590096, 0.60669959,
             -0.13830966], [0.41074166, -0.98800713, -0.90807337, 0.54227364],
            [0.32550153, -0.47914493, 1.37900829,
             2.5285573], [0.66135216, 0.26692411, 0.06167726, 0.62131733])
    pairs = zip(embeds_list, true)
    check_tensor_correctness(pairs)