Пример #1
0
        def generate_derivation(hg  # type: HyperGraph
                                ):
            lexicons = list(cfg.generate_words())
            assert len(lexicons) == len(spans)
            rules = list(cfg.generate_rules())

            for span, lexicon in zip(spans, lexicons):
                lexicon.span = span

            count = 1
            last_new_edge = None

            for rule in rules:
                new_span = (rule.child[0].span[0], rule.child[-1].span[1])
                rule.span = new_span

                result = detect_func(hg, rule)
                if result is None:
                    rule.has_semantics = False
                    continue
                else:
                    rule.has_semantics = True
                    all_edges, internal_nodes, external_nodes = result

                new_edge = HyperEdge(external_nodes, rule.tag, False, new_span)

                new_nodes = hg.nodes - internal_nodes
                new_edges = (hg.edges - all_edges) | {new_edge}

                hg_new = HyperGraph(new_nodes, new_edges)
                node_rename_map, hrg_rule = HRGRule.extract(
                    all_edges, internal_nodes, external_nodes, rule.tag)

                if draw:
                    pic_path = "/tmp/a3/{}/{}".format(sent_id, count)
                    pics.append(
                        cls.draw(hg,
                                 pic_path,
                                 all_edges,
                                 internal_nodes,
                                 external_nodes,
                                 last_new_edge,
                                 draw_format=draw_format))

                hg = hg_new
                last_new_edge = new_edge
                count += 1
                hrg_rule.cfg = cls.convert_cfg_node(rule)
                yield node_rename_map, hrg_rule

            if draw:
                pic_path = "/tmp/a3/{}/{}".format(sent_id, count)
                pics.append(
                    cls.draw(hg,
                             pic_path,
                             last_new_edge=last_new_edge,
                             draw_format=draw_format))
Пример #2
0
    def merge(
            cls,
            cfg_node,
            sync_rule,
            left_sub_graph,  # type: SubGraph
            right_sub_graph,  # type: SubGraph
    ):
        """ :rtype: SubGraph """
        # create concrete node and unify with external nodes of subgraphs
        nodes_mapping = {i: GraphNode() for i in sync_rule.hrg.rhs.nodes}
        external_nodes_map = {}
        left_name, left_edge = sync_rule.rhs[0]
        if left_edge is not None:
            assert len(left_sub_graph.external_nodes) == len(left_edge.nodes)
            external_nodes_map.update({
                abstract_node: concrete_node
                for abstract_node, concrete_node in zip(
                    left_edge.nodes, left_sub_graph.external_nodes)
            })

        right_name, right_edge = sync_rule.rhs[1]
        if right_edge is not None:
            assert len(right_sub_graph.external_nodes) == len(right_edge.nodes)
            external_nodes_map.update({
                abstract_node: concrete_node
                for abstract_node, concrete_node in zip(
                    right_edge.nodes, right_sub_graph.external_nodes)
            })
        nodes_mapping.update(external_nodes_map)

        # build new graph
        edges = frozenset(
            cls.transform_edge(nodes_mapping, edge)
            for edge in sync_rule.hrg.rhs.edges
            if edge != left_edge and edge != right_edge)

        non_terminals = [i for i in edges if not i.is_terminal]
        if non_terminals:
            raise Exception(
                "Non-terminals {} found by rule {} in node {}".format(
                    non_terminals, sync_rule, cfg_node))
        for new_edge in edges:
            if len(new_edge.nodes) == 1 and new_edge.span is None:
                new_edge.span = cfg_node.extra["DelphinSpan"]
        nodes = frozenset(nodes_mapping.values())
        if left_edge is not None:
            edges |= left_sub_graph.graph.edges
            nodes |= left_sub_graph.graph.nodes
        if right_edge is not None:
            edges |= right_sub_graph.graph.edges
            nodes |= right_sub_graph.graph.nodes
        external_nodes = tuple(nodes_mapping[node]
                               for node in sync_rule.hrg.lhs.nodes)
        sub_graph = HyperGraph(nodes, edges)
        return SubGraph(sub_graph, external_nodes)
Пример #3
0
 def recover_rule(self, rule, lexicon, tag):
     return CFGRule(lhs=tag,
                    rhs=((lexicon, None),),
                    hrg=HRGRule(
                        lhs=rule.lhs,
                        rhs=HyperGraph(
                            nodes=rule.rhs.nodes,
                            edges=frozenset(self.transform_edge(edge, lexicon)
                                            for edge in rule.rhs.edges)
                        )
                    ))
