def test_make_example_from_disc_oracle(): actions = [ NT('S'), NT('NP'), SHIFT, REDUCE, NT('VP'), SHIFT, NT('NP'), SHIFT, REDUCE, REDUCE, REDUCE, ] pos_tags = 'NNP VBZ NNP'.split() words = 'John loves Mary'.split() oracle = Oracle(actions, pos_tags, words) fields = [ ('actions', Field()), ('nonterms', Field()), ('pos_tags', Field()), ('words', Field()), ] example = make_example(oracle, fields) assert isinstance(example, Example) assert example.actions == actions assert example.nonterms == [get_nonterm(a) for a in actions if is_nt(a)] assert example.pos_tags == pos_tags assert example.words == words
def test_forward_with_push_nt_when_buffer_is_empty(self): words = self.make_words() actions = self.make_actions([ NT('S'), SHIFT, SHIFT, SHIFT, NT('NP')]) parser = self.make_parser() llh = parser(words, actions) assert llh.exp().data[0] == pytest.approx(0, abs=1e-7)
def test_build_vocab(self): field = self.make_action_field() nonterms = 'S NP VP'.split() field.nonterm_field.build_vocab([nonterms]) field.build_vocab() assert len(field.vocab) == len(field.nonterm_field.vocab) + 2 assert field.vocab.stoi[REDUCE] == DiscRNNG.REDUCE_ID assert field.vocab.stoi[SHIFT] == DiscRNNG.SHIFT_ID for nonterm in nonterms: nid = field.nonterm_field.vocab.stoi[nonterm] action = NT(nonterm) assert field.vocab.stoi[action] == nid + 2 assert NT(field.nonterm_field.unk_token) in field.vocab.stoi
def test_numericalize_with_unknown_nt_action(self): field = self.make_action_field() nonterms = 'S NP VP'.split() field.nonterm_field.build_vocab([nonterms]) field.build_vocab() arr = [ NT('PP'), ] tensor = field.numericalize([arr], device=-1) assert tensor.squeeze(dim=1).data.tolist() == [ field.vocab.stoi[NT(field.nonterm_field.unk_token)] ]
def test_forward_with_push_nt_when_maximum_number_of_open_nt_is_reached(self): DiscRNNG.MAX_OPEN_NT = 2 words = self.make_words() actions = self.make_actions([NT('S')] * (DiscRNNG.MAX_OPEN_NT+1)) parser = self.make_parser() llh = parser(words, actions) assert llh.exp().data[0] == pytest.approx(0, abs=1e-7)
def test_init_with_unequal_number_of_words_and_pos_tags(self): actions = [NT('S'), SHIFT] pos_tags = ['NNP', 'VBZ'] words = ['John'] with pytest.raises(ValueError) as excinfo: DiscOracle(actions, pos_tags, words) assert 'number of POS tags should match number of words' in str(excinfo.value)
def test_init_with_unequal_shift_count_and_number_of_words(self): actions = [NT('S')] pos_tags = ['NNP'] words = ['John'] with pytest.raises(ValueError) as excinfo: DiscOracle(actions, pos_tags, words) assert 'number of words should match number of SHIFT actions' in str(excinfo.value)
def _actionstr2id(self, s: str) -> int: if s in self.vocab.stoi: return self.vocab.stoi[s] # must be an unknown NT action, so we map it to NT(<unk>) action = NT(self.nonterm_field.unk_token) assert action in self.vocab.stoi return self.vocab.stoi[action]
def test_forward_with_reduce_when_tos_is_an_open_nt(self): words = self.make_words() pos_tags = self.make_pos_tags() actions = self.make_actions([NT('S'), REDUCE]) parser = self.make_parser() llh = parser(words, pos_tags, actions) assert llh.exp().data[0] == pytest.approx(0, abs=1e-7)
def test_forward_with_reduce_when_only_single_open_nt_and_buffer_is_not_empty(self): words = self.make_words() pos_tags = self.make_pos_tags() actions = self.make_actions([NT('S'), SHIFT, REDUCE]) parser = self.make_parser() llh = parser(words, pos_tags, actions) assert llh.exp().data[0] == pytest.approx(0, abs=1e-7)
def test_numericalize(self): field = self.make_action_field() nonterms = 'S NP VP'.split() field.nonterm_field.build_vocab([nonterms]) field.build_vocab() arr = [ NT('S'), NT('NP'), NT('VP'), SHIFT, REDUCE, ] tensor = field.numericalize([arr], device=-1) assert tensor.size() == (len(arr), 1) assert tensor.squeeze().data.tolist() == [field.vocab.stoi[a] for a in arr]
def make_actions(self, actions=None): if actions is None: actions = [ NT('S'), NT('NP'), SHIFT, REDUCE, NT('VP'), SHIFT, NT('NP'), SHIFT, REDUCE, REDUCE, REDUCE, ] return Variable(torch.LongTensor([self.action2id(x) for x in actions]))
def get_actions(cls, tree: Tree) -> List[Action]: if len(tree) == 1 and not isinstance(tree[0], Tree): return [cls.get_action_at_pos_node(tree)] actions: List[Action] = [NT(tree.label())] for child in tree: actions.extend(cls.get_actions(child)) actions.append(REDUCE) return actions
def test_init(self): actions = [NT('S'), SHIFT] pos_tags = ['NNP'] words = ['John'] oracle = Oracle(actions, pos_tags, words) assert oracle.actions == actions assert oracle.pos_tags == pos_tags assert oracle.words == words
def test_to_tree(self): s = '(S (NP (NNP John)) (VP (VBZ loves) (NP (NNP Mary))))' actions = [ NT('S'), NT('NP'), GEN('John'), REDUCE, NT('VP'), GEN('loves'), NT('NP'), GEN('Mary'), REDUCE, REDUCE, REDUCE, ] pos_tags = ['NNP', 'VBZ', 'NNP'] oracle = GenOracle(actions, pos_tags) assert str(oracle.to_tree()) == s
def test_to_tree(self): s = '(S (NP (NNP John)) (VP (VBZ loves) (NP (NNP Mary))))' actions = [ NT('S'), NT('NP'), SHIFT, REDUCE, NT('VP'), SHIFT, NT('NP'), SHIFT, REDUCE, REDUCE, REDUCE, ] pos_tags = ['NNP', 'VBZ', 'NNP'] words = ['John', 'loves', 'Mary'] oracle = Oracle(actions, pos_tags, words) assert str(oracle.to_tree()) == s
def test_from_tree(self): s = '(S (NP (NNP John)) (VP (VBZ loves) (NP (NNP Mary))))' expected_actions = [ NT('S'), NT('NP'), SHIFT, REDUCE, NT('VP'), SHIFT, NT('NP'), SHIFT, REDUCE, REDUCE, REDUCE, ] expected_pos_tags = ['NNP', 'VBZ', 'NNP'] expected_words = ['John', 'loves', 'Mary'] oracle = Oracle.from_tree(Tree.fromstring(s)) assert isinstance(oracle, Oracle) assert oracle.actions == expected_actions assert oracle.pos_tags == expected_pos_tags assert oracle.words == expected_words
def test_NT(): assert NT('NP') == 'NT(NP)'
def test_is_nt(): assert is_nt(NT('NP')) assert not is_nt(REDUCE) assert not is_nt(SHIFT) assert not is_nt(GEN('John'))
def test_is_gen(): assert is_gen(GEN('John')) assert not is_gen(REDUCE) assert not is_gen(SHIFT) assert not is_gen(NT('NP'))
def build_vocab(self) -> None: specials = [REDUCE, SHIFT] for nonterm in self.nonterm_field.vocab.stoi: specials.append(NT(nonterm)) self.vocab = Vocab(Counter(), specials=specials)
def test_get_nonterm(): action = NT('NP') assert get_nonterm(action) == 'NP'
def test_init_with_unequal_gen_count_and_number_of_pos_tags(self): actions = [NT('S')] pos_tags = ['NNP'] with pytest.raises(ValueError) as excinfo: GenOracle(actions, pos_tags) assert 'number of POS tags should match number of GEN actions' in str(excinfo.value)