def test_greedy_decode(mocker): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] correct_actions = [ NonTerminalAction('S'), NonTerminalAction('NP'), ShiftAction(), ReduceAction(), NonTerminalAction('VP'), ShiftAction(), NonTerminalAction('NP'), ShiftAction(), ReduceAction(), ReduceAction(), ReduceAction(), ] retvals = [ Variable( torch.zeros(len(action_store)).scatter_( 0, torch.LongTensor([action_store[a]]), 1)) for a in correct_actions ] parser = DiscriminativeRnnGrammar(word2id, pos2id, nt2id, action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) mocker.patch.object(parser, 'forward', side_effect=retvals) result = greedy_decode(parser) assert len(result) == len(correct_actions) picked_actions, log_probs = zip(*result) assert list(picked_actions) == correct_actions
def test_init_no_reduce_action(self): action_store = ItemStore() actions = [NonTerminalAction('S'), NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction()] for a in actions: action_store.add(a) with pytest.raises(ValueError): DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store)
def test_push_known_non_terminal_but_unknown_action(self): actions = [NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction(), ReduceAction()] action_store = ItemStore() for a in actions: action_store.add(a) words = ['John'] pos_tags = ['NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) with pytest.raises(KeyError): parser.push_non_terminal('S')
class TestOracleDataset: bracketed_sents = [ '(S (NP (NNP John)) (VP (VBZ loves) (NP (NNP Mary))))', '(S (NP (NNP Mary)) (VP (VBZ hates) (NP (NNP John))))' # poor John ] words = {'John', 'loves', 'hates', 'Mary'} pos_tags = {'NNP', 'VBZ'} nt_labels = {'S', 'NP', 'VP'} actions = { NonTerminalAction('S'), NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction(), ReduceAction() } def test_init(self): oracles = [ DiscriminativeOracle.from_parsed_sentence(Tree.fromstring(s)) for s in self.bracketed_sents ] dataset = OracleDataset(oracles) assert isinstance(dataset.word_store, ItemStore) assert set(dataset.word_store) == self.words assert isinstance(dataset.pos_store, ItemStore) assert set(dataset.pos_store) == self.pos_tags assert isinstance(dataset.nt_store, ItemStore) assert set(dataset.nt_store) == self.nt_labels assert isinstance(dataset.action_store, ItemStore) assert set(dataset.action_store) == self.actions def test_getitem(self): oracles = [ DiscriminativeOracle.from_parsed_sentence(Tree.fromstring(s)) for s in self.bracketed_sents ] dataset = OracleDataset(oracles) assert oracles[0] is dataset[0] assert oracles[1] is dataset[1] def test_len(self): oracles = [ DiscriminativeOracle.from_parsed_sentence(Tree.fromstring(s)) for s in self.bracketed_sents ] dataset = OracleDataset(oracles) assert len(dataset) == len(oracles)
def test_forward_with_illegal_actions(self): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) action_probs = parser().exp().data assert action_probs[self.action_store[NonTerminalAction('S')]] > 0. assert action_probs[self.action_store[NonTerminalAction('NP')]] > 0. assert action_probs[self.action_store[NonTerminalAction('VP')]] > 0. assert -0.001 <= action_probs[self.action_store[ShiftAction()]] <= 0.001 assert -0.001 <= action_probs[self.action_store[ReduceAction()]] <= 0.001
def get_actions(cls, tree: Tree) -> List[Action]: if len(tree) == 1 and not isinstance(tree[0], Tree): return [cls.get_action_at_pos_node(tree)] actions: List[Action] = [NonTerminalAction(tree.label())] for child in tree: actions.extend(cls.get_actions(child)) actions.append(ReduceAction()) return actions
def push_non_terminal(self, nonterm: NonTerminalLabel) -> None: if nonterm not in self.nt2id: raise KeyError(f"unknown nonterminal '{nonterm}' encountered") action = NonTerminalAction(nonterm) if action not in self.action_store: raise KeyError(f"unknown action '{action}' encountered") if not self.can_push_non_terminal(): raise IllegalActionError(f"Illegal NT({nonterm}) action taken.") self._push_non_terminal(nonterm) self._append_history(action)
def test_from_string(self): s = 'NNP VBZ\nNT(S)\nGEN(asdf)\nGEN(fdsa)\nREDUCE' oracle = GenerativeOracle.from_string(s) assert isinstance(oracle, GenerativeOracle) assert oracle.words == ['asdf', 'fdsa'] assert oracle.pos_tags == ['NNP', 'VBZ'] assert oracle.actions == [ NonTerminalAction('S'), GenerateAction('asdf'), GenerateAction('fdsa'), ReduceAction() ]
def test_from_string(self): s = 'asdf fdsa\nNNP VBZ\nNT(S)\nSHIFT\nSHIFT\nREDUCE' oracle = DiscriminativeOracle.from_string(s) assert isinstance(oracle, DiscriminativeOracle) assert oracle.words == ['asdf', 'fdsa'] assert oracle.pos_tags == ['NNP', 'VBZ'] assert oracle.actions == [ NonTerminalAction('S'), ShiftAction(), ShiftAction(), ReduceAction() ]
def test_from_parsed_sent(self): s = '(S (NP (NNP John)) (VP (VBZ loves) (NP (NNP Mary))))' expected_actions = [ NonTerminalAction('S'), NonTerminalAction('NP'), ShiftAction(), ReduceAction(), NonTerminalAction('VP'), ShiftAction(), NonTerminalAction('NP'), ShiftAction(), ReduceAction(), ReduceAction(), ReduceAction(), ] expected_pos_tags = ['NNP', 'VBZ', 'NNP'] expected_words = ['John', 'loves', 'Mary'] oracle = DiscriminativeOracle.from_parsed_sentence(Tree.fromstring(s)) assert isinstance(oracle, DiscriminativeOracle) assert oracle.actions == expected_actions assert oracle.pos_tags == expected_pos_tags assert oracle.words == expected_words
def test_do_non_terminal_action(self): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) prev_input_buffer = parser.input_buffer parser.push_non_terminal('S') assert len(parser.stack_buffer) == 1 last = parser.stack_buffer[-1] assert isinstance(last, Tree) assert last.label() == 'S' assert len(last) == 0 assert parser.input_buffer == prev_input_buffer assert len(parser.action_history) == 1 assert parser.action_history[-1] == NonTerminalAction('S') assert not parser.finished
import torch from torch.autograd import Variable from rnng.models import DiscriminativeRnnGrammar from rnng.utils import ItemStore from rnng.actions import ShiftAction, ReduceAction, NonTerminalAction from rnng.decoding import greedy_decode word2id = {'John': 0, 'loves': 1, 'Mary': 2} pos2id = {'NNP': 0, 'VBZ': 1} nt2id = {'S': 0, 'NP': 1, 'VP': 2} actions = [ NonTerminalAction('S'), NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction(), ReduceAction() ] action_store = ItemStore() for a in actions: action_store.add(a) def test_greedy_decode(mocker): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] correct_actions = [ NonTerminalAction('S'), NonTerminalAction('NP'), ShiftAction(), ReduceAction(),
def test_from_invalid_string(self): with pytest.raises(ValueError): NonTerminalAction.from_string('asdf')
def test_from_string(self): label = 'NP' a = NonTerminalAction.from_string(self.as_str.format(label=label)) assert isinstance(a, NonTerminalAction) assert a.label == label
def test_eq(self): a = NonTerminalAction('NP') assert a == NonTerminalAction(a.label) assert a != NonTerminalAction('asdf') assert a != ShiftAction()
def test_hash(self): label = 'NP' a = NonTerminalAction(label) assert hash(a) == hash(self.as_str.format(label=label))
def test_to_string(self): label = 'NP' a = NonTerminalAction(label) assert str(a) == self.as_str.format(label=label)
class TestDiscRNNGrammar: word2id = {'John': 0, 'loves': 1, 'Mary': 2} pos2id = {'NNP': 0, 'VBZ': 1} nt2id = {'S': 2, 'NP': 1, 'VP': 0} action_store = ItemStore() actions = [NonTerminalAction('S'), NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction(), ReduceAction()] for a in actions: action_store.add(a) def test_init(self): parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) assert len(parser.stack_buffer) == 0 assert len(parser.input_buffer) == 0 assert len(parser.action_history) == 0 assert not parser.finished assert not parser.started def test_init_no_shift_action(self): action_store = ItemStore() actions = [NonTerminalAction('S'), NonTerminalAction('NP'), NonTerminalAction('VP'), ReduceAction()] for a in actions: action_store.add(a) with pytest.raises(ValueError): DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store) def test_init_no_reduce_action(self): action_store = ItemStore() actions = [NonTerminalAction('S'), NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction()] for a in actions: action_store.add(a) with pytest.raises(ValueError): DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store) def test_init_word_id_out_of_range(self): word2id = dict(self.word2id) word2id['John'] = len(word2id) with pytest.raises(ValueError): DiscriminativeRnnGrammar(word2id, self.pos2id, self.nt2id, self.action_store) word2id['John'] = -1 with pytest.raises(ValueError): DiscriminativeRnnGrammar(word2id, self.pos2id, self.nt2id, self.action_store) def test_init_pos_id_out_of_range(self): pos2id = dict(self.pos2id) pos2id['NNP'] = len(pos2id) with pytest.raises(ValueError): DiscriminativeRnnGrammar(self.word2id, pos2id, self.nt2id, self.action_store) pos2id['NNP'] = -1 with pytest.raises(ValueError): DiscriminativeRnnGrammar(self.word2id, pos2id, self.nt2id, self.action_store) def test_init_non_terminal_id_out_of_range(self): nt2id = dict(self.nt2id) nt2id['S'] = len(nt2id) with pytest.raises(ValueError): DiscriminativeRnnGrammar(self.word2id, self.pos2id, nt2id, self.action_store) nt2id['S'] = -1 with pytest.raises(ValueError): DiscriminativeRnnGrammar(self.word2id, self.pos2id, nt2id, self.action_store) def test_initialise_stacks_and_buffers(self): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) assert len(parser.stack_buffer) == 0 assert parser.input_buffer == words assert len(parser.action_history) == 0 assert not parser.finished assert parser.started def test_initialise_with_empty_tagged_words(self): parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) with pytest.raises(ValueError): parser.initialise_stacks_and_buffers([]) def test_initalise_with_invalid_word_or_pos(self): parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) with pytest.raises(ValueError): parser.initialise_stacks_and_buffers([('Bob', 'NNP')]) with pytest.raises(ValueError): parser.initialise_stacks_and_buffers([('John', 'VBD')]) def test_do_non_terminal_action(self): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) prev_input_buffer = parser.input_buffer parser.push_non_terminal('S') assert len(parser.stack_buffer) == 1 last = parser.stack_buffer[-1] assert isinstance(last, Tree) assert last.label() == 'S' assert len(last) == 0 assert parser.input_buffer == prev_input_buffer assert len(parser.action_history) == 1 assert parser.action_history[-1] == NonTerminalAction('S') assert not parser.finished def test_do_illegal_push_non_terminal_action(self): words = ['John'] pos_tags = ['NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) # Buffer is empty parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') parser.shift() with pytest.raises(IllegalActionError): parser.push_non_terminal('NP') # More than 100 open nonterminals parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) for i in range(100): parser.push_non_terminal('S') with pytest.raises(IllegalActionError): parser.push_non_terminal('NP') def test_push_unknown_non_terminal(self): words = ['John'] pos_tags = ['NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) with pytest.raises(KeyError): parser.push_non_terminal('asdf') def test_push_known_non_terminal_but_unknown_action(self): actions = [NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction(), ReduceAction()] action_store = ItemStore() for a in actions: action_store.add(a) words = ['John'] pos_tags = ['NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) with pytest.raises(KeyError): parser.push_non_terminal('S') def test_do_shift_action(self): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') parser.push_non_terminal('NP') parser.shift() assert len(parser.stack_buffer) == 3 last = parser.stack_buffer[-1] assert last == 'John' assert parser.input_buffer == words[1:] assert len(parser.action_history) == 3 assert parser.action_history[-1] == ShiftAction() assert not parser.finished def test_do_illegal_shift_action(self): words = ['John'] pos_tags = ['NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) # No open nonterminal parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) with pytest.raises(IllegalActionError): parser.shift() # Buffer is empty parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') parser.shift() with pytest.raises(IllegalActionError): parser.shift() def test_do_reduce_action(self): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') parser.push_non_terminal('NP') parser.shift() prev_input_buffer = parser.input_buffer parser.reduce() assert len(parser.stack_buffer) == 2 last = parser.stack_buffer[-1] assert isinstance(last, Tree) assert last.label() == 'NP' assert len(last) == 1 assert last[0] == 'John' assert parser.input_buffer == prev_input_buffer assert len(parser.action_history) == 4 assert parser.action_history[-1] == ReduceAction() assert not parser.finished def test_do_illegal_reduce_action(self): words = ['John', 'loves'] pos_tags = ['NNP', 'VBZ'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) # Top of stack is an open nonterminal parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') with pytest.raises(IllegalActionError): parser.reduce() # Buffer is not empty and REDUCE will finish parsing parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') parser.shift() with pytest.raises(IllegalActionError): parser.reduce() def test_do_action_when_not_started(self): parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) with pytest.raises(RuntimeError): parser.push_non_terminal('S') with pytest.raises(RuntimeError): parser.shift() with pytest.raises(RuntimeError): parser.reduce() def test_forward(self): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') parser.push_non_terminal('NP') parser.shift() parser.reduce() action_logprobs = parser() assert isinstance(action_logprobs, Variable) assert action_logprobs.size() == (len(self.action_store),) sum_prob = action_logprobs.exp().sum().data[0] assert 0.999 <= sum_prob <= 1.001 def test_forward_with_illegal_actions(self): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) action_probs = parser().exp().data assert action_probs[self.action_store[NonTerminalAction('S')]] > 0. assert action_probs[self.action_store[NonTerminalAction('NP')]] > 0. assert action_probs[self.action_store[NonTerminalAction('VP')]] > 0. assert -0.001 <= action_probs[self.action_store[ShiftAction()]] <= 0.001 assert -0.001 <= action_probs[self.action_store[ReduceAction()]] <= 0.001 def test_forward_when_not_started(self): parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) with pytest.raises(RuntimeError): parser() def test_finished(self): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) exp_parse_tree = Tree('S', [Tree('NP', ['John']), Tree('VP', ['loves', Tree('NP', ['Mary'])])]) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') parser.push_non_terminal('NP') parser.shift() parser.reduce() parser.push_non_terminal('VP') parser.shift() parser.push_non_terminal('NP') parser.shift() parser.reduce() parser.reduce() parser.reduce() assert parser.finished parse_tree = parser.stack_buffer[-1] assert parse_tree == exp_parse_tree with pytest.raises(RuntimeError): parser() with pytest.raises(RuntimeError): parser.push_non_terminal('NP') with pytest.raises(RuntimeError): parser.shift() with pytest.raises(RuntimeError): parser.reduce()