Пример #1
0
def test_make_example_from_disc_oracle():
    actions = [
        NT('S'),
        NT('NP'),
        SHIFT,
        REDUCE,
        NT('VP'),
        SHIFT,
        NT('NP'),
        SHIFT,
        REDUCE,
        REDUCE,
        REDUCE,
    ]
    pos_tags = 'NNP VBZ NNP'.split()
    words = 'John loves Mary'.split()
    oracle = Oracle(actions, pos_tags, words)
    fields = [
        ('actions', Field()),
        ('nonterms', Field()),
        ('pos_tags', Field()),
        ('words', Field()),
    ]

    example = make_example(oracle, fields)

    assert isinstance(example, Example)
    assert example.actions == actions
    assert example.nonterms == [get_nonterm(a) for a in actions if is_nt(a)]
    assert example.pos_tags == pos_tags
    assert example.words == words
Пример #2
0
 def test_forward_with_push_nt_when_buffer_is_empty(self):
     words = self.make_words()
     actions = self.make_actions([
         NT('S'), SHIFT, SHIFT, SHIFT, NT('NP')])
     parser = self.make_parser()
     llh = parser(words, actions)
     assert llh.exp().data[0] == pytest.approx(0, abs=1e-7)
Пример #3
0
    def test_build_vocab(self):
        field = self.make_action_field()
        nonterms = 'S NP VP'.split()
        field.nonterm_field.build_vocab([nonterms])
        field.build_vocab()

        assert len(field.vocab) == len(field.nonterm_field.vocab) + 2
        assert field.vocab.stoi[REDUCE] == DiscRNNG.REDUCE_ID
        assert field.vocab.stoi[SHIFT] == DiscRNNG.SHIFT_ID
        for nonterm in nonterms:
            nid = field.nonterm_field.vocab.stoi[nonterm]
            action = NT(nonterm)
            assert field.vocab.stoi[action] == nid + 2
        assert NT(field.nonterm_field.unk_token) in field.vocab.stoi
Пример #4
0
    def test_numericalize_with_unknown_nt_action(self):
        field = self.make_action_field()
        nonterms = 'S NP VP'.split()
        field.nonterm_field.build_vocab([nonterms])
        field.build_vocab()
        arr = [
            NT('PP'),
        ]

        tensor = field.numericalize([arr], device=-1)

        assert tensor.squeeze(dim=1).data.tolist() == [
            field.vocab.stoi[NT(field.nonterm_field.unk_token)]
        ]
Пример #5
0
 def test_forward_with_push_nt_when_maximum_number_of_open_nt_is_reached(self):
     DiscRNNG.MAX_OPEN_NT = 2
     words = self.make_words()
     actions = self.make_actions([NT('S')] * (DiscRNNG.MAX_OPEN_NT+1))
     parser = self.make_parser()
     llh = parser(words, actions)
     assert llh.exp().data[0] == pytest.approx(0, abs=1e-7)
Пример #6
0
 def test_init_with_unequal_number_of_words_and_pos_tags(self):
     actions = [NT('S'), SHIFT]
     pos_tags = ['NNP', 'VBZ']
     words = ['John']
     with pytest.raises(ValueError) as excinfo:
         DiscOracle(actions, pos_tags, words)
     assert 'number of POS tags should match number of words' in str(excinfo.value)
Пример #7
0
 def test_init_with_unequal_shift_count_and_number_of_words(self):
     actions = [NT('S')]
     pos_tags = ['NNP']
     words = ['John']
     with pytest.raises(ValueError) as excinfo:
         DiscOracle(actions, pos_tags, words)
     assert 'number of words should match number of SHIFT actions' in str(excinfo.value)
Пример #8
0
 def _actionstr2id(self, s: str) -> int:
     if s in self.vocab.stoi:
         return self.vocab.stoi[s]
     # must be an unknown NT action, so we map it to NT(<unk>)
     action = NT(self.nonterm_field.unk_token)
     assert action in self.vocab.stoi
     return self.vocab.stoi[action]
Пример #9
0
 def test_forward_with_reduce_when_tos_is_an_open_nt(self):
     words = self.make_words()
     pos_tags = self.make_pos_tags()
     actions = self.make_actions([NT('S'), REDUCE])
     parser = self.make_parser()
     llh = parser(words, pos_tags, actions)
     assert llh.exp().data[0] == pytest.approx(0, abs=1e-7)
Пример #10
0
 def test_forward_with_reduce_when_only_single_open_nt_and_buffer_is_not_empty(self):
     words = self.make_words()
     pos_tags = self.make_pos_tags()
     actions = self.make_actions([NT('S'), SHIFT, REDUCE])
     parser = self.make_parser()
     llh = parser(words, pos_tags, actions)
     assert llh.exp().data[0] == pytest.approx(0, abs=1e-7)
