コード例 #1
0
 def test_get_valid_actions_adds_lambda_productions_only_for_correct_type(
         self):
     state = LambdaGrammarStatelet(
         ["t"],
         {("s", "x"): ["t"]},
         {
             "s": {
                 "global":
                 (torch.Tensor([1, 1]), torch.Tensor([2, 2]), [1, 2])
             },
             "t": {
                 "global":
                 (torch.Tensor([3, 3]), torch.Tensor([4, 4]), [3, 4])
             },
         },
         {"s -> x": (torch.Tensor([5]), torch.Tensor([6]), 5)},
         is_nonterminal,
     )
     actions = state.get_valid_actions()
     assert_almost_equal(actions["global"][0].cpu().numpy(), [3, 3])
     assert_almost_equal(actions["global"][1].cpu().numpy(), [4, 4])
     assert actions["global"][2] == [3, 4]
     # We're doing this assert twice to make sure we haven't accidentally modified the state.
     actions = state.get_valid_actions()
     assert_almost_equal(actions["global"][0].cpu().numpy(), [3, 3])
     assert_almost_equal(actions["global"][1].cpu().numpy(), [4, 4])
     assert actions["global"][2] == [3, 4]
コード例 #2
0
 def test_get_valid_actions_uses_top_of_stack(self):
     s_actions = object()
     t_actions = object()
     e_actions = object()
     state = LambdaGrammarStatelet(['s'], {}, {'s': s_actions, 't': t_actions}, {}, is_nonterminal)
     assert state.get_valid_actions() == s_actions
     state = LambdaGrammarStatelet(['t'], {}, {'s': s_actions, 't': t_actions}, {}, is_nonterminal)
     assert state.get_valid_actions() == t_actions
     state = LambdaGrammarStatelet(['e'],
                                   {},
                                   {'s': s_actions, 't': t_actions, 'e': e_actions},
                                   {},
                                   is_nonterminal)
     assert state.get_valid_actions() == e_actions
コード例 #3
0
 def test_get_valid_actions_adds_lambda_productions(self):
     state = LambdaGrammarStatelet(['s'],
                                   {('s', 'x'): ['s']},
                                   {'s': {'global': (torch.Tensor([1, 1]), torch.Tensor([2, 2]), [1, 2])}},
                                   {'s -> x': (torch.Tensor([5]), torch.Tensor([6]), 5)},
                                   is_nonterminal)
     actions = state.get_valid_actions()
     assert_almost_equal(actions['global'][0].cpu().numpy(), [1, 1, 5])
     assert_almost_equal(actions['global'][1].cpu().numpy(), [2, 2, 6])
     assert actions['global'][2] == [1, 2, 5]
     # We're doing this assert twice to make sure we haven't accidentally modified the state.
     actions = state.get_valid_actions()
     assert_almost_equal(actions['global'][0].cpu().numpy(), [1, 1, 5])
     assert_almost_equal(actions['global'][1].cpu().numpy(), [2, 2, 6])
     assert actions['global'][2] == [1, 2, 5]
コード例 #4
0
 def test_get_valid_actions_uses_top_of_stack(self):
     s_actions = object()
     t_actions = object()
     e_actions = object()
     state = LambdaGrammarStatelet(["s"], {}, {
         "s": s_actions,
         "t": t_actions
     }, {}, is_nonterminal)
     assert state.get_valid_actions() == s_actions
     state = LambdaGrammarStatelet(["t"], {}, {
         "s": s_actions,
         "t": t_actions
     }, {}, is_nonterminal)
     assert state.get_valid_actions() == t_actions
     state = LambdaGrammarStatelet(["e"], {}, {
         "s": s_actions,
         "t": t_actions,
         "e": e_actions
     }, {}, is_nonterminal)
     assert state.get_valid_actions() == e_actions