示例#1
0
 def make_sub_tree(span):
     ret = ConstTree("X")
     ret.word_span = span
     if span[1] - span[0] == 1:
         return wrap_word(span)
     else:
         return ret
示例#2
0
 def convert_cfg_node(cls, node):
     if isinstance(node, Lexicon):
         return node
     ret = ConstTree(node.tag)
     for i in node.child:
         if isinstance(i, Lexicon) or i.has_semantics:
             ret.child.append(i)
         else:
             ret.child.extend(i.generate_words())
     ret.span = node.span
     return ret
示例#3
0
 def random_merge(node):
     children = node.child
     for child_node in children:
         if isinstance(child_node, ConstTree):
             random_merge(child_node)
         else:
             assert len(children) == 1
     while len(children) > 2:
         idx = random_obj.randint(0, len(children) - 2)
         tree_a = children[idx]
         tree_b = children[idx + 1]
         new_tree = ConstTree("X")
         new_tree.word_span = (tree_a.word_span[0], tree_b.word_span[1])
         new_tree.child = [tree_a, tree_b]
         children[idx] = new_tree
         children.pop(idx + 1)
示例#4
0
def mapper(options):
    main_dir, bank, strip_tree, is_train, graph_type, detect_func_name = options
    detect_func = {
        "small": HRGDerivation.detect_small,
        "large": HRGDerivation.detect_large,
        "lexicalized": HRGDerivation.detect_lexicalized
    }[detect_func_name]
    result = []
    with open(main_dir + bank, encoding="utf-8") as f:
        if bank.startswith("."):
            return
        while True:
            sent_id = f.readline().strip()
            if not sent_id:
                break
            assert sent_id.startswith("#")
            sent_id = sent_id[1:]
            tree_literal = f.readline().strip()
            try:
                with gzip.open(
                        deepbank_export_path + bank + "/" + sent_id + ".gz",
                        "rb") as f_gz:
                    contents = f_gz.read().decode("utf-8")
                cfg = ConstTree.from_java_code_deepbank_1_1(
                    tree_literal, contents)

                # strip labels
                if strip_tree == STRIP_ALL_LABELS or strip_tree == STRIP_INTERNAL_LABELS:
                    if strip_tree == STRIP_ALL_LABELS:
                        strip_label(cfg)
                    elif strip_tree == STRIP_INTERNAL_LABELS:
                        strip_label_internal(cfg)
                    strip_unary(cfg)
                elif strip_tree == STRIP_TO_UNLABEL or strip_tree == FUZZY_TREE:
                    strip_to_unlabel(cfg)

                cfg = cfg.condensed_unary_chain()
                cfg.populate_spans_internal()
                fix_punct_hyphen(cfg)
                fields = contents.strip().split("\n\n")
                if graph_type == "eds":
                    eds_literal = fields[-2]
                    eds_literal = re.sub("\{.*\}", "", eds_literal)
                    e = eds.loads_one(eds_literal)
                    hg = HyperGraph.from_eds(e)
                elif graph_type == "dmrs":
                    mrs_literal = fields[-3]
                    mrs_obj = simplemrs.loads_one(mrs_literal)
                    hg = HyperGraph.from_mrs(mrs_obj)
                else:
                    raise Exception("Invalid graph type!")
                names, args = extract_features(hg, cfg)
                if strip_tree == 3:
                    cfg = fuzzy_cfg(cfg, names)
                derivations = CFGRule.extract(
                    hg,
                    cfg,
                    # draw=True,
                    sent_id=sent_id,
                    detect_func=detect_func)
                sent_id_info = "# ID: " + sent_id + "\n"
                span_info = "# DelphinSpans: " + repr(
                    [i.span for i in cfg.generate_words()]) + "\n"
                args_info = "# Args: " + repr(list(args)) + "\n"
                names_info = "# Names: " + repr(list(names)) + "\n"
                header = cfg.get_words()
                original_cfg = cfg.to_string(with_comma=False).replace(
                    "+++", "+!+")
                rules = list(cfg.generate_rules())
                assert len(derivations) == len(rules)
                for syn_rule, cfg_rule in zip(derivations, rules):
                    assert cfg_rule.tag == syn_rule.lhs
                    new_name = "{}#{}".format(cfg_rule.tag,
                                              len(syn_rule.hrg.lhs.nodes) \
                                                  if syn_rule.hrg is not None else 0)
                    cfg_rule.tag = new_name
                additional_cfg = cfg.to_string(with_comma=False).replace(
                    "+++", "+!+")
                if any(rule for rule in cfg.generate_rules()
                       if len(rule.child) > 2):
                    if is_train:
                        print("{} Not binary tree!".format(sent_id))
                    else:
                        raise Exception("Not binary tree!")
                result.append((sent_id, derivations, header,
                               header + original_cfg, header + additional_cfg))
            except Exception as e:
                print(sent_id)
                print(e.__class__.__name__)
                traceback.print_exc()
    return bank, result