Пример #4
0
 def create_leaf_graph(cls, cfg_node, sync_rule):
     """ :rtype: SubGraph """
     nodes_mapping = {i: GraphNode() for i in sync_rule.hrg.rhs.nodes}
     edges = frozenset(
         cls.transform_edge(nodes_mapping, edge)
         for edge in sync_rule.hrg.rhs.edges)
     for new_edge in edges:
         if len(new_edge.nodes) == 1:
             new_edge.span = cfg_node.extra["DelphinSpan"]
     sub_graph = HyperGraph(frozenset(nodes_mapping.values()), edges)
     external_nodes = tuple(nodes_mapping[node]
                            for node in sync_rule.hrg.lhs.nodes)
     return SubGraph(sub_graph, external_nodes)
Пример #5
0
def mapper(options):
    main_dir, bank, strip_tree, is_train, graph_type, detect_func_name = options
    detect_func = {
        "small": HRGDerivation.detect_small,
        "large": HRGDerivation.detect_large,
        "lexicalized": HRGDerivation.detect_lexicalized
    }[detect_func_name]
    result = []
    with open(main_dir + bank, encoding="utf-8") as f:
        if bank.startswith("."):
            return
        while True:
            sent_id = f.readline().strip()
            if not sent_id:
                break
            assert sent_id.startswith("#")
            sent_id = sent_id[1:]
            tree_literal = f.readline().strip()
            try:
                with gzip.open(
                        deepbank_export_path + bank + "/" + sent_id + ".gz",
                        "rb") as f_gz:
                    contents = f_gz.read().decode("utf-8")
                cfg = ConstTree.from_java_code_deepbank_1_1(
                    tree_literal, contents)

                # strip labels
                if strip_tree == STRIP_ALL_LABELS or strip_tree == STRIP_INTERNAL_LABELS:
                    if strip_tree == STRIP_ALL_LABELS:
                        strip_label(cfg)
                    elif strip_tree == STRIP_INTERNAL_LABELS:
                        strip_label_internal(cfg)
                    strip_unary(cfg)
                elif strip_tree == STRIP_TO_UNLABEL or strip_tree == FUZZY_TREE:
                    strip_to_unlabel(cfg)

                cfg = cfg.condensed_unary_chain()
                cfg.populate_spans_internal()
                fix_punct_hyphen(cfg)
                fields = contents.strip().split("\n\n")
                if graph_type == "eds":
                    eds_literal = fields[-2]
                    eds_literal = re.sub("\{.*\}", "", eds_literal)
                    e = eds.loads_one(eds_literal)
                    hg = HyperGraph.from_eds(e)
                elif graph_type == "dmrs":
                    mrs_literal = fields[-3]
                    mrs_obj = simplemrs.loads_one(mrs_literal)
                    hg = HyperGraph.from_mrs(mrs_obj)
                else:
                    raise Exception("Invalid graph type!")
                names, args = extract_features(hg, cfg)
                if strip_tree == 3:
                    cfg = fuzzy_cfg(cfg, names)
                derivations = CFGRule.extract(
                    hg,
                    cfg,
                    # draw=True,
                    sent_id=sent_id,
                    detect_func=detect_func)
                sent_id_info = "# ID: " + sent_id + "\n"
                span_info = "# DelphinSpans: " + repr(
                    [i.span for i in cfg.generate_words()]) + "\n"
                args_info = "# Args: " + repr(list(args)) + "\n"
                names_info = "# Names: " + repr(list(names)) + "\n"
                header = cfg.get_words()
                original_cfg = cfg.to_string(with_comma=False).replace(
                    "+++", "+!+")
                rules = list(cfg.generate_rules())
                assert len(derivations) == len(rules)
                for syn_rule, cfg_rule in zip(derivations, rules):
                    assert cfg_rule.tag == syn_rule.lhs
                    new_name = "{}#{}".format(cfg_rule.tag,
                                              len(syn_rule.hrg.lhs.nodes) \
                                                  if syn_rule.hrg is not None else 0)
                    cfg_rule.tag = new_name
                additional_cfg = cfg.to_string(with_comma=False).replace(
                    "+++", "+!+")
                if any(rule for rule in cfg.generate_rules()
                       if len(rule.child) > 2):
                    if is_train:
                        print("{} Not binary tree!".format(sent_id))
                    else:
                        raise Exception("Not binary tree!")
                result.append((sent_id, derivations, header,
                               header + original_cfg, header + additional_cfg))
            except Exception as e:
                print(sent_id)
                print(e.__class__.__name__)
                traceback.print_exc()
    return bank, result
