Пример #1
0
    def test_invertible(self):
        store = ItemStore()
        for c in list('aabbcc'):
            store.add(c)

        for c in list('abc'):
            assert store.get_by_id(store[c]) == c
Пример #2
0
    def test_init_no_reduce_action(self):
        action_store = ItemStore()
        actions = [NonTerminalAction('S'), NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction()]
        for a in actions:
            action_store.add(a)

        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store)
Пример #3
0
    def test_contains(self):
        store = ItemStore()
        for c in list('aabbcc'):
            store.add(c)

        assert 'a' in store
        assert 'b' in store
        assert 'c' in store
        assert 'd' not in store
Пример #4
0
    def test_unique_ids(self):
        store = ItemStore()
        for c in list('aabbcc'):
            store.add(c)

        ids = [store[c] for c in list('abc')]
        assert len(ids) == len(set(ids))
        for i in ids:
            assert 0 <= i < len(store)
Пример #5
0
    def test_push_known_non_terminal_but_unknown_action(self):
        actions = [NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction(), ReduceAction()]
        action_store = ItemStore()
        for a in actions:
            action_store.add(a)
        words = ['John']
        pos_tags = ['NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))

        with pytest.raises(KeyError):
            parser.push_non_terminal('S')
Пример #6
0
class OracleDataset(Dataset):
    def __init__(self, oracles: Sequence[Oracle]) -> None:
        self.oracles = oracles
        self.word_store = ItemStore()  # type: ItemStore[Word]
        self.pos_store = ItemStore()  # type: ItemStore[POSTag]
        self.nt_store = ItemStore()  # type: ItemStore[NonTerminalLabel]
        self.action_store = ItemStore()  # type: ItemStore[Action]

        self.load()

    def load(self) -> None:
        for oracle in self.oracles:
            for word in oracle.words:
                self.word_store.add(word)
            for pos in oracle.pos_tags:
                self.pos_store.add(pos)
            for action in oracle.actions:
                self.action_store.add(action)
                if isinstance(action, NonTerminalAction):
                    self.nt_store.add(action.label)

    def __getitem__(self, index: int) -> Oracle:
        return self.oracles[index]

    def __len__(self) -> int:
        return len(self.oracles)
Пример #7
0
from rnng.actions import ShiftAction, ReduceAction, NonTerminalAction
from rnng.decoding import greedy_decode

word2id = {'John': 0, 'loves': 1, 'Mary': 2}
pos2id = {'NNP': 0, 'VBZ': 1}
nt2id = {'S': 0, 'NP': 1, 'VP': 2}
actions = [
    NonTerminalAction('S'),
    NonTerminalAction('NP'),
    NonTerminalAction('VP'),
    ShiftAction(),
    ReduceAction()
]
action_store = ItemStore()
for a in actions:
    action_store.add(a)


def test_greedy_decode(mocker):
    words = ['John', 'loves', 'Mary']
    pos_tags = ['NNP', 'VBZ', 'NNP']
    correct_actions = [
        NonTerminalAction('S'),
        NonTerminalAction('NP'),
        ShiftAction(),
        ReduceAction(),
        NonTerminalAction('VP'),
        ShiftAction(),
        NonTerminalAction('NP'),
        ShiftAction(),
        ReduceAction(),
Пример #8
0
    def test_len(self):
        store = ItemStore()
        for c in list('aabbcc'):
            store.add(c)

        assert len(store) == 3
Пример #9
0
    def test_iter(self):
        store = ItemStore()
        for c in list('aabbcc'):
            store.add(c)

        assert sorted(list('abc')) == sorted(store)