示例#5
0
 def wrap_word(span):
     ret = ConstTree("X")
     ret.word_span = span
     ret.child.append(words[span[0]])
     return ret
示例#6
0
def fuzzy_cfg(cfg, names):
    random_obj = Random(45)
    spans = {i[0] for i in names}
    words = list(cfg.generate_words())

    def wrap_word(span):
        ret = ConstTree("X")
        ret.word_span = span
        ret.child.append(words[span[0]])
        return ret

    def make_sub_tree(span):
        ret = ConstTree("X")
        ret.word_span = span
        if span[1] - span[0] == 1:
            return wrap_word(span)
        else:
            return ret

    sub_trees = [make_sub_tree(i) for i in spans]
    sub_trees.sort(key=lambda x: x.word_span[1] - x.word_span[0], reverse=True)

    top_trees = []
    while len(sub_trees) > 1:
        this_tree = sub_trees[-1]
        parent_tree = None
        for other_tree in sub_trees[:-1]:
            if span_overlap(this_tree.word_span, other_tree.word_span):
                if parent_tree is None or span_overlap(other_tree.word_span,
                                                       parent_tree.word_span):
                    parent_tree = other_tree
        if parent_tree is None:
            top_trees.append(this_tree)
        else:
            parent_tree.child.append(this_tree)
        sub_trees.pop()

    if len(sub_trees) == 0:
        root = sub_trees[0]
        if root.word_span[1] - root.word_span[0] != len(words):
            new_root = ConstTree("X")
            new_root.child.append(root)
            root = new_root
    else:
        root = ConstTree("X")
        root.word_span = (0, len(words))
        root.child = sub_trees

    def sort_and_fill_blank(node):
        if not node.child:
            node.child = [
                wrap_word((i, i + 1)) for i in range(*node.word_span)
            ]
        elif isinstance(node.child[0], ConstTree):
            node.child.sort(key=lambda x: x.word_span)
            new_child_list = []
            for i in range(node.word_span[0], node.child[0].word_span[0]):
                new_child_list.append(wrap_word((i, i + 1)))
            for child_node, next_child_node in zip_longest(
                    node.child, node.child[1:]):
                new_child_list.append(child_node)
                end = next_child_node.word_span[
                    0] if next_child_node is not None else node.word_span[1]
                for i in range(child_node.word_span[1], end):
                    new_child_list.append(wrap_word((i, i + 1)))
            origin_children = node.child
            node.child = new_child_list
            for child in origin_children:
                sort_and_fill_blank(child)

    sort_and_fill_blank(root)

    def random_merge(node):
        children = node.child
        for child_node in children:
            if isinstance(child_node, ConstTree):
                random_merge(child_node)
            else:
                assert len(children) == 1
        while len(children) > 2:
            idx = random_obj.randint(0, len(children) - 2)
            tree_a = children[idx]
            tree_b = children[idx + 1]
            new_tree = ConstTree("X")
            new_tree.word_span = (tree_a.word_span[0], tree_b.word_span[1])
            new_tree.child = [tree_a, tree_b]
            children[idx] = new_tree
            children.pop(idx + 1)

    random_merge(root)
    root.populate_spans_internal()
    return root