コード例 #1
0
ファイル: test_example.py プロジェクト: mliu-dark-knight/RNNG
def test_make_example_from_disc_oracle():
    actions = [
        NT('S'),
        NT('NP'),
        SHIFT,
        REDUCE,
        NT('VP'),
        SHIFT,
        NT('NP'),
        SHIFT,
        REDUCE,
        REDUCE,
        REDUCE,
    ]
    pos_tags = 'NNP VBZ NNP'.split()
    words = 'John loves Mary'.split()
    oracle = Oracle(actions, pos_tags, words)
    fields = [
        ('actions', Field()),
        ('nonterms', Field()),
        ('pos_tags', Field()),
        ('words', Field()),
    ]

    example = make_example(oracle, fields)

    assert isinstance(example, Example)
    assert example.actions == actions
    assert example.nonterms == [get_nonterm(a) for a in actions if is_nt(a)]
    assert example.pos_tags == pos_tags
    assert example.words == words
コード例 #2
0
ファイル: oracle.py プロジェクト: mliu-dark-knight/RNNG
	def to_tree(self) -> Tree:
		stack = []
		pos_tags = list(reversed(self.pos_tags))
		words = list(reversed(self.words))
		for a in self.actions:
			if is_nt(a):
				stack.append(get_nonterm(a))
			elif a == REDUCE:
				children = []
				while stack and isinstance(stack[-1], Tree):
					children.append(stack.pop())
				if not children or not stack:
					raise ValueError(
						f'invalid {REDUCE} action, please check if the actions are correct')
				parent = stack.pop()
				tree = Tree(parent, list(reversed(children)))
				stack.append(tree)
			else:
				tree = Tree(pos_tags.pop(), [words.pop()])
				stack.append(tree)
		if len(stack) != 1:
			raise ValueError('actions do not produce a single parse tree')
		return stack[0]
コード例 #3
0
ファイル: test_models.py プロジェクト: uMiss/pytorch-rnng
 def action2id(self, action):
     if action == REDUCE:
         return 0
     if action == SHIFT:
         return 1
     return self.nt2id[get_nonterm(action)] + 2
コード例 #4
0
def make_example(oracle: Oracle, fields: List[Tuple[str, Field]]):
    nonterms = [get_nonterm(a) for a in oracle.actions if is_nt(a)]
    return Example.fromlist(
        [oracle.actions, nonterms, oracle.pos_tags, oracle.words], fields)
コード例 #5
0
ファイル: test_actions.py プロジェクト: uMiss/pytorch-rnng
def test_get_nonterm_of_invalid_action():
    with pytest.raises(ValueError) as excinfo:
        get_nonterm(SHIFT)
    assert f'action {SHIFT} is not an NT action' in str(excinfo.value)
コード例 #6
0
ファイル: test_actions.py プロジェクト: uMiss/pytorch-rnng
def test_get_nonterm():
    action = NT('NP')
    assert get_nonterm(action) == 'NP'