def test_greedy_decode(mocker): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] correct_actions = [ NonTerminalAction('S'), NonTerminalAction('NP'), ShiftAction(), ReduceAction(), NonTerminalAction('VP'), ShiftAction(), NonTerminalAction('NP'), ShiftAction(), ReduceAction(), ReduceAction(), ReduceAction(), ] retvals = [ Variable( torch.zeros(len(action_store)).scatter_( 0, torch.LongTensor([action_store[a]]), 1)) for a in correct_actions ] parser = DiscriminativeRnnGrammar(word2id, pos2id, nt2id, action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) mocker.patch.object(parser, 'forward', side_effect=retvals) result = greedy_decode(parser) assert len(result) == len(correct_actions) picked_actions, log_probs = zip(*result) assert list(picked_actions) == correct_actions
def get_actions(cls, tree: Tree) -> List[Action]: if len(tree) == 1 and not isinstance(tree[0], Tree): return [cls.get_action_at_pos_node(tree)] actions: List[Action] = [NonTerminalAction(tree.label())] for child in tree: actions.extend(cls.get_actions(child)) actions.append(ReduceAction()) return actions
class TestOracleDataset: bracketed_sents = [ '(S (NP (NNP John)) (VP (VBZ loves) (NP (NNP Mary))))', '(S (NP (NNP Mary)) (VP (VBZ hates) (NP (NNP John))))' # poor John ] words = {'John', 'loves', 'hates', 'Mary'} pos_tags = {'NNP', 'VBZ'} nt_labels = {'S', 'NP', 'VP'} actions = { NonTerminalAction('S'), NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction(), ReduceAction() } def test_init(self): oracles = [ DiscriminativeOracle.from_parsed_sentence(Tree.fromstring(s)) for s in self.bracketed_sents ] dataset = OracleDataset(oracles) assert isinstance(dataset.word_store, ItemStore) assert set(dataset.word_store) == self.words assert isinstance(dataset.pos_store, ItemStore) assert set(dataset.pos_store) == self.pos_tags assert isinstance(dataset.nt_store, ItemStore) assert set(dataset.nt_store) == self.nt_labels assert isinstance(dataset.action_store, ItemStore) assert set(dataset.action_store) == self.actions def test_getitem(self): oracles = [ DiscriminativeOracle.from_parsed_sentence(Tree.fromstring(s)) for s in self.bracketed_sents ] dataset = OracleDataset(oracles) assert oracles[0] is dataset[0] assert oracles[1] is dataset[1] def test_len(self): oracles = [ DiscriminativeOracle.from_parsed_sentence(Tree.fromstring(s)) for s in self.bracketed_sents ] dataset = OracleDataset(oracles) assert len(dataset) == len(oracles)
def test_push_known_non_terminal_but_unknown_action(self): actions = [NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction(), ReduceAction()] action_store = ItemStore() for a in actions: action_store.add(a) words = ['John'] pos_tags = ['NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) with pytest.raises(KeyError): parser.push_non_terminal('S')
def test_forward_with_illegal_actions(self): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) action_probs = parser().exp().data assert action_probs[self.action_store[NonTerminalAction('S')]] > 0. assert action_probs[self.action_store[NonTerminalAction('NP')]] > 0. assert action_probs[self.action_store[NonTerminalAction('VP')]] > 0. assert -0.001 <= action_probs[self.action_store[ShiftAction()]] <= 0.001 assert -0.001 <= action_probs[self.action_store[ReduceAction()]] <= 0.001
def test_from_string(self): s = 'asdf fdsa\nNNP VBZ\nNT(S)\nSHIFT\nSHIFT\nREDUCE' oracle = DiscriminativeOracle.from_string(s) assert isinstance(oracle, DiscriminativeOracle) assert oracle.words == ['asdf', 'fdsa'] assert oracle.pos_tags == ['NNP', 'VBZ'] assert oracle.actions == [ NonTerminalAction('S'), ShiftAction(), ShiftAction(), ReduceAction() ]
def test_from_string(self): s = 'NNP VBZ\nNT(S)\nGEN(asdf)\nGEN(fdsa)\nREDUCE' oracle = GenerativeOracle.from_string(s) assert isinstance(oracle, GenerativeOracle) assert oracle.words == ['asdf', 'fdsa'] assert oracle.pos_tags == ['NNP', 'VBZ'] assert oracle.actions == [ NonTerminalAction('S'), GenerateAction('asdf'), GenerateAction('fdsa'), ReduceAction() ]
def test_from_parsed_sent(self): s = '(S (NP (NNP John)) (VP (VBZ loves) (NP (NNP Mary))))' expected_actions = [ NonTerminalAction('S'), NonTerminalAction('NP'), ShiftAction(), ReduceAction(), NonTerminalAction('VP'), ShiftAction(), NonTerminalAction('NP'), ShiftAction(), ReduceAction(), ReduceAction(), ReduceAction(), ] expected_pos_tags = ['NNP', 'VBZ', 'NNP'] expected_words = ['John', 'loves', 'Mary'] oracle = DiscriminativeOracle.from_parsed_sentence(Tree.fromstring(s)) assert isinstance(oracle, DiscriminativeOracle) assert oracle.actions == expected_actions assert oracle.pos_tags == expected_pos_tags assert oracle.words == expected_words
def test_do_reduce_action(self): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') parser.push_non_terminal('NP') parser.shift() prev_input_buffer = parser.input_buffer parser.reduce() assert len(parser.stack_buffer) == 2 last = parser.stack_buffer[-1] assert isinstance(last, Tree) assert last.label() == 'NP' assert len(last) == 1 assert last[0] == 'John' assert parser.input_buffer == prev_input_buffer assert len(parser.action_history) == 4 assert parser.action_history[-1] == ReduceAction() assert not parser.finished
def reduce(self) -> None: if not self.can_reduce(): raise IllegalActionError("Illegal REDUCE action attempted.") self._reduce() self._append_history(ReduceAction())
from torch.autograd import Variable from rnng.models import DiscriminativeRnnGrammar from rnng.utils import ItemStore from rnng.actions import ShiftAction, ReduceAction, NonTerminalAction from rnng.decoding import greedy_decode word2id = {'John': 0, 'loves': 1, 'Mary': 2} pos2id = {'NNP': 0, 'VBZ': 1} nt2id = {'S': 0, 'NP': 1, 'VP': 2} actions = [ NonTerminalAction('S'), NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction(), ReduceAction() ] action_store = ItemStore() for a in actions: action_store.add(a) def test_greedy_decode(mocker): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] correct_actions = [ NonTerminalAction('S'), NonTerminalAction('NP'), ShiftAction(), ReduceAction(), NonTerminalAction('VP'),
def test_eq(self): a = GenerateAction('asdf') assert a == GenerateAction(a.word) assert a != GenerateAction('fdsa') assert a != ReduceAction()
def test_from_invalid_string(self): with pytest.raises(ValueError): ReduceAction.from_string('asdf')
def test_from_string(self): a = ReduceAction.from_string(self.as_str) assert isinstance(a, ReduceAction)
def test_eq(self): assert ReduceAction() == ReduceAction() assert ReduceAction() != ShiftAction()
def test_hash(self): a = ReduceAction() assert hash(a) == hash(self.as_str)
def test_to_string(self): a = ReduceAction() assert str(a) == self.as_str
def __init__(self, word2id: Mapping[Word, WordId], pos2id: Mapping[POSTag, POSId], non_terminal2id: Mapping[NonTerminalLabel, NonTerminalId], action_store: ItemStore, word_dim: int = 32, pos_dim: int = 12, non_terminal_dim: int = 60, action_dim: int = 16, input_dim: int = 128, hidden_dim: int = 128, num_layers: int = 2, dropout: float = 0.) -> None: if ShiftAction() not in action_store: raise ValueError('SHIFT action ID must be specified') if ReduceAction() not in action_store: raise ValueError('REDUCE action ID must be specified') num_words = len(word2id) num_pos = len(pos2id) num_non_terminals = len(non_terminal2id) num_actions = len(action_store) for wid in word2id.values(): if wid < 0 or wid >= num_words: raise ValueError(f'word ID of {wid} is out of range') for pid in pos2id.values(): if pid < 0 or pid >= num_pos: raise ValueError(f'POS tag ID of {pid} is out of range') for nid in non_terminal2id.values(): if nid < 0 or nid >= num_non_terminals: raise ValueError(f'nonterminal ID of {nid} is out of range') super().__init__() self.word2id = word2id self.pos2id = pos2id self.nt2id = non_terminal2id self.action_store = action_store self.num_words = num_words self.num_pos = num_pos self.num_non_terminals = num_non_terminals self.num_actions = num_actions self.word_dim = word_dim self.pos_dim = pos_dim self.non_terminal_dim = non_terminal_dim self.action_dim = action_dim self.input_dim = input_dim self.hidden_dim = hidden_dim self.num_layers = num_layers self.dropout = dropout # Parser states self._stack = [] # type: List[StackElement] self._buffer = [] # type: List[Word] self._history = [] # type: List[Action] self._num_open_non_terminals = 0 self._started = False # Parser state encoders self.stack_lstm = StackLSTM(input_dim, hidden_dim, num_layers=num_layers, dropout=dropout) # can use an LSTM, but this is easier. self.buffer_lstm = StackLSTM(input_dim, hidden_dim, num_layers=num_layers, dropout=dropout) # can use LSTM, but this is more efficient self.history_lstm = StackLSTM(input_dim, hidden_dim, num_layers=num_layers, dropout=dropout) # Composition self.compose_fwd_lstm = nn.LSTM(input_dim, input_dim, num_layers=num_layers, dropout=dropout) self.compose_bwd_lstm = nn.LSTM(input_dim, input_dim, num_layers=num_layers, dropout=dropout) # Transformations self.word2lstm = nn.Sequential( OrderedDict([('linear', nn.Linear(word_dim + pos_dim, input_dim)), ('relu', nn.ReLU())])) self.nt2lstm = nn.Sequential( OrderedDict([('linear', nn.Linear(non_terminal_dim, input_dim)), ('relu', nn.ReLU())])) self.action2lstm = nn.Sequential( OrderedDict([('linear', nn.Linear(action_dim, input_dim)), ('relu', nn.ReLU())])) self.fwdbwd2composed = nn.Sequential( OrderedDict([('linear', nn.Linear(2 * input_dim, input_dim)), ('relu', nn.ReLU())])) self.lstms2summary = nn.Sequential( OrderedDict([ # Stack LSTMs to parser summary ('dropout', nn.Dropout(dropout)), ('linear', nn.Linear(3 * hidden_dim, hidden_dim)), ('relu', nn.ReLU()) ])) self.summary2actions = nn.Linear(hidden_dim, num_actions) # Embeddings self.word_embedding = nn.Embedding(num_words, word_dim) self.pos_embedding = nn.Embedding(num_pos, pos_dim) self.non_terminal_embedding = nn.Embedding(num_non_terminals, non_terminal_dim) self.action_embedding = nn.Embedding(num_actions, action_dim) # Guard parameters for stack, buffer, and action history self.stack_guard = nn.Parameter(torch.Tensor(input_dim)) self.buffer_guard = nn.Parameter(torch.Tensor(input_dim)) self.history_guard = nn.Parameter(torch.Tensor(input_dim)) # Final embeddings self._word_emb = {} # type: Dict[WordId, Variable] self._nt_emb = {} # type: Variable self._action_emb = {} # type: Variable
def test_init_no_shift_action(self): action_store = ItemStore() actions = [NonTerminalAction('S'), NonTerminalAction('NP'), NonTerminalAction('VP'), ReduceAction()] for a in actions: action_store.add(a) with pytest.raises(ValueError): DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store)
class TestDiscRNNGrammar: word2id = {'John': 0, 'loves': 1, 'Mary': 2} pos2id = {'NNP': 0, 'VBZ': 1} nt2id = {'S': 2, 'NP': 1, 'VP': 0} action_store = ItemStore() actions = [NonTerminalAction('S'), NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction(), ReduceAction()] for a in actions: action_store.add(a) def test_init(self): parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) assert len(parser.stack_buffer) == 0 assert len(parser.input_buffer) == 0 assert len(parser.action_history) == 0 assert not parser.finished assert not parser.started def test_init_no_shift_action(self): action_store = ItemStore() actions = [NonTerminalAction('S'), NonTerminalAction('NP'), NonTerminalAction('VP'), ReduceAction()] for a in actions: action_store.add(a) with pytest.raises(ValueError): DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store) def test_init_no_reduce_action(self): action_store = ItemStore() actions = [NonTerminalAction('S'), NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction()] for a in actions: action_store.add(a) with pytest.raises(ValueError): DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store) def test_init_word_id_out_of_range(self): word2id = dict(self.word2id) word2id['John'] = len(word2id) with pytest.raises(ValueError): DiscriminativeRnnGrammar(word2id, self.pos2id, self.nt2id, self.action_store) word2id['John'] = -1 with pytest.raises(ValueError): DiscriminativeRnnGrammar(word2id, self.pos2id, self.nt2id, self.action_store) def test_init_pos_id_out_of_range(self): pos2id = dict(self.pos2id) pos2id['NNP'] = len(pos2id) with pytest.raises(ValueError): DiscriminativeRnnGrammar(self.word2id, pos2id, self.nt2id, self.action_store) pos2id['NNP'] = -1 with pytest.raises(ValueError): DiscriminativeRnnGrammar(self.word2id, pos2id, self.nt2id, self.action_store) def test_init_non_terminal_id_out_of_range(self): nt2id = dict(self.nt2id) nt2id['S'] = len(nt2id) with pytest.raises(ValueError): DiscriminativeRnnGrammar(self.word2id, self.pos2id, nt2id, self.action_store) nt2id['S'] = -1 with pytest.raises(ValueError): DiscriminativeRnnGrammar(self.word2id, self.pos2id, nt2id, self.action_store) def test_initialise_stacks_and_buffers(self): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) assert len(parser.stack_buffer) == 0 assert parser.input_buffer == words assert len(parser.action_history) == 0 assert not parser.finished assert parser.started def test_initialise_with_empty_tagged_words(self): parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) with pytest.raises(ValueError): parser.initialise_stacks_and_buffers([]) def test_initalise_with_invalid_word_or_pos(self): parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) with pytest.raises(ValueError): parser.initialise_stacks_and_buffers([('Bob', 'NNP')]) with pytest.raises(ValueError): parser.initialise_stacks_and_buffers([('John', 'VBD')]) def test_do_non_terminal_action(self): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) prev_input_buffer = parser.input_buffer parser.push_non_terminal('S') assert len(parser.stack_buffer) == 1 last = parser.stack_buffer[-1] assert isinstance(last, Tree) assert last.label() == 'S' assert len(last) == 0 assert parser.input_buffer == prev_input_buffer assert len(parser.action_history) == 1 assert parser.action_history[-1] == NonTerminalAction('S') assert not parser.finished def test_do_illegal_push_non_terminal_action(self): words = ['John'] pos_tags = ['NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) # Buffer is empty parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') parser.shift() with pytest.raises(IllegalActionError): parser.push_non_terminal('NP') # More than 100 open nonterminals parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) for i in range(100): parser.push_non_terminal('S') with pytest.raises(IllegalActionError): parser.push_non_terminal('NP') def test_push_unknown_non_terminal(self): words = ['John'] pos_tags = ['NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) with pytest.raises(KeyError): parser.push_non_terminal('asdf') def test_push_known_non_terminal_but_unknown_action(self): actions = [NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction(), ReduceAction()] action_store = ItemStore() for a in actions: action_store.add(a) words = ['John'] pos_tags = ['NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) with pytest.raises(KeyError): parser.push_non_terminal('S') def test_do_shift_action(self): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') parser.push_non_terminal('NP') parser.shift() assert len(parser.stack_buffer) == 3 last = parser.stack_buffer[-1] assert last == 'John' assert parser.input_buffer == words[1:] assert len(parser.action_history) == 3 assert parser.action_history[-1] == ShiftAction() assert not parser.finished def test_do_illegal_shift_action(self): words = ['John'] pos_tags = ['NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) # No open nonterminal parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) with pytest.raises(IllegalActionError): parser.shift() # Buffer is empty parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') parser.shift() with pytest.raises(IllegalActionError): parser.shift() def test_do_reduce_action(self): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') parser.push_non_terminal('NP') parser.shift() prev_input_buffer = parser.input_buffer parser.reduce() assert len(parser.stack_buffer) == 2 last = parser.stack_buffer[-1] assert isinstance(last, Tree) assert last.label() == 'NP' assert len(last) == 1 assert last[0] == 'John' assert parser.input_buffer == prev_input_buffer assert len(parser.action_history) == 4 assert parser.action_history[-1] == ReduceAction() assert not parser.finished def test_do_illegal_reduce_action(self): words = ['John', 'loves'] pos_tags = ['NNP', 'VBZ'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) # Top of stack is an open nonterminal parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') with pytest.raises(IllegalActionError): parser.reduce() # Buffer is not empty and REDUCE will finish parsing parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') parser.shift() with pytest.raises(IllegalActionError): parser.reduce() def test_do_action_when_not_started(self): parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) with pytest.raises(RuntimeError): parser.push_non_terminal('S') with pytest.raises(RuntimeError): parser.shift() with pytest.raises(RuntimeError): parser.reduce() def test_forward(self): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') parser.push_non_terminal('NP') parser.shift() parser.reduce() action_logprobs = parser() assert isinstance(action_logprobs, Variable) assert action_logprobs.size() == (len(self.action_store),) sum_prob = action_logprobs.exp().sum().data[0] assert 0.999 <= sum_prob <= 1.001 def test_forward_with_illegal_actions(self): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) action_probs = parser().exp().data assert action_probs[self.action_store[NonTerminalAction('S')]] > 0. assert action_probs[self.action_store[NonTerminalAction('NP')]] > 0. assert action_probs[self.action_store[NonTerminalAction('VP')]] > 0. assert -0.001 <= action_probs[self.action_store[ShiftAction()]] <= 0.001 assert -0.001 <= action_probs[self.action_store[ReduceAction()]] <= 0.001 def test_forward_when_not_started(self): parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) with pytest.raises(RuntimeError): parser() def test_finished(self): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) exp_parse_tree = Tree('S', [Tree('NP', ['John']), Tree('VP', ['loves', Tree('NP', ['Mary'])])]) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') parser.push_non_terminal('NP') parser.shift() parser.reduce() parser.push_non_terminal('VP') parser.shift() parser.push_non_terminal('NP') parser.shift() parser.reduce() parser.reduce() parser.reduce() assert parser.finished parse_tree = parser.stack_buffer[-1] assert parse_tree == exp_parse_tree with pytest.raises(RuntimeError): parser() with pytest.raises(RuntimeError): parser.push_non_terminal('NP') with pytest.raises(RuntimeError): parser.shift() with pytest.raises(RuntimeError): parser.reduce()