def test_parse_logic_d3_1(): global test_sent, gold, word_to_ix, vocab torch.manual_seed(1) feat_extract = SimpleFeatureExtractor() word_embed = VanillaWordEmbedding(word_to_ix, TEST_EMBEDDING_DIM) act_chooser = FFActionChooser(TEST_EMBEDDING_DIM * NUM_FEATURES) combiner = FFCombiner(TEST_EMBEDDING_DIM) parser = TransitionParser(feat_extract, word_embed, act_chooser, combiner) output, dep_graph, actions_done = parser(test_sent[:-1], gold) assert len(output) == 16 # Made the right number of decisions # check one of the outputs checked_out = output[9].view(-1).data.tolist() true_out = [-1.2444578409194946, -1.3128550052642822, -0.8145193457603455] check_tensor_correctness([(true_out, checked_out)]) true_dep_graph = dependency_graph_from_oracle(test_sent, gold) assert true_dep_graph == dep_graph assert actions_done == [0, 1, 0, 1, 0, 0, 1, 2, 0, 0, 0, 1, 2, 2, 2, 0]
def test_word_embed_lookup_d2_1(): global test_sent, gold, word_to_ix, vocab torch.manual_seed(1) embedder = VanillaWordEmbedding(word_to_ix, TEST_EMBEDDING_DIM) embeds = embedder(test_sent) assert len(embeds) == len(test_sent) assert isinstance(embeds, list) assert isinstance(embeds[0], ag.Variable) assert embeds[0].size() == (1, TEST_EMBEDDING_DIM) embeds_list = make_list(embeds) true = ([-1.02760863, -0.56305277, -0.89229053, -0.05825018], [-0.42119515, -0.51069999, -1.57266521, -0.12324776], [3.5869894, -1.83129013, 1.59870028, -1.27700698], [0.41074166, -0.98800713, -0.90807337, 0.54227364], [-0.19550958, -0.96563596, 0.42241532, 0.267317], [0.11025489, -2.2590096, 0.60669959, -0.13830966], [0.41074166, -0.98800713, -0.90807337, 0.54227364], [0.32550153, -0.47914493, 1.37900829, 2.5285573], [0.66135216, 0.26692411, 0.06167726, 0.62131733]) pairs = zip(embeds_list, true) check_tensor_correctness(pairs)