Exemplo n.º 1
0
 def test_get_valid_actions_uses_top_of_stack(self):
     s_actions = object()
     t_actions = object()
     e_actions = object()
     state = GrammarStatelet(['s'], {'s': s_actions, 't': t_actions}, is_nonterminal)
     assert state.get_valid_actions() == s_actions
     state = GrammarStatelet(['t'], {'s': s_actions, 't': t_actions}, is_nonterminal)
     assert state.get_valid_actions() == t_actions
     state = GrammarStatelet(['e'], {'s': s_actions, 't': t_actions, 'e': e_actions}, is_nonterminal)
     assert state.get_valid_actions() == e_actions
Exemplo n.º 2
0
 def test_get_valid_actions_uses_top_of_stack(self):
     s_actions = object()
     t_actions = object()
     e_actions = object()
     state = GrammarStatelet(["s"], {"s": s_actions, "t": t_actions}, is_nonterminal)
     assert state.get_valid_actions() == s_actions
     state = GrammarStatelet(["t"], {"s": s_actions, "t": t_actions}, is_nonterminal)
     assert state.get_valid_actions() == t_actions
     state = GrammarStatelet(
         ["e"], {"s": s_actions, "t": t_actions, "e": e_actions}, is_nonterminal
     )
     assert state.get_valid_actions() == e_actions
 def test_get_valid_actions_adds_lambda_productions(self):
     state = GrammarStatelet(
         ['s'], {('s', 'x'): ['s']},
         {
             's': {
                 'global':
                 (torch.Tensor([1, 1]), torch.Tensor([2, 2]), [1, 2])
             }
         }, {'s -> x':
             (torch.Tensor([5]), torch.Tensor([6]), 5)}, is_nonterminal)
     actions = state.get_valid_actions()
     assert_almost_equal(actions['global'][0].cpu().numpy(), [1, 1, 5])
     assert_almost_equal(actions['global'][1].cpu().numpy(), [2, 2, 6])
     assert actions['global'][2] == [1, 2, 5]
     # We're doing this assert twice to make sure we haven't accidentally modified the state.
     actions = state.get_valid_actions()
     assert_almost_equal(actions['global'][0].cpu().numpy(), [1, 1, 5])
     assert_almost_equal(actions['global'][1].cpu().numpy(), [2, 2, 6])
     assert actions['global'][2] == [1, 2, 5]
Exemplo n.º 4
0
 def test_get_valid_actions_uses_top_of_stack(self):
     s_actions = object()
     t_actions = object()
     e_actions = object()
     state = GrammarStatelet(['s'], {
         's': s_actions,
         't': t_actions
     }, is_nonterminal)
     assert state.get_valid_actions() == s_actions
     state = GrammarStatelet(['t'], {
         's': s_actions,
         't': t_actions
     }, is_nonterminal)
     assert state.get_valid_actions() == t_actions
     state = GrammarStatelet(['e'], {
         's': s_actions,
         't': t_actions,
         'e': e_actions
     }, is_nonterminal)
     assert state.get_valid_actions() == e_actions