def test_greedy_decode(mocker): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] correct_actions = [ NonTerminalAction('S'), NonTerminalAction('NP'), ShiftAction(), ReduceAction(), NonTerminalAction('VP'), ShiftAction(), NonTerminalAction('NP'), ShiftAction(), ReduceAction(), ReduceAction(), ReduceAction(), ] retvals = [ Variable( torch.zeros(len(action_store)).scatter_( 0, torch.LongTensor([action_store[a]]), 1)) for a in correct_actions ] parser = DiscriminativeRnnGrammar(word2id, pos2id, nt2id, action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) mocker.patch.object(parser, 'forward', side_effect=retvals) result = greedy_decode(parser) assert len(result) == len(correct_actions) picked_actions, log_probs = zip(*result) assert list(picked_actions) == correct_actions
def test_init_non_terminal_id_out_of_range(self): nt2id = dict(self.nt2id) nt2id['S'] = len(nt2id) with pytest.raises(ValueError): DiscriminativeRnnGrammar(self.word2id, self.pos2id, nt2id, self.action_store) nt2id['S'] = -1 with pytest.raises(ValueError): DiscriminativeRnnGrammar(self.word2id, self.pos2id, nt2id, self.action_store)
def test_init_pos_id_out_of_range(self): pos2id = dict(self.pos2id) pos2id['NNP'] = len(pos2id) with pytest.raises(ValueError): DiscriminativeRnnGrammar(self.word2id, pos2id, self.nt2id, self.action_store) pos2id['NNP'] = -1 with pytest.raises(ValueError): DiscriminativeRnnGrammar(self.word2id, pos2id, self.nt2id, self.action_store)
def test_init_word_id_out_of_range(self): word2id = dict(self.word2id) word2id['John'] = len(word2id) with pytest.raises(ValueError): DiscriminativeRnnGrammar(word2id, self.pos2id, self.nt2id, self.action_store) word2id['John'] = -1 with pytest.raises(ValueError): DiscriminativeRnnGrammar(word2id, self.pos2id, self.nt2id, self.action_store)
def test_initialise_stacks_and_buffers(self): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) assert len(parser.stack_buffer) == 0 assert parser.input_buffer == words assert len(parser.action_history) == 0 assert not parser.finished assert parser.started
def test_forward_with_illegal_actions(self): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) action_probs = parser().exp().data assert action_probs[self.action_store[NonTerminalAction('S')]] > 0. assert action_probs[self.action_store[NonTerminalAction('NP')]] > 0. assert action_probs[self.action_store[NonTerminalAction('VP')]] > 0. assert -0.001 <= action_probs[self.action_store[ShiftAction()]] <= 0.001 assert -0.001 <= action_probs[self.action_store[ReduceAction()]] <= 0.001
def train_model(training_file: str): oracles = read_oracles_from_file(DiscriminativeOracle, training_file) dataset = OracleDataset(oracles) dataset.load() dataset_loader = DataLoader(dataset, collate_fn=lambda x: x[0]) parser = DiscriminativeRnnGrammar(action_store=dataset.action_store, word2id=dataset.word_store, pos2id=dataset.pos_store, non_terminal2id=dataset.nt_store) optimiser = SGD(parser.parameters(), 0.1) train_early_stopping(dataset_loader, dataset_loader, parser, optimiser)
def test_do_action_when_not_started(self): parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) with pytest.raises(RuntimeError): parser.push_non_terminal('S') with pytest.raises(RuntimeError): parser.shift() with pytest.raises(RuntimeError): parser.reduce()
def test_init_no_reduce_action(self): action_store = ItemStore() actions = [NonTerminalAction('S'), NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction()] for a in actions: action_store.add(a) with pytest.raises(ValueError): DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store)
def test_do_shift_action(self): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') parser.push_non_terminal('NP') parser.shift() assert len(parser.stack_buffer) == 3 last = parser.stack_buffer[-1] assert last == 'John' assert parser.input_buffer == words[1:] assert len(parser.action_history) == 3 assert parser.action_history[-1] == ShiftAction() assert not parser.finished
def test_init(self): parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) assert len(parser.stack_buffer) == 0 assert len(parser.input_buffer) == 0 assert len(parser.action_history) == 0 assert not parser.finished assert not parser.started
def test_push_unknown_non_terminal(self): words = ['John'] pos_tags = ['NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) with pytest.raises(KeyError): parser.push_non_terminal('asdf')
def test_initalise_with_invalid_word_or_pos(self): parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) with pytest.raises(ValueError): parser.initialise_stacks_and_buffers([('Bob', 'NNP')]) with pytest.raises(ValueError): parser.initialise_stacks_and_buffers([('John', 'VBD')])
def test_push_known_non_terminal_but_unknown_action(self): actions = [NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction(), ReduceAction()] action_store = ItemStore() for a in actions: action_store.add(a) words = ['John'] pos_tags = ['NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) with pytest.raises(KeyError): parser.push_non_terminal('S')
def test_do_non_terminal_action(self): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) prev_input_buffer = parser.input_buffer parser.push_non_terminal('S') assert len(parser.stack_buffer) == 1 last = parser.stack_buffer[-1] assert isinstance(last, Tree) assert last.label() == 'S' assert len(last) == 0 assert parser.input_buffer == prev_input_buffer assert len(parser.action_history) == 1 assert parser.action_history[-1] == NonTerminalAction('S') assert not parser.finished
def test_finished(self): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) exp_parse_tree = Tree('S', [Tree('NP', ['John']), Tree('VP', ['loves', Tree('NP', ['Mary'])])]) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') parser.push_non_terminal('NP') parser.shift() parser.reduce() parser.push_non_terminal('VP') parser.shift() parser.push_non_terminal('NP') parser.shift() parser.reduce() parser.reduce() parser.reduce() assert parser.finished parse_tree = parser.stack_buffer[-1] assert parse_tree == exp_parse_tree with pytest.raises(RuntimeError): parser() with pytest.raises(RuntimeError): parser.push_non_terminal('NP') with pytest.raises(RuntimeError): parser.shift() with pytest.raises(RuntimeError): parser.reduce()
def test_forward_when_not_started(self): parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) with pytest.raises(RuntimeError): parser()
def test_forward(self): words = ['John', 'loves', 'Mary'] pos_tags = ['NNP', 'VBZ', 'NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') parser.push_non_terminal('NP') parser.shift() parser.reduce() action_logprobs = parser() assert isinstance(action_logprobs, Variable) assert action_logprobs.size() == (len(self.action_store),) sum_prob = action_logprobs.exp().sum().data[0] assert 0.999 <= sum_prob <= 1.001
def test_do_illegal_reduce_action(self): words = ['John', 'loves'] pos_tags = ['NNP', 'VBZ'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) # Top of stack is an open nonterminal parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') with pytest.raises(IllegalActionError): parser.reduce() # Buffer is not empty and REDUCE will finish parsing parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') parser.shift() with pytest.raises(IllegalActionError): parser.reduce()
def test_do_illegal_shift_action(self): words = ['John'] pos_tags = ['NNP'] parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) # No open nonterminal parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) with pytest.raises(IllegalActionError): parser.shift() # Buffer is empty parser.initialise_stacks_and_buffers(list(zip(words, pos_tags))) parser.push_non_terminal('S') parser.shift() with pytest.raises(IllegalActionError): parser.shift()
def test_initialise_with_empty_tagged_words(self): parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store) with pytest.raises(ValueError): parser.initialise_stacks_and_buffers([])