Пример #11
0
    def test_numericalize(self):
        field = self.make_action_field()
        nonterms = 'S NP VP'.split()
        field.nonterm_field.build_vocab([nonterms])
        field.build_vocab()
        arr = [
            NT('S'),
            NT('NP'),
            NT('VP'),
            SHIFT,
            REDUCE,
        ]

        tensor = field.numericalize([arr], device=-1)

        assert tensor.size() == (len(arr), 1)
        assert tensor.squeeze().data.tolist() == [field.vocab.stoi[a] for a in arr]
Пример #12
0
    def make_actions(self, actions=None):
        if actions is None:
            actions = [
                NT('S'),
                NT('NP'),
                SHIFT,
                REDUCE,
                NT('VP'),
                SHIFT,
                NT('NP'),
                SHIFT,
                REDUCE,
                REDUCE,
                REDUCE,
            ]

        return Variable(torch.LongTensor([self.action2id(x) for x in actions]))
Пример #13
0
    def get_actions(cls, tree: Tree) -> List[Action]:
        if len(tree) == 1 and not isinstance(tree[0], Tree):
            return [cls.get_action_at_pos_node(tree)]

        actions: List[Action] = [NT(tree.label())]
        for child in tree:
            actions.extend(cls.get_actions(child))
        actions.append(REDUCE)
        return actions
Пример #14
0
    def test_init(self):
        actions = [NT('S'), SHIFT]
        pos_tags = ['NNP']
        words = ['John']

        oracle = Oracle(actions, pos_tags, words)

        assert oracle.actions == actions
        assert oracle.pos_tags == pos_tags
        assert oracle.words == words
Пример #15
0
    def test_to_tree(self):
        s = '(S (NP (NNP John)) (VP (VBZ loves) (NP (NNP Mary))))'
        actions = [
            NT('S'),
            NT('NP'),
            GEN('John'),
            REDUCE,
            NT('VP'),
            GEN('loves'),
            NT('NP'),
            GEN('Mary'),
            REDUCE,
            REDUCE,
            REDUCE,
        ]
        pos_tags = ['NNP', 'VBZ', 'NNP']

        oracle = GenOracle(actions, pos_tags)

        assert str(oracle.to_tree()) == s
Пример #16
0
    def test_to_tree(self):
        s = '(S (NP (NNP John)) (VP (VBZ loves) (NP (NNP Mary))))'
        actions = [
            NT('S'),
            NT('NP'),
            SHIFT,
            REDUCE,
            NT('VP'),
            SHIFT,
            NT('NP'),
            SHIFT,
            REDUCE,
            REDUCE,
            REDUCE,
        ]
        pos_tags = ['NNP', 'VBZ', 'NNP']
        words = ['John', 'loves', 'Mary']

        oracle = Oracle(actions, pos_tags, words)

        assert str(oracle.to_tree()) == s
Пример #17
0
    def test_from_tree(self):
        s = '(S (NP (NNP John)) (VP (VBZ loves) (NP (NNP Mary))))'
        expected_actions = [
            NT('S'),
            NT('NP'),
            SHIFT,
            REDUCE,
            NT('VP'),
            SHIFT,
            NT('NP'),
            SHIFT,
            REDUCE,
            REDUCE,
            REDUCE,
        ]
        expected_pos_tags = ['NNP', 'VBZ', 'NNP']
        expected_words = ['John', 'loves', 'Mary']

        oracle = Oracle.from_tree(Tree.fromstring(s))

        assert isinstance(oracle, Oracle)
        assert oracle.actions == expected_actions
        assert oracle.pos_tags == expected_pos_tags
        assert oracle.words == expected_words
Пример #18
0
def test_NT():
    assert NT('NP') == 'NT(NP)'
Пример #19
0
def test_is_nt():
    assert is_nt(NT('NP'))
    assert not is_nt(REDUCE)
    assert not is_nt(SHIFT)
    assert not is_nt(GEN('John'))
Пример #20
0
def test_is_gen():
    assert is_gen(GEN('John'))
    assert not is_gen(REDUCE)
    assert not is_gen(SHIFT)
    assert not is_gen(NT('NP'))
Пример #21
0
 def build_vocab(self) -> None:
     specials = [REDUCE, SHIFT]
     for nonterm in self.nonterm_field.vocab.stoi:
         specials.append(NT(nonterm))
     self.vocab = Vocab(Counter(), specials=specials)
Пример #22
0
def test_get_nonterm():
    action = NT('NP')
    assert get_nonterm(action) == 'NP'
Пример #23
0
 def test_init_with_unequal_gen_count_and_number_of_pos_tags(self):
     actions = [NT('S')]
     pos_tags = ['NNP']
     with pytest.raises(ValueError) as excinfo:
         GenOracle(actions, pos_tags)
     assert 'number of POS tags should match number of GEN actions' in str(excinfo.value)