def __init__(self, oracles: Sequence[Oracle]) -> None: self.oracles = oracles self.word_store = ItemStore() # type: ItemStore[Word] self.pos_store = ItemStore() # type: ItemStore[POSTag] self.nt_store = ItemStore() # type: ItemStore[NonTerminalLabel] self.action_store = ItemStore() # type: ItemStore[Action] self.load()
def test_invertible(self): store = ItemStore() for c in list('aabbcc'): store.add(c) for c in list('abc'): assert store.get_by_id(store[c]) == c
def test_init_no_reduce_action(self): action_store = ItemStore() actions = [NonTerminalAction('S'), NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction()] for a in actions: action_store.add(a) with pytest.raises(ValueError): DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store)
def test_contains(self): store = ItemStore() for c in list('aabbcc'): store.add(c) assert 'a' in store assert 'b' in store assert 'c' in store assert 'd' not in store
def test_unique_ids(self): store = ItemStore() for c in list('aabbcc'): store.add(c) ids = [store[c] for c in list('abc')] assert len(ids) == len(set(ids)) for i in ids: assert 0 <= i < len(store)
def test_push_known_non_terminal_but_unknown_action(self): actions = [NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction(), ReduceAction()] action_store = ItemStore() for a in actions: action_store.add(a) words = ['John'] pos_tags = ['NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) with pytest.raises(KeyError): parser.push_non_terminal('S')
from rnng.models import DiscriminativeRnnGrammar from rnng.utils import ItemStore from rnng.actions import ShiftAction, ReduceAction, NonTerminalAction from rnng.decoding import greedy_decode word2id = {'John': 0, 'loves': 1, 'Mary': 2} pos2id = {'NNP': 0, 'VBZ': 1} nt2id = {'S': 0, 'NP': 1, 'VP': 2} actions = [ NonTerminalAction('S'), NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction(), ReduceAction() ] action_store = ItemStore() for a in actions: action_store.add(a) def test_greedy_decode(mocker): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] correct_actions = [ NonTerminalAction('S'), NonTerminalAction('NP'), ShiftAction(), ReduceAction(), NonTerminalAction('VP'), ShiftAction(), NonTerminalAction('NP'),
class TestDiscRNNGrammar: word2id = {'John': 0, 'loves': 1, 'Mary': 2} pos2id = {'NNP': 0, 'VBZ': 1} nt2id = {'S': 2, 'NP': 1, 'VP': 0} action_store = ItemStore() actions = [NonTerminalAction('S'), NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction(), ReduceAction()] for a in actions: action_store.add(a) def test_init(self): parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) assert len(parser.stack_buffer) == 0 assert len(parser.input_buffer) == 0 assert len(parser.action_history) == 0 assert not parser.finished assert not parser.started def test_init_no_shift_action(self): action_store = ItemStore() actions = [NonTerminalAction('S'), NonTerminalAction('NP'), NonTerminalAction('VP'), ReduceAction()] for a in actions: action_store.add(a) with pytest.raises(ValueError): DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store) def test_init_no_reduce_action(self): action_store = ItemStore() actions = [NonTerminalAction('S'), NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction()] for a in actions: action_store.add(a) with pytest.raises(ValueError): DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store) def test_init_word_id_out_of_range(self): word2id = dict(self.word2id) word2id['John'] = len(word2id) with pytest.raises(ValueError): DiscriminativeRnnGrammar(word2id, self.pos2id, self.nt2id, self.action_store) word2id['John'] = -1 with pytest.raises(ValueError): DiscriminativeRnnGrammar(word2id, self.pos2id, self.nt2id, self.action_store) def test_init_pos_id_out_of_range(self): pos2id = dict(self.pos2id) pos2id['NNP'] = len(pos2id) with pytest.raises(ValueError): DiscriminativeRnnGrammar(self.word2id, pos2id, self.nt2id, self.action_store) pos2id['NNP'] = -1 with pytest.raises(ValueError): DiscriminativeRnnGrammar(self.word2id, pos2id, self.nt2id, self.action_store) def test_init_non_terminal_id_out_of_range(self): nt2id = dict(self.nt2id) nt2id['S'] = len(nt2id) with pytest.raises(ValueError): DiscriminativeRnnGrammar(self.word2id, self.pos2id, nt2id, self.action_store) nt2id['S'] = -1 with pytest.raises(ValueError): DiscriminativeRnnGrammar(self.word2id, self.pos2id, nt2id, self.action_store) def test_initialise_stacks_and_buffers(self): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) assert len(parser.stack_buffer) == 0 assert parser.input_buffer == words assert len(parser.action_history) == 0 assert not parser.finished assert parser.started def test_initialise_with_empty_tagged_words(self): parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) with pytest.raises(ValueError): parser.initialise_stacks_and_buffers([]) def test_initalise_with_invalid_word_or_pos(self): parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) with pytest.raises(ValueError): parser.initialise_stacks_and_buffers([('Bob', 'NNP')]) with pytest.raises(ValueError): parser.initialise_stacks_and_buffers([('John', 'VBD')]) def test_do_non_terminal_action(self): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) prev_input_buffer = parser.input_buffer parser.push_non_terminal('S') assert len(parser.stack_buffer) == 1 last = parser.stack_buffer[-1] assert isinstance(last, Tree) assert last.label() == 'S' assert len(last) == 0 assert parser.input_buffer == prev_input_buffer assert len(parser.action_history) == 1 assert parser.action_history[-1] == NonTerminalAction('S') assert not parser.finished def test_do_illegal_push_non_terminal_action(self): words = ['John'] pos_tags = ['NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) # Buffer is empty parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') parser.shift() with pytest.raises(IllegalActionError): parser.push_non_terminal('NP') # More than 100 open nonterminals parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) for i in range(100): parser.push_non_terminal('S') with pytest.raises(IllegalActionError): parser.push_non_terminal('NP') def test_push_unknown_non_terminal(self): words = ['John'] pos_tags = ['NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) with pytest.raises(KeyError): parser.push_non_terminal('asdf') def test_push_known_non_terminal_but_unknown_action(self): actions = [NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction(), ReduceAction()] action_store = ItemStore() for a in actions: action_store.add(a) words = ['John'] pos_tags = ['NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) with pytest.raises(KeyError): parser.push_non_terminal('S') def test_do_shift_action(self): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') parser.push_non_terminal('NP') parser.shift() assert len(parser.stack_buffer) == 3 last = parser.stack_buffer[-1] assert last == 'John' assert parser.input_buffer == words[1:] assert len(parser.action_history) == 3 assert parser.action_history[-1] == ShiftAction() assert not parser.finished def test_do_illegal_shift_action(self): words = ['John'] pos_tags = ['NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) # No open nonterminal parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) with pytest.raises(IllegalActionError): parser.shift() # Buffer is empty parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') parser.shift() with pytest.raises(IllegalActionError): parser.shift() def test_do_reduce_action(self): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') parser.push_non_terminal('NP') parser.shift() prev_input_buffer = parser.input_buffer parser.reduce() assert len(parser.stack_buffer) == 2 last = parser.stack_buffer[-1] assert isinstance(last, Tree) assert last.label() == 'NP' assert len(last) == 1 assert last[0] == 'John' assert parser.input_buffer == prev_input_buffer assert len(parser.action_history) == 4 assert parser.action_history[-1] == ReduceAction() assert not parser.finished def test_do_illegal_reduce_action(self): words = ['John', 'loves'] pos_tags = ['NNP', 'VBZ'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) # Top of stack is an open nonterminal parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') with pytest.raises(IllegalActionError): parser.reduce() # Buffer is not empty and REDUCE will finish parsing parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') parser.shift() with pytest.raises(IllegalActionError): parser.reduce() def test_do_action_when_not_started(self): parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) with pytest.raises(RuntimeError): parser.push_non_terminal('S') with pytest.raises(RuntimeError): parser.shift() with pytest.raises(RuntimeError): parser.reduce() def test_forward(self): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') parser.push_non_terminal('NP') parser.shift() parser.reduce() action_logprobs = parser() assert isinstance(action_logprobs, Variable) assert action_logprobs.size() == (len(self.action_store),) sum_prob = action_logprobs.exp().sum().data[0] assert 0.999 <= sum_prob <= 1.001 def test_forward_with_illegal_actions(self): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) action_probs = parser().exp().data assert action_probs[self.action_store[NonTerminalAction('S')]] > 0. assert action_probs[self.action_store[NonTerminalAction('NP')]] > 0. assert action_probs[self.action_store[NonTerminalAction('VP')]] > 0. assert -0.001 <= action_probs[self.action_store[ShiftAction()]] <= 0.001 assert -0.001 <= action_probs[self.action_store[ReduceAction()]] <= 0.001 def test_forward_when_not_started(self): parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) with pytest.raises(RuntimeError): parser() def test_finished(self): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) exp_parse_tree = Tree('S', [Tree('NP', ['John']), Tree('VP', ['loves', Tree('NP', ['Mary'])])]) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') parser.push_non_terminal('NP') parser.shift() parser.reduce() parser.push_non_terminal('VP') parser.shift() parser.push_non_terminal('NP') parser.shift() parser.reduce() parser.reduce() parser.reduce() assert parser.finished parse_tree = parser.stack_buffer[-1] assert parse_tree == exp_parse_tree with pytest.raises(RuntimeError): parser() with pytest.raises(RuntimeError): parser.push_non_terminal('NP') with pytest.raises(RuntimeError): parser.shift() with pytest.raises(RuntimeError): parser.reduce()
def test_len(self): store = ItemStore() for c in list('aabbcc'): store.add(c) assert len(store) == 3
def test_iter(self): store = ItemStore() for c in list('aabbcc'): store.add(c) assert sorted(list('abc')) == sorted(store)