Пример #6
0
        def generate_derivation(hg  # type: HyperGraph
                                ):
            rules = list(cfg.generate_rules())  # root last

            count = 1
            last_new_edge = None

            for rule in rules:
                new_span = (rule.child[0].span[0], rule.child[-1].span[1])
                rule.span = new_span

                result = detect_func(hg, rule)

                # null semantic node
                if result is None:
                    rule.has_semantics = False
                    if lexicalize_null_semantic:
                        cfg_rhs = tuple((j, None) for j in rule.generate_words(
                        ))  # type: Tuple[Tuple[Lexicon, None]]
                    else:
                        cfg_rhs = tuple(
                            (i if isinstance(i, Lexicon) else i.tag, None)
                            for i in rule.child)
                    yield CFGRule(rule.tag, cfg_rhs, None)
                    continue
                else:
                    rule.has_semantics = True
                    all_edges, internal_nodes, external_nodes = result

                new_edge = HyperEdge(external_nodes, rule.tag, False, new_span)

                new_nodes = hg.nodes - internal_nodes
                new_edges = (hg.edges - all_edges) | {new_edge}

                hg_new = HyperGraph(new_nodes, new_edges)
                node_rename_map, hrg_rule = HRGRule.extract(
                    all_edges, internal_nodes, external_nodes, rule.tag, rule)

                if draw:
                    pic_path = "/tmp/a3/{}/{}".format(sent_id, count)
                    pics.append(
                        HRGDerivation.draw(hg,
                                           pic_path,
                                           all_edges,
                                           internal_nodes,
                                           external_nodes,
                                           last_new_edge,
                                           draw_format=draw_format))

                hg = hg_new
                last_new_edge = new_edge
                count += 1

                if isinstance(rule.child[0], Lexicon):
                    # leaf node
                    assert len(rule.child) == 1
                    cfg_rhs = ((rule.child[0], None), )
                else:
                    # internal node
                    assert all(isinstance(i, ConstTree) for i in rule.child)
                    cfg_rhs = []
                    for i in rule.child:
                        if not i.has_semantics:
                            if lexicalize_null_semantic:
                                cfg_rhs.extend(
                                    (j, None) for j in i.generate_words())
                            else:
                                cfg_rhs.append((i.tag, None))
                        else:
                            # find corresponding hyperedge in hrg rule for this tree node
                            target_edges = [
                                j for j in all_edges if j.span == i.span
                            ]
                            assert len(target_edges) == 1
                            if target_edges[0].label != i.tag:
                                print("Non-consistent CFG and HRG: ",
                                      " ".join(j.string
                                               for j in rule.generate_words()),
                                      file=sys.stderr)
                                cfg_rhs = None
                                break
                            target_edges_r = HyperEdge(
                                (node_rename_map[node]
                                 for node in target_edges[0].nodes),
                                target_edges[0].label,
                                target_edges[0].is_terminal)
                            cfg_rhs.append((i.tag, target_edges_r))

                if cfg_rhs is not None:
                    yield CFGRule(rule.tag, tuple(cfg_rhs), hrg_rule)
                else:
                    yield CFGRule(rule.tag, cfg_rhs, None)

            if draw:
                pic_path = "/tmp/a3/{}/{}".format(sent_id, count)
                pics.append(
                    HRGDerivation.draw(hg,
                                       pic_path,
                                       last_new_edge=last_new_edge,
                                       draw_format=draw_format))
