예제 #1
0
    def test_invertible(self):
        store = ItemStore()
        for c in list('aabbcc'):
            store.add(c)

        for c in list('abc'):
            assert store.get_by_id(store[c]) == c
예제 #2
0
    def test_init_no_reduce_action(self):
        action_store = ItemStore()
        actions = [NonTerminalAction('S'), NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction()]
        for a in actions:
            action_store.add(a)

        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store)
예제 #3
0
    def test_contains(self):
        store = ItemStore()
        for c in list('aabbcc'):
            store.add(c)

        assert 'a' in store
        assert 'b' in store
        assert 'c' in store
        assert 'd' not in store
예제 #4
0
    def test_unique_ids(self):
        store = ItemStore()
        for c in list('aabbcc'):
            store.add(c)

        ids = [store[c] for c in list('abc')]
        assert len(ids) == len(set(ids))
        for i in ids:
            assert 0 <= i < len(store)
예제 #5
0
    def test_push_known_non_terminal_but_unknown_action(self):
        actions = [NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction(), ReduceAction()]
        action_store = ItemStore()
        for a in actions:
            action_store.add(a)
        words = ['John']
        pos_tags = ['NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))

        with pytest.raises(KeyError):
            parser.push_non_terminal('S')
예제 #6
0
class OracleDataset(Dataset):
    def __init__(self, oracles: Sequence[Oracle]) -> None:
        self.oracles = oracles
        self.word_store = ItemStore()  # type: ItemStore[Word]
        self.pos_store = ItemStore()  # type: ItemStore[POSTag]
        self.nt_store = ItemStore()  # type: ItemStore[NonTerminalLabel]
        self.action_store = ItemStore()  # type: ItemStore[Action]

        self.load()

    def load(self) -> None:
        for oracle in self.oracles:
            for word in oracle.words:
                self.word_store.add(word)
            for pos in oracle.pos_tags:
                self.pos_store.add(pos)
            for action in oracle.actions:
                self.action_store.add(action)
                if isinstance(action, NonTerminalAction):
                    self.nt_store.add(action.label)

    def __getitem__(self, index: int) -> Oracle:
        return self.oracles[index]

    def __len__(self) -> int:
        return len(self.oracles)
예제 #7
0
    def __init__(self, oracles: Sequence[Oracle]) -> None:
        self.oracles = oracles
        self.word_store = ItemStore()  # type: ItemStore[Word]
        self.pos_store = ItemStore()  # type: ItemStore[POSTag]
        self.nt_store = ItemStore()  # type: ItemStore[NonTerminalLabel]
        self.action_store = ItemStore()  # type: ItemStore[Action]

        self.load()
예제 #8
0
from rnng.models import DiscriminativeRnnGrammar
from rnng.utils import ItemStore
from rnng.actions import ShiftAction, ReduceAction, NonTerminalAction
from rnng.decoding import greedy_decode

word2id = {'John': 0, 'loves': 1, 'Mary': 2}
pos2id = {'NNP': 0, 'VBZ': 1}
nt2id = {'S': 0, 'NP': 1, 'VP': 2}
actions = [
    NonTerminalAction('S'),
    NonTerminalAction('NP'),
    NonTerminalAction('VP'),
    ShiftAction(),
    ReduceAction()
]
action_store = ItemStore()
for a in actions:
    action_store.add(a)


def test_greedy_decode(mocker):
    words = ['John', 'loves', 'Mary']
    pos_tags = ['NNP', 'VBZ', 'NNP']
    correct_actions = [
        NonTerminalAction('S'),
        NonTerminalAction('NP'),
        ShiftAction(),
        ReduceAction(),
        NonTerminalAction('VP'),
        ShiftAction(),
        NonTerminalAction('NP'),
예제 #9
0
class TestDiscRNNGrammar:
    word2id = {'John': 0, 'loves': 1, 'Mary': 2}
    pos2id = {'NNP': 0, 'VBZ': 1}
    nt2id = {'S': 2, 'NP': 1, 'VP': 0}
    action_store = ItemStore()
    actions = [NonTerminalAction('S'), NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction(), ReduceAction()]
    for a in actions:
        action_store.add(a)

    def test_init(self):
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        assert len(parser.stack_buffer) == 0
        assert len(parser.input_buffer) == 0
        assert len(parser.action_history) == 0
        assert not parser.finished
        assert not parser.started

    def test_init_no_shift_action(self):
        action_store = ItemStore()
        actions = [NonTerminalAction('S'), NonTerminalAction('NP'), NonTerminalAction('VP'), ReduceAction()]
        for a in actions:
            action_store.add(a)

        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store)

    def test_init_no_reduce_action(self):
        action_store = ItemStore()
        actions = [NonTerminalAction('S'), NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction()]
        for a in actions:
            action_store.add(a)

        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store)

    def test_init_word_id_out_of_range(self):
        word2id = dict(self.word2id)

        word2id['John'] = len(word2id)
        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(word2id, self.pos2id, self.nt2id, self.action_store)

        word2id['John'] = -1
        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(word2id, self.pos2id, self.nt2id, self.action_store)

    def test_init_pos_id_out_of_range(self):
        pos2id = dict(self.pos2id)

        pos2id['NNP'] = len(pos2id)
        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(self.word2id, pos2id, self.nt2id, self.action_store)

        pos2id['NNP'] = -1
        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(self.word2id, pos2id, self.nt2id, self.action_store)

    def test_init_non_terminal_id_out_of_range(self):
        nt2id = dict(self.nt2id)

        nt2id['S'] = len(nt2id)
        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(self.word2id, self.pos2id, nt2id, self.action_store)

        nt2id['S'] = -1
        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(self.word2id, self.pos2id, nt2id, self.action_store)

    def test_initialise_stacks_and_buffers(self):
        words = ['John', 'loves', 'Mary']
        pos_tags = ['NNP', 'VBZ', 'NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))

        assert len(parser.stack_buffer) == 0
        assert parser.input_buffer == words
        assert len(parser.action_history) == 0
        assert not parser.finished
        assert parser.started

    def test_initialise_with_empty_tagged_words(self):
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        with pytest.raises(ValueError):
            parser.initialise_stacks_and_buffers([])

    def test_initalise_with_invalid_word_or_pos(self):
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        with pytest.raises(ValueError):
            parser.initialise_stacks_and_buffers([('Bob', 'NNP')])

        with pytest.raises(ValueError):
            parser.initialise_stacks_and_buffers([('John', 'VBD')])

    def test_do_non_terminal_action(self):
        words = ['John', 'loves', 'Mary']
        pos_tags = ['NNP', 'VBZ', 'NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        prev_input_buffer = parser.input_buffer

        parser.push_non_terminal('S')

        assert len(parser.stack_buffer) == 1
        last = parser.stack_buffer[-1]
        assert isinstance(last, Tree)
        assert last.label() == 'S'
        assert len(last) == 0
        assert parser.input_buffer == prev_input_buffer
        assert len(parser.action_history) == 1
        assert parser.action_history[-1] == NonTerminalAction('S')
        assert not parser.finished

    def test_do_illegal_push_non_terminal_action(self):
        words = ['John']
        pos_tags = ['NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        # Buffer is empty
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        parser.push_non_terminal('S')
        parser.shift()
        with pytest.raises(IllegalActionError):
            parser.push_non_terminal('NP')

        # More than 100 open nonterminals
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        for i in range(100):
            parser.push_non_terminal('S')
        with pytest.raises(IllegalActionError):
            parser.push_non_terminal('NP')

    def test_push_unknown_non_terminal(self):
        words = ['John']
        pos_tags = ['NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))

        with pytest.raises(KeyError):
            parser.push_non_terminal('asdf')

    def test_push_known_non_terminal_but_unknown_action(self):
        actions = [NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction(), ReduceAction()]
        action_store = ItemStore()
        for a in actions:
            action_store.add(a)
        words = ['John']
        pos_tags = ['NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))

        with pytest.raises(KeyError):
            parser.push_non_terminal('S')

    def test_do_shift_action(self):
        words = ['John', 'loves', 'Mary']
        pos_tags = ['NNP', 'VBZ', 'NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        parser.push_non_terminal('S')
        parser.push_non_terminal('NP')

        parser.shift()

        assert len(parser.stack_buffer) == 3
        last = parser.stack_buffer[-1]
        assert last == 'John'
        assert parser.input_buffer == words[1:]
        assert len(parser.action_history) == 3
        assert parser.action_history[-1] == ShiftAction()
        assert not parser.finished

    def test_do_illegal_shift_action(self):
        words = ['John']
        pos_tags = ['NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        # No open nonterminal
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        with pytest.raises(IllegalActionError):
            parser.shift()

        # Buffer is empty
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        parser.push_non_terminal('S')
        parser.shift()
        with pytest.raises(IllegalActionError):
            parser.shift()

    def test_do_reduce_action(self):
        words = ['John', 'loves', 'Mary']
        pos_tags = ['NNP', 'VBZ', 'NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        parser.push_non_terminal('S')
        parser.push_non_terminal('NP')
        parser.shift()
        prev_input_buffer = parser.input_buffer

        parser.reduce()

        assert len(parser.stack_buffer) == 2
        last = parser.stack_buffer[-1]
        assert isinstance(last, Tree)
        assert last.label() == 'NP'
        assert len(last) == 1
        assert last[0] == 'John'
        assert parser.input_buffer == prev_input_buffer
        assert len(parser.action_history) == 4
        assert parser.action_history[-1] == ReduceAction()
        assert not parser.finished

    def test_do_illegal_reduce_action(self):
        words = ['John', 'loves']
        pos_tags = ['NNP', 'VBZ']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        # Top of stack is an open nonterminal
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        parser.push_non_terminal('S')
        with pytest.raises(IllegalActionError):
            parser.reduce()

        # Buffer is not empty and REDUCE will finish parsing
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        parser.push_non_terminal('S')
        parser.shift()
        with pytest.raises(IllegalActionError):
            parser.reduce()

    def test_do_action_when_not_started(self):
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        with pytest.raises(RuntimeError):
            parser.push_non_terminal('S')
        with pytest.raises(RuntimeError):
            parser.shift()
        with pytest.raises(RuntimeError):
            parser.reduce()

    def test_forward(self):
        words = ['John', 'loves', 'Mary']
        pos_tags = ['NNP', 'VBZ', 'NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        parser.push_non_terminal('S')
        parser.push_non_terminal('NP')
        parser.shift()
        parser.reduce()

        action_logprobs = parser()

        assert isinstance(action_logprobs, Variable)
        assert action_logprobs.size() == (len(self.action_store),)
        sum_prob = action_logprobs.exp().sum().data[0]
        assert 0.999 <= sum_prob <= 1.001

    def test_forward_with_illegal_actions(self):
        words = ['John', 'loves', 'Mary']
        pos_tags = ['NNP', 'VBZ', 'NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))

        action_probs = parser().exp().data

        assert action_probs[self.action_store[NonTerminalAction('S')]] > 0.
        assert action_probs[self.action_store[NonTerminalAction('NP')]] > 0.
        assert action_probs[self.action_store[NonTerminalAction('VP')]] > 0.
        assert -0.001 <= action_probs[self.action_store[ShiftAction()]] <= 0.001
        assert -0.001 <= action_probs[self.action_store[ReduceAction()]] <= 0.001

    def test_forward_when_not_started(self):
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        with pytest.raises(RuntimeError):
            parser()

    def test_finished(self):
        words = ['John', 'loves', 'Mary']
        pos_tags = ['NNP', 'VBZ', 'NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)
        exp_parse_tree = Tree('S', [Tree('NP', ['John']),
                                    Tree('VP', ['loves', Tree('NP', ['Mary'])])])

        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        parser.push_non_terminal('S')
        parser.push_non_terminal('NP')
        parser.shift()
        parser.reduce()
        parser.push_non_terminal('VP')
        parser.shift()
        parser.push_non_terminal('NP')
        parser.shift()
        parser.reduce()
        parser.reduce()
        parser.reduce()

        assert parser.finished
        parse_tree = parser.stack_buffer[-1]
        assert parse_tree == exp_parse_tree
        with pytest.raises(RuntimeError):
            parser()
        with pytest.raises(RuntimeError):
            parser.push_non_terminal('NP')
        with pytest.raises(RuntimeError):
            parser.shift()
        with pytest.raises(RuntimeError):
            parser.reduce()
예제 #10
0
    def test_len(self):
        store = ItemStore()
        for c in list('aabbcc'):
            store.add(c)

        assert len(store) == 3
예제 #11
0
    def test_iter(self):
        store = ItemStore()
        for c in list('aabbcc'):
            store.add(c)

        assert sorted(list('abc')) == sorted(store)