示例#1
0
def test_greedy_decode(mocker):
    words = ['John', 'loves', 'Mary']
    pos_tags = ['NNP', 'VBZ', 'NNP']
    correct_actions = [
        NonTerminalAction('S'),
        NonTerminalAction('NP'),
        ShiftAction(),
        ReduceAction(),
        NonTerminalAction('VP'),
        ShiftAction(),
        NonTerminalAction('NP'),
        ShiftAction(),
        ReduceAction(),
        ReduceAction(),
        ReduceAction(),
    ]
    retvals = [
        Variable(
            torch.zeros(len(action_store)).scatter_(
                0, torch.LongTensor([action_store[a]]), 1))
        for a in correct_actions
    ]
    parser = DiscriminativeRnnGrammar(word2id, pos2id, nt2id, action_store)
    parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
    mocker.patch.object(parser, 'forward', side_effect=retvals)

    result = greedy_decode(parser)

    assert len(result) == len(correct_actions)
    picked_actions, log_probs = zip(*result)
    assert list(picked_actions) == correct_actions
示例#2
0
    def test_init_non_terminal_id_out_of_range(self):
        nt2id = dict(self.nt2id)

        nt2id['S'] = len(nt2id)
        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(self.word2id, self.pos2id, nt2id, self.action_store)

        nt2id['S'] = -1
        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(self.word2id, self.pos2id, nt2id, self.action_store)
示例#3
0
    def test_init_pos_id_out_of_range(self):
        pos2id = dict(self.pos2id)

        pos2id['NNP'] = len(pos2id)
        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(self.word2id, pos2id, self.nt2id, self.action_store)

        pos2id['NNP'] = -1
        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(self.word2id, pos2id, self.nt2id, self.action_store)
示例#4
0
    def test_init_word_id_out_of_range(self):
        word2id = dict(self.word2id)

        word2id['John'] = len(word2id)
        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(word2id, self.pos2id, self.nt2id, self.action_store)

        word2id['John'] = -1
        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(word2id, self.pos2id, self.nt2id, self.action_store)