Пример #7
0
    def extract(
            cls,
            edges,  # type: Set[HyperEdge]
            internal_nodes,  # type: Set[GraphNode]
            external_nodes,  # type: Set[GraphNode]
            label,  # type: str
            cfg_rule=None):
        nodes = internal_nodes.union(external_nodes)
        edge_by_node = defaultdict(
            list)  # node -> (edge, index of this node in this edge)
        for edge in edges:
            for idx, node in enumerate(edge.nodes):
                edge_by_node[node].append((edge, idx))

        default_hash = hashlib.md5(b"13").digest()
        node_hashes = {node: default_hash for node in nodes}  # node -> hash

        def get_edge_hashes(
                node_hashes,  # type: Dict[GraphNode, bytes]
                edge,  # type: HyperEdge
                idx  # type: int
        ):
            md5_obj = hashlib.md5((edge.label + "#" + str(idx)).encode())
            for adj_node in edge.nodes:
                md5_obj.update(node_hashes[adj_node] + b"#")
            return md5_obj.digest()

        def get_sibling_hashes(
                node_hashes,  # type: Dict[GraphNode, bytes]
                node  # type: GraphNode
        ):
            md5_obj = hashlib.md5()
            edge_hashes = sorted(
                get_edge_hashes(node_hashes, edge, idx)
                for edge, idx in edge_by_node[node])
            for h in edge_hashes:
                md5_obj.update(h)
            return md5_obj.digest()

        for cycle in range(10):
            new_node_hashes = {}
            # recalculate hashes
            for node in nodes:
                md5_obj = hashlib.md5()
                md5_obj.update(get_sibling_hashes(node_hashes, node))
                md5_obj.update(b'\x01' if node in external_nodes else b'\x00')
                new_node_hashes[node] = md5_obj.digest()
            node_hashes = new_node_hashes

        nodes_in_order = sorted(node_hashes.items(), key=itemgetter(1))

        node_rename_map = {}
        for node_idx, (node, hash_value) in enumerate(nodes_in_order):
            node_rename_map[node] = GraphNode(str(node_idx))

        # get rhs
        new_edges = []
        for edge in edges:
            new_edges.append(
                HyperEdge((node_rename_map[node] for node in edge.nodes),
                          edge.label, edge.is_terminal))
        rhs = HyperGraph(frozenset(node_rename_map.values()),
                         frozenset(new_edges))

        # determine external nodes permutation
        def get_external_nodes_permutation():
            if len(external_nodes) == 2:
                for permutation in permutations(external_nodes):
                    if any(edge.nodes == permutation for edge in edges):
                        return [node_rename_map[i] for i in permutation]
                if cfg_rule is not None and len(cfg_rule.child) == 2:
                    left_span = cfg_rule.child[0].span
                    right_span = cfg_rule.child[1].span
                    left_node = [
                        edge.nodes[0] for edge in edges
                        if len(edge.nodes) == 1 and edge.span == left_span
                    ]
                    right_node = [
                        edge.nodes[0] for edge in edges
                        if len(edge.nodes) == 1 and edge.span == right_span
                    ]
                    if left_node and right_node and {
                            left_node[0], right_node[0]
                    } == external_nodes:
                        # print("Permutation rule 2 used")
                        return [
                            node_rename_map[left_node[0]],
                            node_rename_map[right_node[0]]
                        ]
            return sorted((node_rename_map[i] for i in external_nodes),
                          key=lambda x: int(x.name))

        # get lhs
        lhs = HyperEdge(get_external_nodes_permutation(),
                        label=label,
                        is_terminal=False)
        return node_rename_map, cls(lhs, rhs)
Пример #8
0
    def predict(self, trees, return_derivations=False):
        derivations = []
        for idx, tree in enumerate(trees):
            sentence_interface = tree.to_sentence()
            self.populate_delphin_spans(tree)
            self.span_ebd_network.init_special()
            span_features = self.span_ebd_network.get_span_features(
                sentence_interface)
            r = [i for i in tree.root_first()]

            syn_rules = []
            for i in tree.root_first():
                correspondents = set(self.rule_lookup(i, False).items())
                best_rule_getter = self.scorer_network.get_best_rule(
                    span_features[i.span], correspondents, None)
                exprs = next(best_rule_getter)
                best_rule, this_loss, real_best_rule = next(best_rule_getter)
                syn_rules.append(best_rule)

            rule_mapping = dict(zip(r, syn_rules))

            def transform_edge(mapping, edge, span):
                return HyperEdge((mapping[i] for i in edge.nodes), edge.label,
                                 edge.is_terminal, span)

            # deal wth root rule

            # create nodes in working graph
            nodes_mapping = {
                i: GraphNode()
                for i in syn_rules[0].hrg.rhs.nodes
            }

            # edge -> span
            span_mapping = {}
            for cfg_subnode, (name, edge) in zip(r[0], syn_rules[0].rhs):
                if edge is not None:
                    span_mapping[edge] = cfg_subnode.extra["DelphinSpan"]

            # create edges in working graph
            new_edges = frozenset(
                transform_edge(nodes_mapping, edge, span_mapping.get(edge))
                for edge in syn_rules[0].hrg.rhs.edges)

            for new_edge in new_edges:
                if len(new_edge.nodes) == 1 and new_edge.span is None:
                    new_edge.span = r[0].extra["DelphinSpan"]

            step = 0
            working_graph = HyperGraph(frozenset(nodes_mapping.values()),
                                       new_edges)
            derivations.append((working_graph, syn_rules[0]))

            queue = deque()

            if isinstance(tree.children[0], ConstTree):
                # add children nodes to queue
                for i, (_, j) in zip(tree.children, syn_rules[0].rhs):
                    if j is not None:
                        queue.append((i, rule_mapping[i],
                                      transform_edge(nodes_mapping, j,
                                                     span_mapping.get(j))))

            while queue:
                # each step substitute one nonteminal edge into subgraph,
                # and append child substitution into queue
                target_cfg_rule, target_sync_rule, target_edge = queue.popleft(
                )
                assert target_edge in working_graph.edges
                target_nodes_mapping = dict(
                    zip(target_sync_rule.hrg.lhs.nodes, target_edge.nodes))
                for node in target_sync_rule.hrg.rhs.nodes:
                    if node not in target_nodes_mapping.keys():
                        target_nodes_mapping[node] = GraphNode()

                # edge -> span
                span_mapping = {}
                for cfg_subnode, (name, edge) in zip(target_cfg_rule,
                                                     target_sync_rule.rhs):
                    if edge is not None:
                        span_mapping[edge] = cfg_subnode.extra["DelphinSpan"]

                new_nodes = working_graph.nodes | frozenset(
                    target_nodes_mapping.values())
                new_edges_this_step = frozenset(
                    transform_edge(target_nodes_mapping, edge,
                                   span_mapping.get(edge))
                    for edge in target_sync_rule.hrg.rhs.edges)
                new_edges = (working_graph.edges -
                             {target_edge}) | new_edges_this_step

                for new_edge in new_edges_this_step:
                    if len(new_edge.nodes) == 1 and new_edge.span is None:
                        new_edge.span = target_cfg_rule.extra["DelphinSpan"]

                step += 1
                working_graph = HyperGraph(new_nodes, new_edges)
                derivations.append((working_graph, target_sync_rule))

                for i, (_, j) in zip(target_cfg_rule.children,
                                     target_sync_rule.rhs):
                    if j is not None:
                        queue.append((i, rule_mapping[i],
                                      transform_edge(target_nodes_mapping, j,
                                                     span_mapping.get(j))))
            if not return_derivations:
                yield tree.extra["ID"], working_graph
            else:
                yield tree.extra["ID"], working_graph, derivations
            dn.renew_cg()
