Example #1
0
def test_greedy_decode(mocker):
    words = ['John', 'loves', 'Mary']
    pos_tags = ['NNP', 'VBZ', 'NNP']
    correct_actions = [
        NonTerminalAction('S'),
        NonTerminalAction('NP'),
        ShiftAction(),
        ReduceAction(),
        NonTerminalAction('VP'),
        ShiftAction(),
        NonTerminalAction('NP'),
        ShiftAction(),
        ReduceAction(),
        ReduceAction(),
        ReduceAction(),
    ]
    retvals = [
        Variable(
            torch.zeros(len(action_store)).scatter_(
                0, torch.LongTensor([action_store[a]]), 1))
        for a in correct_actions
    ]
    parser = DiscriminativeRnnGrammar(word2id, pos2id, nt2id, action_store)
    parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
    mocker.patch.object(parser, 'forward', side_effect=retvals)

    result = greedy_decode(parser)

    assert len(result) == len(correct_actions)
    picked_actions, log_probs = zip(*result)
    assert list(picked_actions) == correct_actions
Example #2
0
    def test_init_no_reduce_action(self):
        action_store = ItemStore()
        actions = [NonTerminalAction('S'), NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction()]
        for a in actions:
            action_store.add(a)

        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store)
Example #3
0
    def test_push_known_non_terminal_but_unknown_action(self):
        actions = [NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction(), ReduceAction()]
        action_store = ItemStore()
        for a in actions:
            action_store.add(a)
        words = ['John']
        pos_tags = ['NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))

        with pytest.raises(KeyError):
            parser.push_non_terminal('S')
Example #4
0
class TestOracleDataset:
    bracketed_sents = [
        '(S (NP (NNP John)) (VP (VBZ loves) (NP (NNP Mary))))',
        '(S (NP (NNP Mary)) (VP (VBZ hates) (NP (NNP John))))'  # poor John
    ]
    words = {'John', 'loves', 'hates', 'Mary'}
    pos_tags = {'NNP', 'VBZ'}
    nt_labels = {'S', 'NP', 'VP'}
    actions = {
        NonTerminalAction('S'),
        NonTerminalAction('NP'),
        NonTerminalAction('VP'),
        ShiftAction(),
        ReduceAction()
    }

    def test_init(self):
        oracles = [
            DiscriminativeOracle.from_parsed_sentence(Tree.fromstring(s))
            for s in self.bracketed_sents
        ]

        dataset = OracleDataset(oracles)

        assert isinstance(dataset.word_store, ItemStore)
        assert set(dataset.word_store) == self.words
        assert isinstance(dataset.pos_store, ItemStore)
        assert set(dataset.pos_store) == self.pos_tags
        assert isinstance(dataset.nt_store, ItemStore)
        assert set(dataset.nt_store) == self.nt_labels
        assert isinstance(dataset.action_store, ItemStore)
        assert set(dataset.action_store) == self.actions

    def test_getitem(self):
        oracles = [
            DiscriminativeOracle.from_parsed_sentence(Tree.fromstring(s))
            for s in self.bracketed_sents
        ]

        dataset = OracleDataset(oracles)

        assert oracles[0] is dataset[0]
        assert oracles[1] is dataset[1]

    def test_len(self):
        oracles = [
            DiscriminativeOracle.from_parsed_sentence(Tree.fromstring(s))
            for s in self.bracketed_sents
        ]

        dataset = OracleDataset(oracles)

        assert len(dataset) == len(oracles)
Example #5
0
    def test_forward_with_illegal_actions(self):
        words = ['John', 'loves', 'Mary']
        pos_tags = ['NNP', 'VBZ', 'NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))

        action_probs = parser().exp().data

        assert action_probs[self.action_store[NonTerminalAction('S')]] > 0.
        assert action_probs[self.action_store[NonTerminalAction('NP')]] > 0.
        assert action_probs[self.action_store[NonTerminalAction('VP')]] > 0.
        assert -0.001 <= action_probs[self.action_store[ShiftAction()]] <= 0.001
        assert -0.001 <= action_probs[self.action_store[ReduceAction()]] <= 0.001
Example #6
0
    def get_actions(cls, tree: Tree) -> List[Action]:
        if len(tree) == 1 and not isinstance(tree[0], Tree):
            return [cls.get_action_at_pos_node(tree)]

        actions: List[Action] = [NonTerminalAction(tree.label())]
        for child in tree:
            actions.extend(cls.get_actions(child))
        actions.append(ReduceAction())
        return actions
Example #7
0
    def push_non_terminal(self, nonterm: NonTerminalLabel) -> None:
        if nonterm not in self.nt2id:
            raise KeyError(f"unknown nonterminal '{nonterm}' encountered")
        action = NonTerminalAction(nonterm)
        if action not in self.action_store:
            raise KeyError(f"unknown action '{action}' encountered")

        if not self.can_push_non_terminal():
            raise IllegalActionError(f"Illegal NT({nonterm}) action taken.")
        self._push_non_terminal(nonterm)
        self._append_history(action)
Example #8
0
    def test_from_string(self):
        s = 'NNP VBZ\nNT(S)\nGEN(asdf)\nGEN(fdsa)\nREDUCE'

        oracle = GenerativeOracle.from_string(s)

        assert isinstance(oracle, GenerativeOracle)
        assert oracle.words == ['asdf', 'fdsa']
        assert oracle.pos_tags == ['NNP', 'VBZ']
        assert oracle.actions == [
            NonTerminalAction('S'),
            GenerateAction('asdf'),
            GenerateAction('fdsa'),
            ReduceAction()
        ]
Example #9
0
    def test_from_string(self):
        s = 'asdf fdsa\nNNP VBZ\nNT(S)\nSHIFT\nSHIFT\nREDUCE'

        oracle = DiscriminativeOracle.from_string(s)

        assert isinstance(oracle, DiscriminativeOracle)
        assert oracle.words == ['asdf', 'fdsa']
        assert oracle.pos_tags == ['NNP', 'VBZ']
        assert oracle.actions == [
            NonTerminalAction('S'),
            ShiftAction(),
            ShiftAction(),
            ReduceAction()
        ]
Example #10
0
    def test_from_parsed_sent(self):
        s = '(S (NP (NNP John)) (VP (VBZ loves) (NP (NNP Mary))))'
        expected_actions = [
            NonTerminalAction('S'),
            NonTerminalAction('NP'),
            ShiftAction(),
            ReduceAction(),
            NonTerminalAction('VP'),
            ShiftAction(),
            NonTerminalAction('NP'),
            ShiftAction(),
            ReduceAction(),
            ReduceAction(),
            ReduceAction(),
        ]
        expected_pos_tags = ['NNP', 'VBZ', 'NNP']
        expected_words = ['John', 'loves', 'Mary']

        oracle = DiscriminativeOracle.from_parsed_sentence(Tree.fromstring(s))

        assert isinstance(oracle, DiscriminativeOracle)
        assert oracle.actions == expected_actions
        assert oracle.pos_tags == expected_pos_tags
        assert oracle.words == expected_words
Example #11
0
    def test_do_non_terminal_action(self):
        words = ['John', 'loves', 'Mary']
        pos_tags = ['NNP', 'VBZ', 'NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        prev_input_buffer = parser.input_buffer

        parser.push_non_terminal('S')

        assert len(parser.stack_buffer) == 1
        last = parser.stack_buffer[-1]
        assert isinstance(last, Tree)
        assert last.label() == 'S'
        assert len(last) == 0
        assert parser.input_buffer == prev_input_buffer
        assert len(parser.action_history) == 1
        assert parser.action_history[-1] == NonTerminalAction('S')
        assert not parser.finished
Example #12
0
import torch
from torch.autograd import Variable

from rnng.models import DiscriminativeRnnGrammar
from rnng.utils import ItemStore
from rnng.actions import ShiftAction, ReduceAction, NonTerminalAction
from rnng.decoding import greedy_decode

word2id = {'John': 0, 'loves': 1, 'Mary': 2}
pos2id = {'NNP': 0, 'VBZ': 1}
nt2id = {'S': 0, 'NP': 1, 'VP': 2}
actions = [
    NonTerminalAction('S'),
    NonTerminalAction('NP'),
    NonTerminalAction('VP'),
    ShiftAction(),
    ReduceAction()
]
action_store = ItemStore()
for a in actions:
    action_store.add(a)


def test_greedy_decode(mocker):
    words = ['John', 'loves', 'Mary']
    pos_tags = ['NNP', 'VBZ', 'NNP']
    correct_actions = [
        NonTerminalAction('S'),
        NonTerminalAction('NP'),
        ShiftAction(),
        ReduceAction(),
Example #13
0
 def test_from_invalid_string(self):
     with pytest.raises(ValueError):
         NonTerminalAction.from_string('asdf')
Example #14
0
 def test_from_string(self):
     label = 'NP'
     a = NonTerminalAction.from_string(self.as_str.format(label=label))
     assert isinstance(a, NonTerminalAction)
     assert a.label == label
Example #15
0
 def test_eq(self):
     a = NonTerminalAction('NP')
     assert a == NonTerminalAction(a.label)
     assert a != NonTerminalAction('asdf')
     assert a != ShiftAction()
Example #16
0
 def test_hash(self):
     label = 'NP'
     a = NonTerminalAction(label)
     assert hash(a) == hash(self.as_str.format(label=label))
Example #17
0
 def test_to_string(self):
     label = 'NP'
     a = NonTerminalAction(label)
     assert str(a) == self.as_str.format(label=label)
Example #18
0
class TestDiscRNNGrammar:
    word2id = {'John': 0, 'loves': 1, 'Mary': 2}
    pos2id = {'NNP': 0, 'VBZ': 1}
    nt2id = {'S': 2, 'NP': 1, 'VP': 0}
    action_store = ItemStore()
    actions = [NonTerminalAction('S'), NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction(), ReduceAction()]
    for a in actions:
        action_store.add(a)

    def test_init(self):
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        assert len(parser.stack_buffer) == 0
        assert len(parser.input_buffer) == 0
        assert len(parser.action_history) == 0
        assert not parser.finished
        assert not parser.started

    def test_init_no_shift_action(self):
        action_store = ItemStore()
        actions = [NonTerminalAction('S'), NonTerminalAction('NP'), NonTerminalAction('VP'), ReduceAction()]
        for a in actions:
            action_store.add(a)

        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store)

    def test_init_no_reduce_action(self):
        action_store = ItemStore()
        actions = [NonTerminalAction('S'), NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction()]
        for a in actions:
            action_store.add(a)

        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store)

    def test_init_word_id_out_of_range(self):
        word2id = dict(self.word2id)

        word2id['John'] = len(word2id)
        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(word2id, self.pos2id, self.nt2id, self.action_store)

        word2id['John'] = -1
        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(word2id, self.pos2id, self.nt2id, self.action_store)

    def test_init_pos_id_out_of_range(self):
        pos2id = dict(self.pos2id)

        pos2id['NNP'] = len(pos2id)
        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(self.word2id, pos2id, self.nt2id, self.action_store)

        pos2id['NNP'] = -1
        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(self.word2id, pos2id, self.nt2id, self.action_store)

    def test_init_non_terminal_id_out_of_range(self):
        nt2id = dict(self.nt2id)

        nt2id['S'] = len(nt2id)
        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(self.word2id, self.pos2id, nt2id, self.action_store)

        nt2id['S'] = -1
        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(self.word2id, self.pos2id, nt2id, self.action_store)

    def test_initialise_stacks_and_buffers(self):
        words = ['John', 'loves', 'Mary']
        pos_tags = ['NNP', 'VBZ', 'NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))

        assert len(parser.stack_buffer) == 0
        assert parser.input_buffer == words
        assert len(parser.action_history) == 0
        assert not parser.finished
        assert parser.started

    def test_initialise_with_empty_tagged_words(self):
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        with pytest.raises(ValueError):
            parser.initialise_stacks_and_buffers([])

    def test_initalise_with_invalid_word_or_pos(self):
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        with pytest.raises(ValueError):
            parser.initialise_stacks_and_buffers([('Bob', 'NNP')])

        with pytest.raises(ValueError):
            parser.initialise_stacks_and_buffers([('John', 'VBD')])

    def test_do_non_terminal_action(self):
        words = ['John', 'loves', 'Mary']
        pos_tags = ['NNP', 'VBZ', 'NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        prev_input_buffer = parser.input_buffer

        parser.push_non_terminal('S')

        assert len(parser.stack_buffer) == 1
        last = parser.stack_buffer[-1]
        assert isinstance(last, Tree)
        assert last.label() == 'S'
        assert len(last) == 0
        assert parser.input_buffer == prev_input_buffer
        assert len(parser.action_history) == 1
        assert parser.action_history[-1] == NonTerminalAction('S')
        assert not parser.finished

    def test_do_illegal_push_non_terminal_action(self):
        words = ['John']
        pos_tags = ['NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        # Buffer is empty
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        parser.push_non_terminal('S')
        parser.shift()
        with pytest.raises(IllegalActionError):
            parser.push_non_terminal('NP')

        # More than 100 open nonterminals
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        for i in range(100):
            parser.push_non_terminal('S')
        with pytest.raises(IllegalActionError):
            parser.push_non_terminal('NP')

    def test_push_unknown_non_terminal(self):
        words = ['John']
        pos_tags = ['NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))

        with pytest.raises(KeyError):
            parser.push_non_terminal('asdf')

    def test_push_known_non_terminal_but_unknown_action(self):
        actions = [NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction(), ReduceAction()]
        action_store = ItemStore()
        for a in actions:
            action_store.add(a)
        words = ['John']
        pos_tags = ['NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))

        with pytest.raises(KeyError):
            parser.push_non_terminal('S')

    def test_do_shift_action(self):
        words = ['John', 'loves', 'Mary']
        pos_tags = ['NNP', 'VBZ', 'NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        parser.push_non_terminal('S')
        parser.push_non_terminal('NP')

        parser.shift()

        assert len(parser.stack_buffer) == 3
        last = parser.stack_buffer[-1]
        assert last == 'John'
        assert parser.input_buffer == words[1:]
        assert len(parser.action_history) == 3
        assert parser.action_history[-1] == ShiftAction()
        assert not parser.finished

    def test_do_illegal_shift_action(self):
        words = ['John']
        pos_tags = ['NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        # No open nonterminal
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        with pytest.raises(IllegalActionError):
            parser.shift()

        # Buffer is empty
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        parser.push_non_terminal('S')
        parser.shift()
        with pytest.raises(IllegalActionError):
            parser.shift()

    def test_do_reduce_action(self):
        words = ['John', 'loves', 'Mary']
        pos_tags = ['NNP', 'VBZ', 'NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        parser.push_non_terminal('S')
        parser.push_non_terminal('NP')
        parser.shift()
        prev_input_buffer = parser.input_buffer

        parser.reduce()

        assert len(parser.stack_buffer) == 2
        last = parser.stack_buffer[-1]
        assert isinstance(last, Tree)
        assert last.label() == 'NP'
        assert len(last) == 1
        assert last[0] == 'John'
        assert parser.input_buffer == prev_input_buffer
        assert len(parser.action_history) == 4
        assert parser.action_history[-1] == ReduceAction()
        assert not parser.finished

    def test_do_illegal_reduce_action(self):
        words = ['John', 'loves']
        pos_tags = ['NNP', 'VBZ']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        # Top of stack is an open nonterminal
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        parser.push_non_terminal('S')
        with pytest.raises(IllegalActionError):
            parser.reduce()

        # Buffer is not empty and REDUCE will finish parsing
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        parser.push_non_terminal('S')
        parser.shift()
        with pytest.raises(IllegalActionError):
            parser.reduce()

    def test_do_action_when_not_started(self):
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        with pytest.raises(RuntimeError):
            parser.push_non_terminal('S')
        with pytest.raises(RuntimeError):
            parser.shift()
        with pytest.raises(RuntimeError):
            parser.reduce()

    def test_forward(self):
        words = ['John', 'loves', 'Mary']
        pos_tags = ['NNP', 'VBZ', 'NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        parser.push_non_terminal('S')
        parser.push_non_terminal('NP')
        parser.shift()
        parser.reduce()

        action_logprobs = parser()

        assert isinstance(action_logprobs, Variable)
        assert action_logprobs.size() == (len(self.action_store),)
        sum_prob = action_logprobs.exp().sum().data[0]
        assert 0.999 <= sum_prob <= 1.001

    def test_forward_with_illegal_actions(self):
        words = ['John', 'loves', 'Mary']
        pos_tags = ['NNP', 'VBZ', 'NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))

        action_probs = parser().exp().data

        assert action_probs[self.action_store[NonTerminalAction('S')]] > 0.
        assert action_probs[self.action_store[NonTerminalAction('NP')]] > 0.
        assert action_probs[self.action_store[NonTerminalAction('VP')]] > 0.
        assert -0.001 <= action_probs[self.action_store[ShiftAction()]] <= 0.001
        assert -0.001 <= action_probs[self.action_store[ReduceAction()]] <= 0.001

    def test_forward_when_not_started(self):
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        with pytest.raises(RuntimeError):
            parser()

    def test_finished(self):
        words = ['John', 'loves', 'Mary']
        pos_tags = ['NNP', 'VBZ', 'NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)
        exp_parse_tree = Tree('S', [Tree('NP', ['John']),
                                    Tree('VP', ['loves', Tree('NP', ['Mary'])])])

        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        parser.push_non_terminal('S')
        parser.push_non_terminal('NP')
        parser.shift()
        parser.reduce()
        parser.push_non_terminal('VP')
        parser.shift()
        parser.push_non_terminal('NP')
        parser.shift()
        parser.reduce()
        parser.reduce()
        parser.reduce()

        assert parser.finished
        parse_tree = parser.stack_buffer[-1]
        assert parse_tree == exp_parse_tree
        with pytest.raises(RuntimeError):
            parser()
        with pytest.raises(RuntimeError):
            parser.push_non_terminal('NP')
        with pytest.raises(RuntimeError):
            parser.shift()
        with pytest.raises(RuntimeError):
            parser.reduce()