Ejemplo n.º 1
0
def test_get_oracle_actions():
    ids, words, tags, heads, deps, ents = [], [], [], [], [], []
    for id_, word, tag, head, dep, ent in annot_tuples:
        ids.append(id_)
        words.append(word)
        tags.append(tag)
        heads.append(head)
        deps.append(dep)
        ents.append(ent)
    doc = Doc(Vocab(), words=[t[1] for t in annot_tuples])
    cfg = {"model": DEFAULT_PARSER_MODEL}
    model = registry.resolve(cfg, validate=True)["model"]
    parser = DependencyParser(doc.vocab, model)
    parser.moves.add_action(0, "")
    parser.moves.add_action(1, "")
    parser.moves.add_action(1, "")
    parser.moves.add_action(4, "ROOT")
    heads, deps = projectivize(heads, deps)
    for i, (head, dep) in enumerate(zip(heads, deps)):
        if head > i:
            parser.moves.add_action(2, dep)
        elif head < i:
            parser.moves.add_action(3, dep)
    example = Example.from_dict(doc, {
        "words": words,
        "tags": tags,
        "heads": heads,
        "deps": deps
    })
    parser.moves.get_oracle_sequence(example)
Ejemplo n.º 2
0
def test_parser_pseudoprojectivity(en_vocab):
    def deprojectivize(proj_heads, deco_labels):
        words = ["whatever "] * len(proj_heads)
        doc = Doc(en_vocab, words=words, deps=deco_labels, heads=proj_heads)
        nonproj.deprojectivize(doc)
        return [t.head.i for t in doc], [token.dep_ for token in doc]

    # fmt: off
    tree = [1, 2, 2]
    nonproj_tree = [1, 2, 2, 4, 5, 2, 7, 4, 2]
    nonproj_tree2 = [9, 1, 3, 1, 5, 6, 9, 8, 6, 1, 6, 12, 13, 10, 1]
    labels = [
        "det", "nsubj", "root", "det", "dobj", "aux", "nsubj", "acl", "punct"
    ]
    labels2 = [
        "advmod", "root", "det", "nsubj", "advmod", "det", "dobj", "det",
        "nmod", "aux", "nmod", "advmod", "det", "amod", "punct"
    ]
    # fmt: on
    assert nonproj.decompose("X||Y") == ("X", "Y")
    assert nonproj.decompose("X") == ("X", "")
    assert nonproj.is_decorated("X||Y") is True
    assert nonproj.is_decorated("X") is False
    nonproj._lift(0, tree)
    assert tree == [2, 2, 2]
    assert nonproj.get_smallest_nonproj_arc_slow(nonproj_tree) == 7
    assert nonproj.get_smallest_nonproj_arc_slow(nonproj_tree2) == 10
    # fmt: off
    proj_heads, deco_labels = nonproj.projectivize(nonproj_tree, labels)
    assert proj_heads == [1, 2, 2, 4, 5, 2, 7, 5, 2]
    assert deco_labels == [
        "det", "nsubj", "root", "det", "dobj", "aux", "nsubj", "acl||dobj",
        "punct"
    ]
    deproj_heads, undeco_labels = deprojectivize(proj_heads, deco_labels)
    assert deproj_heads == nonproj_tree
    assert undeco_labels == labels
    proj_heads, deco_labels = nonproj.projectivize(nonproj_tree2, labels2)
    assert proj_heads == [1, 1, 3, 1, 5, 6, 9, 8, 6, 1, 9, 12, 13, 10, 1]
    assert deco_labels == [
        "advmod||aux", "root", "det", "nsubj", "advmod", "det", "dobj", "det",
        "nmod", "aux", "nmod||dobj", "advmod", "det", "amod", "punct"
    ]
    deproj_heads, undeco_labels = deprojectivize(proj_heads, deco_labels)
    assert deproj_heads == nonproj_tree2
    assert undeco_labels == labels2
    # if decoration is wrong such that there is no head with the desired label
    # the structure is kept and the label is undecorated
    proj_heads = [1, 2, 2, 4, 5, 2, 7, 5, 2]
    deco_labels = [
        "det", "nsubj", "root", "det", "dobj", "aux", "nsubj", "acl||iobj",
        "punct"
    ]
    deproj_heads, undeco_labels = deprojectivize(proj_heads, deco_labels)
    assert deproj_heads == proj_heads
    assert undeco_labels == [
        "det", "nsubj", "root", "det", "dobj", "aux", "nsubj", "acl", "punct"
    ]
    # if there are two potential new heads, the first one is chosen even if
    # it's wrong
    proj_heads = [1, 1, 3, 1, 5, 6, 9, 8, 6, 1, 9, 12, 13, 10, 1]
    deco_labels = [
        "advmod||aux", "root", "det", "aux", "advmod", "det", "dobj", "det",
        "nmod", "aux", "nmod||dobj", "advmod", "det", "amod", "punct"
    ]

    deproj_heads, undeco_labels = deprojectivize(proj_heads, deco_labels)
    assert deproj_heads == [3, 1, 3, 1, 5, 6, 9, 8, 6, 1, 6, 12, 13, 10, 1]
    assert undeco_labels == [
        "advmod", "root", "det", "aux", "advmod", "det", "dobj", "det", "nmod",
        "aux", "nmod", "advmod", "det", "amod", "punct"
    ]