def test_do_action_when_not_started(self): parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) with pytest.raises(RuntimeError): parser.push_non_terminal('S') with pytest.raises(RuntimeError): parser.shift() with pytest.raises(RuntimeError): parser.reduce()
def test_finished(self): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) exp_parse_tree = Tree('S', [Tree('NP', ['John']), Tree('VP', ['loves', Tree('NP', ['Mary'])])]) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') parser.push_non_terminal('NP') parser.shift() parser.reduce() parser.push_non_terminal('VP') parser.shift() parser.push_non_terminal('NP') parser.shift() parser.reduce() parser.reduce() parser.reduce() assert parser.finished parse_tree = parser.stack_buffer[-1] assert parse_tree == exp_parse_tree with pytest.raises(RuntimeError): parser() with pytest.raises(RuntimeError): parser.push_non_terminal('NP') with pytest.raises(RuntimeError): parser.shift() with pytest.raises(RuntimeError): parser.reduce()
def test_forward(self): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') parser.push_non_terminal('NP') parser.shift() parser.reduce() action_logprobs = parser() assert isinstance(action_logprobs, Variable) assert action_logprobs.size() == (len(self.action_store),) sum_prob = action_logprobs.exp().sum().data[0] assert 0.999 <= sum_prob <= 1.001
def test_do_illegal_reduce_action(self): words = ['John', 'loves'] pos_tags = ['NNP', 'VBZ'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) # Top of stack is an open nonterminal parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') with pytest.raises(IllegalActionError): parser.reduce() # Buffer is not empty and REDUCE will finish parsing parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') parser.shift() with pytest.raises(IllegalActionError): parser.reduce()
def test_do_reduce_action(self): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') parser.push_non_terminal('NP') parser.shift() prev_input_buffer = parser.input_buffer parser.reduce() assert len(parser.stack_buffer) == 2 last = parser.stack_buffer[-1] assert isinstance(last, Tree) assert last.label() == 'NP' assert len(last) == 1 assert last[0] == 'John' assert parser.input_buffer == prev_input_buffer assert len(parser.action_history) == 4 assert parser.action_history[-1] == ReduceAction() assert not parser.finished