示例#5
0
    def test_initialise_stacks_and_buffers(self):
        words = ['John', 'loves', 'Mary']
        pos_tags = ['NNP', 'VBZ', 'NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))

        assert len(parser.stack_buffer) == 0
        assert parser.input_buffer == words
        assert len(parser.action_history) == 0
        assert not parser.finished
        assert parser.started
示例#6
0
    def test_forward_with_illegal_actions(self):
        words = ['John', 'loves', 'Mary']
        pos_tags = ['NNP', 'VBZ', 'NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))

        action_probs = parser().exp().data

        assert action_probs[self.action_store[NonTerminalAction('S')]] > 0.
        assert action_probs[self.action_store[NonTerminalAction('NP')]] > 0.
        assert action_probs[self.action_store[NonTerminalAction('VP')]] > 0.
        assert -0.001 <= action_probs[self.action_store[ShiftAction()]] <= 0.001
        assert -0.001 <= action_probs[self.action_store[ReduceAction()]] <= 0.001
示例#7
0
def train_model(training_file: str):

    oracles = read_oracles_from_file(DiscriminativeOracle, training_file)
    dataset = OracleDataset(oracles)
    dataset.load()
    dataset_loader = DataLoader(dataset, collate_fn=lambda x: x[0])

    parser = DiscriminativeRnnGrammar(action_store=dataset.action_store,
                                      word2id=dataset.word_store,
                                      pos2id=dataset.pos_store,
                                      non_terminal2id=dataset.nt_store)
    optimiser = SGD(parser.parameters(), 0.1)

    train_early_stopping(dataset_loader, dataset_loader, parser, optimiser)
示例#8
0
    def test_do_action_when_not_started(self):
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        with pytest.raises(RuntimeError):
            parser.push_non_terminal('S')
        with pytest.raises(RuntimeError):
            parser.shift()
        with pytest.raises(RuntimeError):
            parser.reduce()
示例#9
0
    def test_init_no_reduce_action(self):
        action_store = ItemStore()
        actions = [NonTerminalAction('S'), NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction()]
        for a in actions:
            action_store.add(a)

        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store)
示例#10
0
    def test_do_shift_action(self):
        words = ['John', 'loves', 'Mary']
        pos_tags = ['NNP', 'VBZ', 'NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        parser.push_non_terminal('S')
        parser.push_non_terminal('NP')

        parser.shift()

        assert len(parser.stack_buffer) == 3
        last = parser.stack_buffer[-1]
        assert last == 'John'
        assert parser.input_buffer == words[1:]
        assert len(parser.action_history) == 3
        assert parser.action_history[-1] == ShiftAction()
        assert not parser.finished
示例#11
0
    def test_init(self):
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        assert len(parser.stack_buffer) == 0
        assert len(parser.input_buffer) == 0
        assert len(parser.action_history) == 0
        assert not parser.finished
        assert not parser.started
示例#12
0
    def test_push_unknown_non_terminal(self):
        words = ['John']
        pos_tags = ['NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))

        with pytest.raises(KeyError):
            parser.push_non_terminal('asdf')
示例#13
0
    def test_initalise_with_invalid_word_or_pos(self):
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        with pytest.raises(ValueError):
            parser.initialise_stacks_and_buffers([('Bob', 'NNP')])

        with pytest.raises(ValueError):
            parser.initialise_stacks_and_buffers([('John', 'VBD')])
示例#14
0
    def test_push_known_non_terminal_but_unknown_action(self):
        actions = [NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction(), ReduceAction()]
        action_store = ItemStore()
        for a in actions:
            action_store.add(a)
        words = ['John']
        pos_tags = ['NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))

        with pytest.raises(KeyError):
            parser.push_non_terminal('S')
示例#15
0
    def test_do_non_terminal_action(self):
        words = ['John', 'loves', 'Mary']
        pos_tags = ['NNP', 'VBZ', 'NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        prev_input_buffer = parser.input_buffer

        parser.push_non_terminal('S')

        assert len(parser.stack_buffer) == 1
        last = parser.stack_buffer[-1]
        assert isinstance(last, Tree)
        assert last.label() == 'S'
        assert len(last) == 0
        assert parser.input_buffer == prev_input_buffer
        assert len(parser.action_history) == 1
        assert parser.action_history[-1] == NonTerminalAction('S')
        assert not parser.finished
示例#16
0
    def test_finished(self):
        words = ['John', 'loves', 'Mary']
        pos_tags = ['NNP', 'VBZ', 'NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)
        exp_parse_tree = Tree('S', [Tree('NP', ['John']),
                                    Tree('VP', ['loves', Tree('NP', ['Mary'])])])

        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        parser.push_non_terminal('S')
        parser.push_non_terminal('NP')
        parser.shift()
        parser.reduce()
        parser.push_non_terminal('VP')
        parser.shift()
        parser.push_non_terminal('NP')
        parser.shift()
        parser.reduce()
        parser.reduce()
        parser.reduce()

        assert parser.finished
        parse_tree = parser.stack_buffer[-1]
        assert parse_tree == exp_parse_tree
        with pytest.raises(RuntimeError):
            parser()
        with pytest.raises(RuntimeError):
            parser.push_non_terminal('NP')
        with pytest.raises(RuntimeError):
            parser.shift()
        with pytest.raises(RuntimeError):
            parser.reduce()
示例#17
0
    def test_forward_when_not_started(self):
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        with pytest.raises(RuntimeError):
            parser()
示例#18
0
    def test_forward(self):
        words = ['John', 'loves', 'Mary']
        pos_tags = ['NNP', 'VBZ', 'NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        parser.push_non_terminal('S')
        parser.push_non_terminal('NP')
        parser.shift()
        parser.reduce()

        action_logprobs = parser()

        assert isinstance(action_logprobs, Variable)
        assert action_logprobs.size() == (len(self.action_store),)
        sum_prob = action_logprobs.exp().sum().data[0]
        assert 0.999 <= sum_prob <= 1.001
示例#19
0
    def test_do_illegal_reduce_action(self):
        words = ['John', 'loves']
        pos_tags = ['NNP', 'VBZ']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        # Top of stack is an open nonterminal
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        parser.push_non_terminal('S')
        with pytest.raises(IllegalActionError):
            parser.reduce()

        # Buffer is not empty and REDUCE will finish parsing
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        parser.push_non_terminal('S')
        parser.shift()
        with pytest.raises(IllegalActionError):
            parser.reduce()
示例#20
0
    def test_do_illegal_shift_action(self):
        words = ['John']
        pos_tags = ['NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        # No open nonterminal
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        with pytest.raises(IllegalActionError):
            parser.shift()

        # Buffer is empty
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        parser.push_non_terminal('S')
        parser.shift()
        with pytest.raises(IllegalActionError):
            parser.shift()
示例#21
0
    def test_initialise_with_empty_tagged_words(self):
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        with pytest.raises(ValueError):
            parser.initialise_stacks_and_buffers([])