def test_make_example_from_disc_oracle(): actions = [ NT('S'), NT('NP'), SHIFT, REDUCE, NT('VP'), SHIFT, NT('NP'), SHIFT, REDUCE, REDUCE, REDUCE, ] pos_tags = 'NNP VBZ NNP'.split() words = 'John loves Mary'.split() oracle = Oracle(actions, pos_tags, words) fields = [ ('actions', Field()), ('nonterms', Field()), ('pos_tags', Field()), ('words', Field()), ] example = make_example(oracle, fields) assert isinstance(example, Example) assert example.actions == actions assert example.nonterms == [get_nonterm(a) for a in actions if is_nt(a)] assert example.pos_tags == pos_tags assert example.words == words
def to_tree(self) -> Tree: stack = [] pos_tags = list(reversed(self.pos_tags)) words = list(reversed(self.words)) for a in self.actions: if is_nt(a): stack.append(get_nonterm(a)) elif a == REDUCE: children = [] while stack and isinstance(stack[-1], Tree): children.append(stack.pop()) if not children or not stack: raise ValueError( f'invalid {REDUCE} action, please check if the actions are correct') parent = stack.pop() tree = Tree(parent, list(reversed(children))) stack.append(tree) else: tree = Tree(pos_tags.pop(), [words.pop()]) stack.append(tree) if len(stack) != 1: raise ValueError('actions do not produce a single parse tree') return stack[0]
def action2id(self, action): if action == REDUCE: return 0 if action == SHIFT: return 1 return self.nt2id[get_nonterm(action)] + 2
def make_example(oracle: Oracle, fields: List[Tuple[str, Field]]): nonterms = [get_nonterm(a) for a in oracle.actions if is_nt(a)] return Example.fromlist( [oracle.actions, nonterms, oracle.pos_tags, oracle.words], fields)
def test_get_nonterm_of_invalid_action(): with pytest.raises(ValueError) as excinfo: get_nonterm(SHIFT) assert f'action {SHIFT} is not an NT action' in str(excinfo.value)
def test_get_nonterm(): action = NT('NP') assert get_nonterm(action) == 'NP'