Пример #9
0
    def sync_grammar_fallback_2(self, tree_node):
        rule_name, main_node_count = tree_node.tag.rsplit("#", 1)
        word = tree_node.children[0].string
        main_node_count = int(main_node_count)
        if main_node_count == 1:
            main_node = GraphNode("0")
            surface = tree_node.children[0].string

            if self.pattern_number.match(surface):
                label = "card"
            elif rule_name.find("generic_proper") >= 0:
                label = "named"
            else:
                lemma = self.lemmatizer.lemmatize(word)
                if rule_name.find("n_-_c-pl-unk_le") >= 0:
                    label = "_{}/nns_u_unknown".format(lemma)
                elif rule_name.find("n_-_mc_le") >= 0 or rule_name.find("n_-_c_le") >= 0:
                    label = "_{}_n_1".format(lemma)  # more number is used
                elif rule_name.find("generic_mass_count_noun") >= 0:
                    label = "_{}/nn_u_unknown".format(lemma)  # more number is used
                else:
                    candidates = self.lexicon_mapping[HLexicon(word), main_node_count]
                    if candidates:
                        return candidates
                    else:
                        label = "named"

            old_edge = HyperEdge(
                nodes=[main_node],
                label=rule_name,
                is_terminal=False
            )

            main_edge = HyperEdge(
                nodes=[main_node],
                label=label,
                is_terminal=True
            )

            fallback = CFGRule(lhs=rule_name,
                               rhs=((tree_node.children[0], None),),
                               hrg=HRGRule(
                                   lhs=old_edge,
                                   rhs=HyperGraph(
                                       nodes=frozenset([main_node]),
                                       edges=frozenset({main_edge})
                                   )
                               ))
        else:
            ret1 = self.terminal_mapping.get(tree_node.tag)
            if ret1:
                return Counter([ret1.most_common(1)[0][0]])
            connected_nodes = [GraphNode(str(i)) for i in range(main_node_count)]
            centural_node = GraphNode(str(main_node_count + 1))
            old_edge = HyperEdge(
                nodes=connected_nodes,
                label=rule_name,
                is_terminal=False
            )
            main_edges = [HyperEdge(
                nodes=[centural_node, i],
                label="???",
                is_terminal=True
            ) for i in connected_nodes]
            fallback = CFGRule(lhs=rule_name,
                               rhs=((tree_node.children[0], None),),
                               hrg=HRGRule(
                                   lhs=old_edge,
                                   rhs=HyperGraph(
                                       nodes=frozenset(connected_nodes + [centural_node]),
                                       edges=frozenset(main_edges)
                                   )
                               ))
        return Counter([fallback])