def test_get_valid_actions_adds_lambda_productions_only_for_correct_type( self): state = LambdaGrammarStatelet( ["t"], {("s", "x"): ["t"]}, { "s": { "global": (torch.Tensor([1, 1]), torch.Tensor([2, 2]), [1, 2]) }, "t": { "global": (torch.Tensor([3, 3]), torch.Tensor([4, 4]), [3, 4]) }, }, {"s -> x": (torch.Tensor([5]), torch.Tensor([6]), 5)}, is_nonterminal, ) actions = state.get_valid_actions() assert_almost_equal(actions["global"][0].cpu().numpy(), [3, 3]) assert_almost_equal(actions["global"][1].cpu().numpy(), [4, 4]) assert actions["global"][2] == [3, 4] # We're doing this assert twice to make sure we haven't accidentally modified the state. actions = state.get_valid_actions() assert_almost_equal(actions["global"][0].cpu().numpy(), [3, 3]) assert_almost_equal(actions["global"][1].cpu().numpy(), [4, 4]) assert actions["global"][2] == [3, 4]
def test_get_valid_actions_uses_top_of_stack(self): s_actions = object() t_actions = object() e_actions = object() state = LambdaGrammarStatelet(['s'], {}, {'s': s_actions, 't': t_actions}, {}, is_nonterminal) assert state.get_valid_actions() == s_actions state = LambdaGrammarStatelet(['t'], {}, {'s': s_actions, 't': t_actions}, {}, is_nonterminal) assert state.get_valid_actions() == t_actions state = LambdaGrammarStatelet(['e'], {}, {'s': s_actions, 't': t_actions, 'e': e_actions}, {}, is_nonterminal) assert state.get_valid_actions() == e_actions
def test_get_valid_actions_adds_lambda_productions(self): state = LambdaGrammarStatelet(['s'], {('s', 'x'): ['s']}, {'s': {'global': (torch.Tensor([1, 1]), torch.Tensor([2, 2]), [1, 2])}}, {'s -> x': (torch.Tensor([5]), torch.Tensor([6]), 5)}, is_nonterminal) actions = state.get_valid_actions() assert_almost_equal(actions['global'][0].cpu().numpy(), [1, 1, 5]) assert_almost_equal(actions['global'][1].cpu().numpy(), [2, 2, 6]) assert actions['global'][2] == [1, 2, 5] # We're doing this assert twice to make sure we haven't accidentally modified the state. actions = state.get_valid_actions() assert_almost_equal(actions['global'][0].cpu().numpy(), [1, 1, 5]) assert_almost_equal(actions['global'][1].cpu().numpy(), [2, 2, 6]) assert actions['global'][2] == [1, 2, 5]
def test_get_valid_actions_uses_top_of_stack(self): s_actions = object() t_actions = object() e_actions = object() state = LambdaGrammarStatelet(["s"], {}, { "s": s_actions, "t": t_actions }, {}, is_nonterminal) assert state.get_valid_actions() == s_actions state = LambdaGrammarStatelet(["t"], {}, { "s": s_actions, "t": t_actions }, {}, is_nonterminal) assert state.get_valid_actions() == t_actions state = LambdaGrammarStatelet(["e"], {}, { "s": s_actions, "t": t_actions, "e": e_actions }, {}, is_nonterminal) assert state.get_valid_actions() == e_actions