Example #1
0
def test_greedy_decode(mocker):
    words = ['John', 'loves', 'Mary']
    pos_tags = ['NNP', 'VBZ', 'NNP']
    correct_actions = [
        NonTerminalAction('S'),
        NonTerminalAction('NP'),
        ShiftAction(),
        ReduceAction(),
        NonTerminalAction('VP'),
        ShiftAction(),
        NonTerminalAction('NP'),
        ShiftAction(),
        ReduceAction(),
        ReduceAction(),
        ReduceAction(),
    ]
    retvals = [
        Variable(
            torch.zeros(len(action_store)).scatter_(
                0, torch.LongTensor([action_store[a]]), 1))
        for a in correct_actions
    ]
    parser = DiscriminativeRnnGrammar(word2id, pos2id, nt2id, action_store)
    parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
    mocker.patch.object(parser, 'forward', side_effect=retvals)

    result = greedy_decode(parser)

    assert len(result) == len(correct_actions)
    picked_actions, log_probs = zip(*result)
    assert list(picked_actions) == correct_actions
Example #2
0
    def get_actions(cls, tree: Tree) -> List[Action]:
        if len(tree) == 1 and not isinstance(tree[0], Tree):
            return [cls.get_action_at_pos_node(tree)]

        actions: List[Action] = [NonTerminalAction(tree.label())]
        for child in tree:
            actions.extend(cls.get_actions(child))
        actions.append(ReduceAction())
        return actions
Example #3
0
class TestOracleDataset:
    bracketed_sents = [
        '(S (NP (NNP John)) (VP (VBZ loves) (NP (NNP Mary))))',
        '(S (NP (NNP Mary)) (VP (VBZ hates) (NP (NNP John))))'  # poor John
    ]
    words = {'John', 'loves', 'hates', 'Mary'}
    pos_tags = {'NNP', 'VBZ'}
    nt_labels = {'S', 'NP', 'VP'}
    actions = {
        NonTerminalAction('S'),
        NonTerminalAction('NP'),
        NonTerminalAction('VP'),
        ShiftAction(),
        ReduceAction()
    }

    def test_init(self):
        oracles = [
            DiscriminativeOracle.from_parsed_sentence(Tree.fromstring(s))
            for s in self.bracketed_sents
        ]

        dataset = OracleDataset(oracles)

        assert isinstance(dataset.word_store, ItemStore)
        assert set(dataset.word_store) == self.words
        assert isinstance(dataset.pos_store, ItemStore)
        assert set(dataset.pos_store) == self.pos_tags
        assert isinstance(dataset.nt_store, ItemStore)
        assert set(dataset.nt_store) == self.nt_labels
        assert isinstance(dataset.action_store, ItemStore)
        assert set(dataset.action_store) == self.actions

    def test_getitem(self):
        oracles = [
            DiscriminativeOracle.from_parsed_sentence(Tree.fromstring(s))
            for s in self.bracketed_sents
        ]

        dataset = OracleDataset(oracles)

        assert oracles[0] is dataset[0]
        assert oracles[1] is dataset[1]

    def test_len(self):
        oracles = [
            DiscriminativeOracle.from_parsed_sentence(Tree.fromstring(s))
            for s in self.bracketed_sents
        ]

        dataset = OracleDataset(oracles)

        assert len(dataset) == len(oracles)
Example #4
0
    def test_push_known_non_terminal_but_unknown_action(self):
        actions = [NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction(), ReduceAction()]
        action_store = ItemStore()
        for a in actions:
            action_store.add(a)
        words = ['John']
        pos_tags = ['NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))

        with pytest.raises(KeyError):
            parser.push_non_terminal('S')
Example #5
0
    def test_forward_with_illegal_actions(self):
        words = ['John', 'loves', 'Mary']
        pos_tags = ['NNP', 'VBZ', 'NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))

        action_probs = parser().exp().data

        assert action_probs[self.action_store[NonTerminalAction('S')]] > 0.
        assert action_probs[self.action_store[NonTerminalAction('NP')]] > 0.
        assert action_probs[self.action_store[NonTerminalAction('VP')]] > 0.
        assert -0.001 <= action_probs[self.action_store[ShiftAction()]] <= 0.001
        assert -0.001 <= action_probs[self.action_store[ReduceAction()]] <= 0.001
Example #6
0
    def test_from_string(self):
        s = 'asdf fdsa\nNNP VBZ\nNT(S)\nSHIFT\nSHIFT\nREDUCE'

        oracle = DiscriminativeOracle.from_string(s)

        assert isinstance(oracle, DiscriminativeOracle)
        assert oracle.words == ['asdf', 'fdsa']
        assert oracle.pos_tags == ['NNP', 'VBZ']
        assert oracle.actions == [
            NonTerminalAction('S'),
            ShiftAction(),
            ShiftAction(),
            ReduceAction()
        ]
Example #7
0
    def test_from_string(self):
        s = 'NNP VBZ\nNT(S)\nGEN(asdf)\nGEN(fdsa)\nREDUCE'

        oracle = GenerativeOracle.from_string(s)

        assert isinstance(oracle, GenerativeOracle)
        assert oracle.words == ['asdf', 'fdsa']
        assert oracle.pos_tags == ['NNP', 'VBZ']
        assert oracle.actions == [
            NonTerminalAction('S'),
            GenerateAction('asdf'),
            GenerateAction('fdsa'),
            ReduceAction()
        ]
Example #8
0
    def test_from_parsed_sent(self):
        s = '(S (NP (NNP John)) (VP (VBZ loves) (NP (NNP Mary))))'
        expected_actions = [
            NonTerminalAction('S'),
            NonTerminalAction('NP'),
            ShiftAction(),
            ReduceAction(),
            NonTerminalAction('VP'),
            ShiftAction(),
            NonTerminalAction('NP'),
            ShiftAction(),
            ReduceAction(),
            ReduceAction(),
            ReduceAction(),
        ]
        expected_pos_tags = ['NNP', 'VBZ', 'NNP']
        expected_words = ['John', 'loves', 'Mary']

        oracle = DiscriminativeOracle.from_parsed_sentence(Tree.fromstring(s))

        assert isinstance(oracle, DiscriminativeOracle)
        assert oracle.actions == expected_actions
        assert oracle.pos_tags == expected_pos_tags
        assert oracle.words == expected_words
Example #9
0
    def test_do_reduce_action(self):
        words = ['John', 'loves', 'Mary']
        pos_tags = ['NNP', 'VBZ', 'NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        parser.push_non_terminal('S')
        parser.push_non_terminal('NP')
        parser.shift()
        prev_input_buffer = parser.input_buffer

        parser.reduce()

        assert len(parser.stack_buffer) == 2
        last = parser.stack_buffer[-1]
        assert isinstance(last, Tree)
        assert last.label() == 'NP'
        assert len(last) == 1
        assert last[0] == 'John'
        assert parser.input_buffer == prev_input_buffer
        assert len(parser.action_history) == 4
        assert parser.action_history[-1] == ReduceAction()
        assert not parser.finished
Example #10
0
 def reduce(self) -> None:
     if not self.can_reduce():
         raise IllegalActionError("Illegal REDUCE action attempted.")
     self._reduce()
     self._append_history(ReduceAction())
Example #11
0
from torch.autograd import Variable

from rnng.models import DiscriminativeRnnGrammar
from rnng.utils import ItemStore
from rnng.actions import ShiftAction, ReduceAction, NonTerminalAction
from rnng.decoding import greedy_decode

word2id = {'John': 0, 'loves': 1, 'Mary': 2}
pos2id = {'NNP': 0, 'VBZ': 1}
nt2id = {'S': 0, 'NP': 1, 'VP': 2}
actions = [
    NonTerminalAction('S'),
    NonTerminalAction('NP'),
    NonTerminalAction('VP'),
    ShiftAction(),
    ReduceAction()
]
action_store = ItemStore()
for a in actions:
    action_store.add(a)


def test_greedy_decode(mocker):
    words = ['John', 'loves', 'Mary']
    pos_tags = ['NNP', 'VBZ', 'NNP']
    correct_actions = [
        NonTerminalAction('S'),
        NonTerminalAction('NP'),
        ShiftAction(),
        ReduceAction(),
        NonTerminalAction('VP'),
Example #12
0
 def test_eq(self):
     a = GenerateAction('asdf')
     assert a == GenerateAction(a.word)
     assert a != GenerateAction('fdsa')
     assert a != ReduceAction()
Example #13
0
 def test_from_invalid_string(self):
     with pytest.raises(ValueError):
         ReduceAction.from_string('asdf')
Example #14
0
 def test_from_string(self):
     a = ReduceAction.from_string(self.as_str)
     assert isinstance(a, ReduceAction)
Example #15
0
 def test_eq(self):
     assert ReduceAction() == ReduceAction()
     assert ReduceAction() != ShiftAction()
Example #16
0
 def test_hash(self):
     a = ReduceAction()
     assert hash(a) == hash(self.as_str)
Example #17
0
 def test_to_string(self):
     a = ReduceAction()
     assert str(a) == self.as_str
Example #18
0
    def __init__(self,
                 word2id: Mapping[Word, WordId],
                 pos2id: Mapping[POSTag, POSId],
                 non_terminal2id: Mapping[NonTerminalLabel, NonTerminalId],
                 action_store: ItemStore,
                 word_dim: int = 32,
                 pos_dim: int = 12,
                 non_terminal_dim: int = 60,
                 action_dim: int = 16,
                 input_dim: int = 128,
                 hidden_dim: int = 128,
                 num_layers: int = 2,
                 dropout: float = 0.) -> None:

        if ShiftAction() not in action_store:
            raise ValueError('SHIFT action ID must be specified')
        if ReduceAction() not in action_store:
            raise ValueError('REDUCE action ID must be specified')

        num_words = len(word2id)
        num_pos = len(pos2id)
        num_non_terminals = len(non_terminal2id)
        num_actions = len(action_store)

        for wid in word2id.values():
            if wid < 0 or wid >= num_words:
                raise ValueError(f'word ID of {wid} is out of range')
        for pid in pos2id.values():
            if pid < 0 or pid >= num_pos:
                raise ValueError(f'POS tag ID of {pid} is out of range')
        for nid in non_terminal2id.values():
            if nid < 0 or nid >= num_non_terminals:
                raise ValueError(f'nonterminal ID of {nid} is out of range')

        super().__init__()
        self.word2id = word2id
        self.pos2id = pos2id
        self.nt2id = non_terminal2id
        self.action_store = action_store
        self.num_words = num_words
        self.num_pos = num_pos
        self.num_non_terminals = num_non_terminals
        self.num_actions = num_actions
        self.word_dim = word_dim
        self.pos_dim = pos_dim
        self.non_terminal_dim = non_terminal_dim
        self.action_dim = action_dim
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.dropout = dropout

        # Parser states
        self._stack = []  # type: List[StackElement]
        self._buffer = []  # type: List[Word]
        self._history = []  # type: List[Action]
        self._num_open_non_terminals = 0
        self._started = False

        # Parser state encoders
        self.stack_lstm = StackLSTM(input_dim,
                                    hidden_dim,
                                    num_layers=num_layers,
                                    dropout=dropout)
        # can use an LSTM, but this is easier.
        self.buffer_lstm = StackLSTM(input_dim,
                                     hidden_dim,
                                     num_layers=num_layers,
                                     dropout=dropout)
        # can use LSTM, but this is more efficient
        self.history_lstm = StackLSTM(input_dim,
                                      hidden_dim,
                                      num_layers=num_layers,
                                      dropout=dropout)

        # Composition
        self.compose_fwd_lstm = nn.LSTM(input_dim,
                                        input_dim,
                                        num_layers=num_layers,
                                        dropout=dropout)
        self.compose_bwd_lstm = nn.LSTM(input_dim,
                                        input_dim,
                                        num_layers=num_layers,
                                        dropout=dropout)

        # Transformations
        self.word2lstm = nn.Sequential(
            OrderedDict([('linear', nn.Linear(word_dim + pos_dim, input_dim)),
                         ('relu', nn.ReLU())]))
        self.nt2lstm = nn.Sequential(
            OrderedDict([('linear', nn.Linear(non_terminal_dim, input_dim)),
                         ('relu', nn.ReLU())]))
        self.action2lstm = nn.Sequential(
            OrderedDict([('linear', nn.Linear(action_dim, input_dim)),
                         ('relu', nn.ReLU())]))
        self.fwdbwd2composed = nn.Sequential(
            OrderedDict([('linear', nn.Linear(2 * input_dim, input_dim)),
                         ('relu', nn.ReLU())]))
        self.lstms2summary = nn.Sequential(
            OrderedDict([  # Stack LSTMs to parser summary
                ('dropout', nn.Dropout(dropout)),
                ('linear', nn.Linear(3 * hidden_dim, hidden_dim)),
                ('relu', nn.ReLU())
            ]))
        self.summary2actions = nn.Linear(hidden_dim, num_actions)

        # Embeddings
        self.word_embedding = nn.Embedding(num_words, word_dim)
        self.pos_embedding = nn.Embedding(num_pos, pos_dim)
        self.non_terminal_embedding = nn.Embedding(num_non_terminals,
                                                   non_terminal_dim)
        self.action_embedding = nn.Embedding(num_actions, action_dim)

        # Guard parameters for stack, buffer, and action history
        self.stack_guard = nn.Parameter(torch.Tensor(input_dim))
        self.buffer_guard = nn.Parameter(torch.Tensor(input_dim))
        self.history_guard = nn.Parameter(torch.Tensor(input_dim))

        # Final embeddings
        self._word_emb = {}  # type: Dict[WordId, Variable]
        self._nt_emb = {}  # type: Variable
        self._action_emb = {}  # type: Variable
Example #19
0
    def test_init_no_shift_action(self):
        action_store = ItemStore()
        actions = [NonTerminalAction('S'), NonTerminalAction('NP'), NonTerminalAction('VP'), ReduceAction()]
        for a in actions:
            action_store.add(a)

        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store)
Example #20
0
class TestDiscRNNGrammar:
    word2id = {'John': 0, 'loves': 1, 'Mary': 2}
    pos2id = {'NNP': 0, 'VBZ': 1}
    nt2id = {'S': 2, 'NP': 1, 'VP': 0}
    action_store = ItemStore()
    actions = [NonTerminalAction('S'), NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction(), ReduceAction()]
    for a in actions:
        action_store.add(a)

    def test_init(self):
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        assert len(parser.stack_buffer) == 0
        assert len(parser.input_buffer) == 0
        assert len(parser.action_history) == 0
        assert not parser.finished
        assert not parser.started

    def test_init_no_shift_action(self):
        action_store = ItemStore()
        actions = [NonTerminalAction('S'), NonTerminalAction('NP'), NonTerminalAction('VP'), ReduceAction()]
        for a in actions:
            action_store.add(a)

        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store)

    def test_init_no_reduce_action(self):
        action_store = ItemStore()
        actions = [NonTerminalAction('S'), NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction()]
        for a in actions:
            action_store.add(a)

        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store)

    def test_init_word_id_out_of_range(self):
        word2id = dict(self.word2id)

        word2id['John'] = len(word2id)
        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(word2id, self.pos2id, self.nt2id, self.action_store)

        word2id['John'] = -1
        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(word2id, self.pos2id, self.nt2id, self.action_store)

    def test_init_pos_id_out_of_range(self):
        pos2id = dict(self.pos2id)

        pos2id['NNP'] = len(pos2id)
        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(self.word2id, pos2id, self.nt2id, self.action_store)

        pos2id['NNP'] = -1
        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(self.word2id, pos2id, self.nt2id, self.action_store)

    def test_init_non_terminal_id_out_of_range(self):
        nt2id = dict(self.nt2id)

        nt2id['S'] = len(nt2id)
        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(self.word2id, self.pos2id, nt2id, self.action_store)

        nt2id['S'] = -1
        with pytest.raises(ValueError):
            DiscriminativeRnnGrammar(self.word2id, self.pos2id, nt2id, self.action_store)

    def test_initialise_stacks_and_buffers(self):
        words = ['John', 'loves', 'Mary']
        pos_tags = ['NNP', 'VBZ', 'NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))

        assert len(parser.stack_buffer) == 0
        assert parser.input_buffer == words
        assert len(parser.action_history) == 0
        assert not parser.finished
        assert parser.started

    def test_initialise_with_empty_tagged_words(self):
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        with pytest.raises(ValueError):
            parser.initialise_stacks_and_buffers([])

    def test_initalise_with_invalid_word_or_pos(self):
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        with pytest.raises(ValueError):
            parser.initialise_stacks_and_buffers([('Bob', 'NNP')])

        with pytest.raises(ValueError):
            parser.initialise_stacks_and_buffers([('John', 'VBD')])

    def test_do_non_terminal_action(self):
        words = ['John', 'loves', 'Mary']
        pos_tags = ['NNP', 'VBZ', 'NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        prev_input_buffer = parser.input_buffer

        parser.push_non_terminal('S')

        assert len(parser.stack_buffer) == 1
        last = parser.stack_buffer[-1]
        assert isinstance(last, Tree)
        assert last.label() == 'S'
        assert len(last) == 0
        assert parser.input_buffer == prev_input_buffer
        assert len(parser.action_history) == 1
        assert parser.action_history[-1] == NonTerminalAction('S')
        assert not parser.finished

    def test_do_illegal_push_non_terminal_action(self):
        words = ['John']
        pos_tags = ['NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        # Buffer is empty
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        parser.push_non_terminal('S')
        parser.shift()
        with pytest.raises(IllegalActionError):
            parser.push_non_terminal('NP')

        # More than 100 open nonterminals
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        for i in range(100):
            parser.push_non_terminal('S')
        with pytest.raises(IllegalActionError):
            parser.push_non_terminal('NP')

    def test_push_unknown_non_terminal(self):
        words = ['John']
        pos_tags = ['NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))

        with pytest.raises(KeyError):
            parser.push_non_terminal('asdf')

    def test_push_known_non_terminal_but_unknown_action(self):
        actions = [NonTerminalAction('NP'), NonTerminalAction('VP'), ShiftAction(), ReduceAction()]
        action_store = ItemStore()
        for a in actions:
            action_store.add(a)
        words = ['John']
        pos_tags = ['NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))

        with pytest.raises(KeyError):
            parser.push_non_terminal('S')

    def test_do_shift_action(self):
        words = ['John', 'loves', 'Mary']
        pos_tags = ['NNP', 'VBZ', 'NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        parser.push_non_terminal('S')
        parser.push_non_terminal('NP')

        parser.shift()

        assert len(parser.stack_buffer) == 3
        last = parser.stack_buffer[-1]
        assert last == 'John'
        assert parser.input_buffer == words[1:]
        assert len(parser.action_history) == 3
        assert parser.action_history[-1] == ShiftAction()
        assert not parser.finished

    def test_do_illegal_shift_action(self):
        words = ['John']
        pos_tags = ['NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        # No open nonterminal
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        with pytest.raises(IllegalActionError):
            parser.shift()

        # Buffer is empty
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        parser.push_non_terminal('S')
        parser.shift()
        with pytest.raises(IllegalActionError):
            parser.shift()

    def test_do_reduce_action(self):
        words = ['John', 'loves', 'Mary']
        pos_tags = ['NNP', 'VBZ', 'NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        parser.push_non_terminal('S')
        parser.push_non_terminal('NP')
        parser.shift()
        prev_input_buffer = parser.input_buffer

        parser.reduce()

        assert len(parser.stack_buffer) == 2
        last = parser.stack_buffer[-1]
        assert isinstance(last, Tree)
        assert last.label() == 'NP'
        assert len(last) == 1
        assert last[0] == 'John'
        assert parser.input_buffer == prev_input_buffer
        assert len(parser.action_history) == 4
        assert parser.action_history[-1] == ReduceAction()
        assert not parser.finished

    def test_do_illegal_reduce_action(self):
        words = ['John', 'loves']
        pos_tags = ['NNP', 'VBZ']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        # Top of stack is an open nonterminal
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        parser.push_non_terminal('S')
        with pytest.raises(IllegalActionError):
            parser.reduce()

        # Buffer is not empty and REDUCE will finish parsing
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        parser.push_non_terminal('S')
        parser.shift()
        with pytest.raises(IllegalActionError):
            parser.reduce()

    def test_do_action_when_not_started(self):
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        with pytest.raises(RuntimeError):
            parser.push_non_terminal('S')
        with pytest.raises(RuntimeError):
            parser.shift()
        with pytest.raises(RuntimeError):
            parser.reduce()

    def test_forward(self):
        words = ['John', 'loves', 'Mary']
        pos_tags = ['NNP', 'VBZ', 'NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        parser.push_non_terminal('S')
        parser.push_non_terminal('NP')
        parser.shift()
        parser.reduce()

        action_logprobs = parser()

        assert isinstance(action_logprobs, Variable)
        assert action_logprobs.size() == (len(self.action_store),)
        sum_prob = action_logprobs.exp().sum().data[0]
        assert 0.999 <= sum_prob <= 1.001

    def test_forward_with_illegal_actions(self):
        words = ['John', 'loves', 'Mary']
        pos_tags = ['NNP', 'VBZ', 'NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)
        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))

        action_probs = parser().exp().data

        assert action_probs[self.action_store[NonTerminalAction('S')]] > 0.
        assert action_probs[self.action_store[NonTerminalAction('NP')]] > 0.
        assert action_probs[self.action_store[NonTerminalAction('VP')]] > 0.
        assert -0.001 <= action_probs[self.action_store[ShiftAction()]] <= 0.001
        assert -0.001 <= action_probs[self.action_store[ReduceAction()]] <= 0.001

    def test_forward_when_not_started(self):
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)

        with pytest.raises(RuntimeError):
            parser()

    def test_finished(self):
        words = ['John', 'loves', 'Mary']
        pos_tags = ['NNP', 'VBZ', 'NNP']
        parser = DiscriminativeRnnGrammar(self.word2id, self.pos2id, self.nt2id, self.action_store)
        exp_parse_tree = Tree('S', [Tree('NP', ['John']),
                                    Tree('VP', ['loves', Tree('NP', ['Mary'])])])

        parser.initialise_stacks_and_buffers(list(zip(words, pos_tags)))
        parser.push_non_terminal('S')
        parser.push_non_terminal('NP')
        parser.shift()
        parser.reduce()
        parser.push_non_terminal('VP')
        parser.shift()
        parser.push_non_terminal('NP')
        parser.shift()
        parser.reduce()
        parser.reduce()
        parser.reduce()

        assert parser.finished
        parse_tree = parser.stack_buffer[-1]
        assert parse_tree == exp_parse_tree
        with pytest.raises(RuntimeError):
            parser()
        with pytest.raises(RuntimeError):
            parser.push_non_terminal('NP')
        with pytest.raises(RuntimeError):
            parser.shift()
        with pytest.raises(RuntimeError):
            parser.reduce()