示例#1
0
def mdst_score_v2(t: gt.Graph,
                  root,
                  weight='MDST_v_weight',
                  weight_e='MDST_e_weight'):  # Use Old formula for now
    percent_left = vp_map(t, 'PercentLeft', 'float')
    weight_p = vp_map(t, weight)
    weight_ep = ep_map(t, weight_e)

    percent_left[root] = weight_p[root]
    total_score = percent_left[root]
    for e in gt_s.bfs_iterator(t, root):
        start_v, end_v = e
        sc = percent_left[
            end_v] = weight_p[end_v] * weight_ep[e] * percent_left[start_v]
        total_score += sc

    return total_score
示例#2
0
def mdst_score(t: gt.Graph, n_idx, e_idx, used_stuff):
    n_att_name = next(n_idx.keys())
    e_att_name = next(e_idx.keys())

    percent_left = vp_map(t, 'PercentLeft', 'float')
    v_attribute_p = vp_map(t, n_att_name, 'int')
    e_attribute_p = ep_map(t, e_att_name, 'int')

    # Choose a vertex with input degree = 0
    root = np.where(t.get_in_degrees(t.get_vertices()) == 0)[0]

    if root in used_stuff:
        percent_left[root] = 1
    else:
        percent_left[root] = len(
            n_idx[n_att_name][v_attribute_p[root]]) / n_idx[n_att_name]['size']

    total_score = percent_left[root]
    for e in gt_s.bfs_iterator(t, root):
        start_node, end_node = e
        if e in used_stuff:
            edge_score = 1
            end_node_score = 1
        else:
            edge_score = len(e_idx[e_att_name][
                e_attribute_p[e]]) / e_idx[e_att_name]['size']
            if end_node in used_stuff:
                end_node_score = 1
            else:
                end_node_score = len(n_idx[n_att_name][
                    v_attribute_p[end_node]]) / n_idx[n_att_name]['size']
        percent_left[
            end_node] = end_node_score * percent_left[start_node] * edge_score
        total_score += percent_left[end_node]

    return total_score
示例#3
0
def sgm_match(t_graph: gt.Graph, g_graph: gt.Graph, delta, tau, n_idx,
              e_idx) -> gt.Graph:
    # T is a query tree
    # G is a query graph
    # Delta is the score delta that we can accept from perfect match
    # tau is how far off this tree is from the graph, at most.
    # nIdx is an index containing node attributes
    # eIdx is an index containing edge attributes
    # root_match = [n for n, d in list(T.in_degree().items()) if d == 0]

    root_match = [v for v in t_graph.vertices() if v.in_degree() == 0]
    root = root_match[0]
    n_keys = list(n_idx.keys())[0]
    e_keys = list(e_idx.keys())[0]

    #    print 'Building matching graph'

    print('Printing MDST Graph')
    print(root)
    print_graph(t_graph)

    # Step 1: Get all the matches for the nodes
    node_matches = dict()
    for v in t_graph.vertices():
        if t_graph.vp[n_keys][v] in list(n_idx[n_keys].keys()):
            node_matches[v] = n_idx[n_keys][t_graph.vp[n_keys][v]]
        else:
            node_matches[v] = set()

    # Step 2: Get all the edge matches for the node
    edge_matches = dict()
    for e in t_graph.edges():
        if t_graph.ep[e_keys][e] in list(e_idx[e_keys].keys()):
            edge_matches[e] = e_idx[e_keys][t_graph.ep[e_keys][e]]
        else:
            edge_matches[e] = set()
        # Make sure you count just the ones that have matching nodes too.
        edge_matches[e] = set([
            em for em in edge_matches[e] if em[0] in node_matches[e.source()]
            and em[1] in node_matches[e.target()]
        ])

    # Scoring, initially, is going to be super-simple:
    # You get a 1 if you match, and a 0 if you don't.  Everything's created equal.

    # Score everything and put it in a graph.

    for k in list(edge_matches.keys()):
        if len(edge_matches[k]) == 0:
            pass
            # stop_here = 1

    match_graph = gt.Graph(directed=True)
    #    for nT in T.nodes():
    #        for nG in node_matches[nT]:
    #            MatchGraph.add_node(tuple([nT,nG]),score=1,solo_score=1)
    mg_edges = set()
    mg_vertices = set()
    mg_vertices_to_index = {}
    for eT in t_graph.edges():
        for eG in edge_matches[eT]:
            v1 = (eT.source(), eG[0])
            v2 = (eT.target(), eG[1])
            mg_vertices.add(v1)
            mg_vertices.add(v2)
            mg_edges.add((v1, v2))

    # match_graph.add_edge([(eT.source(), eG.source()), (eT.target(), eG.target())])
    zero_id = vp_map(match_graph, 'zero_id')
    one_id = vp_map(match_graph, 'one_id')

    for tup in mg_vertices:
        v = match_graph.add_vertex()
        zero_id[v], one_id[v] = tup
        mg_vertices_to_index[tup] = v

    # it = iter(mg_vertices)
    # for v in match_graph.vertices():
    #     tup = next(it)
    #     zero_id[v], one_id[v] = tup
    #     mg_vertices_to_index[tup] = v

    for t1, t2 in mg_edges:
        match_graph.add_edge(mg_vertices_to_index[t1],
                             mg_vertices_to_index[t2])

    # debug_match_graph(match_graph)

    solo_score_vp = vp_map(match_graph, 'solo_score', 'int')
    score_vp = vp_map(match_graph, 'score_v', 'int')
    score_ep = ep_map(match_graph, 'score_e', 'int')
    path_vp = vp_map(match_graph, 'path', 'object')

    g_graph_original = original_vp(g_graph)
    t_graph_original = original_vp(t_graph)

    for v in match_graph.vertices():
        solo_score_vp[v] = 1
        score_vp[v] = 1

        # Here we insert original nodes
        d = coll.deque()
        d.append((t_graph_original[zero_id[v]], g_graph_original[one_id[v]]))
        path_vp[v] = d

    for e in match_graph.edges():
        score_ep[e] = 1

    # gt_draw.graph_draw(match_graph, vprops={'text': zero_id})

    # Get rid of anybody flying solo
    match_graph = clear_unconnected(match_graph,
                                    root)  # this is clearly not working.

    # Now acquire/organize all hypotheses with scores above Max_Score - tau - delta

    # Figure out how much score you could possibly get at every node in the query.
    max_score_v = vp_map(t_graph, 'max_score_v', 'int')
    max_score_e = ep_map(t_graph, 'max_score_e', 'int')
    score_vp = vp_map(match_graph, 'score_v', 'int')
    score_ep = ep_map(match_graph, 'score_e', 'int')
    path_vp = vp_map(match_graph, 'path', 'object')
    zero_id = vp_map(match_graph, 'zero_id')

    # gt_draw.graph_draw(match_graph, vprops={'text': zero_id})

    for n in t_graph.vertices():
        max_score_v[n] = 1
    for e in t_graph.edges():
        max_score_e[e] = 1

    bfs_edges = list(gt_s.bfs_iterator(t_graph, source=root))
    reversed_bfs_edges = list(reversed(bfs_edges))

    t_index = t_graph.vertex_index

    # debug_match_graph(match_graph)

    for e in reversed_bfs_edges:  # Reverse BFS search - should do leaf nodes first.
        # What's the best score we could get at this node?
        v1, v2 = e

        max_score_v[v1] += max_score_v[v2] + max_score_e[e]

        # Find all the edges equivalent to this one in the match graph
        edge_matches = [
            (eG1, eG2) for eG1, eG2 in match_graph.edges()
            if zero_id[eG1] == t_index[v1] and zero_id[eG2] == t_index[v2]
        ]

        parent_nodes = set([eM1 for eM1, eM2 in edge_matches])

        for p in parent_nodes:
            child_nodes = [eM2 for eM1, eM2 in edge_matches if eM1 == p]
            # First, check if the bottom node has a score
            best_score = 0
            # best_node = None
            c_path = None
            for c in child_nodes:
                c_edge = match_graph.edge(p, c)
                c_score = score_vp[c] + score_ep[c_edge]
                c_path = path_vp[c]

                if c_score > best_score:
                    best_score = c_score
                    # best_child_path = c_path
            score_vp[p] += best_score
            for pathNode in c_path:
                path_vp[p].appendleft(pathNode)

    leave_prop = match_graph.new_vertex_property('bool')

    # CLEAN IT UP.
    for n in match_graph.vertices():
        leave_prop[n] = score_vp[n] >= max_score_v[t_graph.vertex(
            zero_id[n])] - delta

    sub = gt.GraphView(match_graph, leave_prop)
    new_match_graph = create_q_graph(sub, add_back_reference=False)

    # Get rid of anybody flying solo
    match_graph = save_root_children(new_match_graph, root)
    zero_id = vp_map(match_graph, 'zero_id')
    one_id = vp_map(match_graph, 'one_id')
    path_list_vp = vp_map(match_graph, 'path_list', 'object')
    for n in match_graph.vertices():
        d = coll.deque()
        d.append((t_graph_original[zero_id[n]], g_graph_original[one_id[n]]))
        path_list_vp[n] = [d]

    # Get a list of solutions alive in the graph
    for e in reversed_bfs_edges:
        v1, v2 = e
        edge_matches = [
            (eG1, eG2) for eG1, eG2 in match_graph.edges()
            if zero_id[eG1] == t_index[v1] and zero_id[eG2] == t_index[v2]
        ]

        parent_nodes = set([eM1 for eM1, eM2 in edge_matches])

        for p in parent_nodes:
            child_nodes = [eM2 for eM1, eM2 in edge_matches if eM1 == p]
            # First, check if the bottom node has a score
            tmpList = []
            for c in child_nodes:
                for _p in path_list_vp[p]:
                    for _c in path_list_vp[c]:
                        tmpList.append(_p + _c)
            path_list_vp[p] = tmpList

    # debug_match_graph(match_graph)

    # Score the root solutions
    return match_graph
示例#4
0
def calculate_mdst_v2(g: gt.Graph, n_idx, e_idx, used_stuff=set()):
    # Step 1: Figure out the weights.
    n_att_name = list(n_idx.keys())[0]
    e_att_name = list(e_idx.keys())[0]

    # Create an MDSTWeight vector on the nodes and edges.
    v_weight = vp_map(g, 'MDST_v_weight', 'float')
    e_weight = ep_map(g, 'MDST_e_weight', 'float')

    v_attribute_list = list(n_idx[n_att_name].keys())
    e_attribute_list = list(e_idx[e_att_name].keys())

    v_a_map = vp_map(g, n_att_name)
    e_a_map = ep_map(g, e_att_name)

    for n in g.vertices():
        if v_a_map[n] in v_attribute_list:
            v_weight[n] = len(
                n_idx[n_att_name][v_a_map[n]]) / n_idx[n_att_name]['size']
        else:
            v_weight[n] = 0

    for e in g.edges():
        if e in used_stuff:
            e_weight[e] = 1
        else:
            if e_a_map[e] in e_attribute_list:
                e_weight[e] = len(
                    e_idx[e_att_name][e_a_map[e]]) / e_idx[e_att_name]['size']
            else:
                e_weight[e] = 0

    #    for e1,e2 in G.edges():
    #        G.adj[e1][e2]['Nonsense'] = 5

    # Step 2: Calculate the MST.
    # gt.draw.graph_draw(g, vertex_text=g.vp['old'], vertex_font_size=18, output_size=(300, 300), output='G.png')
    t_map = gt_top.min_spanning_tree(g, e_weight, g.vertex(0))

    # T = nx.algorithms.minimum_spanning_tree(G,weight='Nonsense')
    # Step 3: Figure out which root results in us doing the least work.

    t = gt.GraphView(g, efilt=t_map, directed=False)
    # gt.draw.graph_draw(t, vertex_text=t.vp['old'], vertex_font_size=18, output_size=(300, 300), output='T.png')

    best_t = None
    best_score = np.inf
    for root in t.vertices():
        # Generate a new tree

        it = gt_s.bfs_iterator(t, root)
        nodes = []
        edges = []
        for e in it:
            edges.append(e)
            nodes.extend([e.source(), e.target()])
        nodes = np.unique(nodes)
        new_t = create_q_graph(t, q_nodes=nodes, q_edges=edges, directed=True)

        new_t_score = mdst_score_v2(t, root)
        if new_t_score < best_score:
            # print(best_score)
            best_t = new_t
            best_score = new_t_score

    return best_t, best_score
def steiner_tree_mst(g, root, infection_times, source, terminals,
                     closure_builder=build_closure,
                     strictly_smaller=True,
                     return_closure=False,
                     k=-1,
                     debug=False,
                     verbose=True):
    gc, eweight, r2pred = closure_builder(g, root, terminals,
                                          infection_times,
                                          strictly_smaller=strictly_smaller,
                                          k=k,
                                          debug=debug,
                                          verbose=verbose)

    # get the minimum spanning arborescence
    # graph_tool does not provide minimum_spanning_arborescence
    if verbose:
        print('getting mst')
    gx = gt2nx(gc, root, terminals, edge_attrs={'weight': eweight})
    try:
        nx_tree = nx.minimum_spanning_arborescence(gx, 'weight')
    except nx.exception.NetworkXException:
        if debug:
            print('fail to find mst')
        if return_closure:
            return None, gc, None
        else:
            return None

    if verbose:
        print('returning tree')

    mst_tree = Graph(directed=True)
    for _ in range(g.num_vertices()):
        mst_tree.add_vertex()

    for u, v in nx_tree.edges():
        mst_tree.add_edge(u, v)

    if verbose:
        print('extract edges from original graph')

    # extract the edges from the original graph

    # sort observations by time
    # and also topological order
    topological_index = {}
    for i, e in enumerate(bfs_iterator(mst_tree, source=root)):
        topological_index[int(e.target())] = i
    sorted_obs = sorted(
        set(terminals) - {root},
        key=lambda o: (infection_times[o], topological_index[o]))

    tree_nodes = {root}
    tree_edges = set()
    # print('root', root)
    for u in sorted_obs:
        if u in tree_nodes:
            if debug:
                print('{} covered already'.format(u))
            continue
        # print(u)
        v, u = map(int, next(mst_tree.vertex(u).in_edges()))  # v is ancestor
        tree_nodes.add(v)

        late_nodes = [n for n in terminals if infection_times[n] > infection_times[u]]
        vis = init_visitor(g, u)
        # from child to any tree node, including v

        cpbfs_search(g, source=u, terminals=list(tree_nodes),
                     forbidden_nodes=late_nodes,
                     visitor=vis,
                     count_threshold=1)
        # dist, pred = shortest_distance(g, source=u, pred_map=True)
        node_set = {v for v, d in vis.dist.items() if d > 0}
        reachable_tree_nodes = node_set.intersection(tree_nodes)
        ancestor = min(reachable_tree_nodes, key=vis.dist.__getitem__)

        edges = extract_edges_from_pred(g, u, ancestor, vis.pred)
        edges = {(j, i) for i, j in edges}  # need to reverse it
        if debug:
            print('tree_nodes', tree_nodes)
            print('connecting {} to {}'.format(v, u))
            print('using ancestor {}'.format(ancestor))
            print('adding edges {}'.format(edges))
        tree_nodes |= {u for e in edges for u in e}

        tree_edges |= edges

    t = Graph(directed=True)
    for _ in range(g.num_vertices()):
        t.add_vertex()

    for u, v in tree_edges:
        t.add_edge(t.vertex(u), t.vertex(v))

    tree_nodes = {u for e in tree_edges for u in e}
    vfilt = t.new_vertex_property('bool')
    vfilt.a = False
    for v in tree_nodes:
        vfilt[t.vertex(v)] = True

    t.set_vertex_filter(vfilt)

    if return_closure:
        return t, gc, mst_tree
    else:
        return t
示例#6
0
def run(
        input_file: KGTKFiles,
        output_file: KGTKFiles,
        root: typing.Optional[typing.List[str]],
        rootfile,
        rootfilecolumn,
        subject_column_name: typing.Optional[str],
        object_column_name: typing.Optional[str],
        predicate_column_name: typing.Optional[str],
        props: typing.Optional[typing.List[str]],
        props_file: typing.Optional[str],
        propsfilecolumn: typing.Optional[str],
        inverted: bool,
        inverted_props: typing.Optional[typing.List[str]],
        inverted_props_file: typing.Optional[str],
        invertedpropsfilecolumn: typing.Optional[str],
        undirected: bool,
        undirected_props: typing.Optional[typing.List[str]],
        undirected_props_file: typing.Optional[str],
        undirectedpropsfilecolumn: typing.Optional[str],
        label: str,
        selflink_bool: bool,
        show_properties: bool,
        breadth_first: bool,
        depth_limit: typing.Optional[int],
        errors_to_stdout: bool,
        errors_to_stderr: bool,
        show_options: bool,
        verbose: bool,
        very_verbose: bool,
        **kwargs,  # Whatever KgtkFileOptions and KgtkValueOptions want.
):
    import sys
    import csv
    from pathlib import Path
    import time
    from graph_tool.search import dfs_iterator, bfs_iterator, bfs_search, BFSVisitor
    # from graph_tool import load_graph_from_csv
    from graph_tool.util import find_edge
    from kgtk.exceptions import KGTKException
    from kgtk.cli_argparse import KGTKArgumentParser

    from kgtk.gt.gt_load import load_graph_from_kgtk
    from kgtk.io.kgtkwriter import KgtkWriter
    from kgtk.io.kgtkreader import KgtkReader, KgtkReaderOptions
    from kgtk.value.kgtkvalueoptions import KgtkValueOptions

    #Graph-tool names columns that are not subject or object c0, c1... This function finds the number that graph tool assigned to the predicate column
    def find_pred_position(sub, pred, obj):
        if pred < sub and pred < obj:
            return pred
        elif (pred > sub and pred < obj) or (pred < sub and pred > obj):
            return pred - 1
        else:
            return pred - 2

    def get_edges_by_edge_prop(g, p, v):
        return find_edge(g, prop=g.properties[('e', p)], match=v)

    input_kgtk_file: Path = KGTKArgumentParser.get_input_file(input_file)
    output_kgtk_file: Path = KGTKArgumentParser.get_output_file(output_file)

    # Select where to send error messages, defaulting to stderr.
    error_file: typing.TextIO = sys.stdout if errors_to_stdout else sys.stderr

    # Build the option structures.
    input_reader_options: KgtkReaderOptions = KgtkReaderOptions.from_dict(
        kwargs, who="input", fallback=True)
    root_reader_options: KgtkReaderOptions = KgtkReaderOptions.from_dict(
        kwargs, who="root", fallback=True)
    props_reader_options: KgtkReaderOptions = KgtkReaderOptions.from_dict(
        kwargs, who="props", fallback=True)
    undirected_props_reader_options: KgtkReaderOptions = KgtkReaderOptions.from_dict(
        kwargs, who="undirected_props", fallback=True)
    inverted_props_reader_options: KgtkReaderOptions = KgtkReaderOptions.from_dict(
        kwargs, who="inverted_props", fallback=True)
    value_options: KgtkValueOptions = KgtkValueOptions.from_dict(kwargs)

    if root is None:
        root = []  # This simplifies matters.

    if props is None:
        props = []  # This simplifies matters.

    if undirected_props is None:
        undirected_props = []  # This simplifies matters.

    if inverted_props is None:
        inverted_props = []  # This simplifies matters.

    if show_options:
        if root is not None:
            print("--root %s" % " ".join(root), file=error_file)
        if rootfile is not None:
            print("--rootfile=%s" % rootfile, file=error_file)
        if rootfilecolumn is not None:
            print("--rootfilecolumn=%s" % rootfilecolumn, file=error_file)
        if subject_column_name is not None:
            print("--subj=%s" % subject_column_name, file=error_file)
        if object_column_name is not None:
            print("--obj=%s" % object_column_name, file=error_file)
        if predicate_column_name is not None:
            print("--pred=%s" % predicate_column_name, file=error_file)

        if props is not None:
            print("--props=%s" % " ".join(props), file=error_file)
        if props_file is not None:
            print("--props-file=%s" % props_file, file=error_file)
        if propsfilecolumn is not None:
            print("--propsfilecolumn=%s" % propsfilecolumn, file=error_file)

        print("--inverted=%s" % str(inverted), file=error_file)
        if inverted_props is not None:
            print("--inverted-props=%s" % " ".join(inverted_props),
                  file=error_file)
        if inverted_props_file is not None:
            print("--inverted-props-file=%s" % inverted_props_file,
                  file=error_file)
        if invertedpropsfilecolumn is not None:
            print("--invertedpropsfilecolumn=%s" % invertedpropsfilecolumn,
                  file=error_file)

        print("--undirected=%s" % str(undirected), file=error_file)
        if undirected_props is not None:
            print("--undirected-props=%s" % " ".join(undirected_props),
                  file=error_file)
        if undirected_props_file is not None:
            print("--undirected-props-file=%s" % undirected_props_file,
                  file=error_file)
        if undirectedpropsfilecolumn is not None:
            print("--undirectedpropsfilecolumn=%s" % undirectedpropsfilecolumn,
                  file=error_file)

        print("--label=%s" % label, file=error_file)
        print("--selflink=%s" % str(selflink_bool), file=error_file)
        print("--breadth-first=%s" % str(breadth_first), file=error_file)
        if depth_limit is not None:
            print("--depth-limit=%d" % depth_limit, file=error_file)
        input_reader_options.show(out=error_file)
        root_reader_options.show(out=error_file)
        props_reader_options.show(out=error_file)
        undirected_props_reader_options.show(out=error_file)
        inverted_props_reader_options.show(out=error_file)
        value_options.show(out=error_file)
        KgtkReader.show_debug_arguments(errors_to_stdout=errors_to_stdout,
                                        errors_to_stderr=errors_to_stderr,
                                        show_options=show_options,
                                        verbose=verbose,
                                        very_verbose=very_verbose,
                                        out=error_file)
        print("=======", file=error_file, flush=True)

    if inverted and (len(inverted_props) > 0
                     or inverted_props_file is not None):
        raise KGTKException(
            "--inverted is not allowed with --inverted-props or --inverted-props-file"
        )

    if undirected and (len(undirected_props) > 0
                       or undirected_props_file is not None):
        raise KGTKException(
            "--undirected is not allowed with --undirected-props or --undirected-props-file"
        )

    if depth_limit is not None:
        if not breadth_first:
            raise KGTKException(
                "--depth-limit is not allowed without --breadth-first")
        if depth_limit <= 0:
            raise KGTKException("--depth-limit requires a positive argument")

    root_set: typing.Set = set()

    if rootfile is not None:
        if verbose:
            print("Reading the root file %s" % repr(rootfile),
                  file=error_file,
                  flush=True)
        try:
            root_kr: KgtkReader = KgtkReader.open(
                Path(rootfile),
                error_file=error_file,
                who="root",
                options=root_reader_options,
                value_options=value_options,
                verbose=verbose,
                very_verbose=very_verbose,
            )
        except SystemExit:
            raise KGTKException("Exiting.")

        rootcol: int
        if root_kr.is_edge_file:
            rootcol = int(
                rootfilecolumn
            ) if rootfilecolumn is not None and rootfilecolumn.isdigit(
            ) else root_kr.get_node1_column_index(rootfilecolumn)
        elif root_kr.is_node_file:
            rootcol = int(
                rootfilecolumn
            ) if rootfilecolumn is not None and rootfilecolumn.isdigit(
            ) else root_kr.get_id_column_index(rootfilecolumn)
        elif rootfilecolumn is not None:
            rootcol = int(
                rootfilecolumn
            ) if rootfilecolumn is not None and rootfilecolumn.isdigit(
            ) else root_kr.column_name_map.get(rootfilecolumn, -1)
        else:
            root_kr.close()
            raise KGTKException(
                "The root file is neither an edge nor a node file and the root column name was not supplied."
            )

        if rootcol < 0:
            root_kr.close()
            raise KGTKException("Unknown root column %s" %
                                repr(rootfilecolumn))

        for row in root_kr:
            rootnode: str = row[rootcol]
            root_set.add(rootnode)
        root_kr.close()

    if len(root) > 0:
        if verbose:
            print("Adding root nodes from the command line.",
                  file=error_file,
                  flush=True)
        root_group: str
        for root_group in root:
            r: str
            for r in root_group.split(','):
                if verbose:
                    print("... adding %s" % repr(r),
                          file=error_file,
                          flush=True)
                root_set.add(r)
    if len(root_set) == 0:
        print(
            "Warning: No nodes in the root set, the output file will be empty.",
            file=error_file,
            flush=True)
    elif verbose:
        print("%d nodes in the root set." % len(root_set),
              file=error_file,
              flush=True)

    property_set: typing.Set[str] = set()
    if props_file is not None:
        if verbose:
            print("Reading the root file %s" % repr(props_file),
                  file=error_file,
                  flush=True)
        try:
            props_kr: KgtkReader = KgtkReader.open(
                Path(props_file),
                error_file=error_file,
                who="props",
                options=props_reader_options,
                value_options=value_options,
                verbose=verbose,
                very_verbose=very_verbose,
            )
        except SystemExit:
            raise KGTKException("Exiting.")

        propscol: int
        if props_kr.is_edge_file:
            propscol = int(
                propsfilecolumn
            ) if propsfilecolumn is not None and propsfilecolumn.isdigit(
            ) else props_kr.get_node1_column_index(propsfilecolumn)
        elif props_kr.is_node_file:
            propscol = int(
                propsfilecolumn
            ) if propsfilecolumn is not None and propsfilecolumn.isdigit(
            ) else props_kr.get_id_column_index(propsfilecolumn)
        elif propsfilecolumn is not None:
            propscol = int(
                propsfilecolumn
            ) if propsfilecolumn is not None and propsfilecolumn.isdigit(
            ) else props_kr.column_name_map.get(propsfilecolumn, -1)
        else:
            props_kr.close()
            raise KGTKException(
                "The props file is neither an edge nor a node file and the root column name was not supplied."
            )

        if propscol < 0:
            props_kr.close()
            raise KGTKException("Unknown props column %s" %
                                repr(propsfilecolumn))

        for row in props_kr:
            property_name: str = row[propscol]
            property_set.add(property_name)
        props_kr.close()

    if len(props) > 0:
        # Filter the graph, G, to include only edges where the predicate (label)
        # column contains one of the selected properties.

        prop_group: str
        for prop_group in props:
            prop: str
            for prop in prop_group.split(','):
                property_set.add(prop)
    if verbose and len(property_set) > 0:
        print("property set=%s" % " ".join(sorted(list(property_set))),
              file=error_file,
              flush=True)

    undirected_property_set: typing.Set[str] = set()
    if undirected_props_file is not None:
        if verbose:
            print("Reading the undirected properties file %s" %
                  repr(undirected_props_file),
                  file=error_file,
                  flush=True)
        try:
            undirected_props_kr: KgtkReader = KgtkReader.open(
                Path(undirected_props_file),
                error_file=error_file,
                who="undirected_props",
                options=undirected_props_reader_options,
                value_options=value_options,
                verbose=verbose,
                very_verbose=very_verbose,
            )
        except SystemExit:
            raise KGTKException("Exiting.")

        undirected_props_col: int
        if undirected_props_kr.is_edge_file:
            undirected_props_col = int(
                undirectedpropsfilecolumn
            ) if undirectedpropsfilecolumn is not None and undirectedpropsfilecolumn.isdigit(
            ) else undirected_props_kr.get_node1_column_index(
                undirectedpropsfilecolumn)
        elif undirected_props_kr.is_node_file:
            undirected_props_col = int(
                undirectedpropsfilecolumn
            ) if undirectedpropsfilecolumn is not None and undirectedpropsfilecolumn.isdigit(
            ) else undirected_props_kr.get_id_column_index(
                undirectedpropsfilecolumn)
        elif undirectedpropsfilecolumn is not None:
            undirected_props_col = int(
                undirectedpropsfilecolumn
            ) if undirectedpropsfilecolumn is not None and undirectedpropsfilecolumn.isdigit(
            ) else undirected_props_kr.column_name_map.get(
                undirectedpropsfilecolumn, -1)
        else:
            undirected_props_kr.close()
            raise KGTKException(
                "The undirected props file is neither an edge nor a node file and the root column name was not supplied."
            )

        if undirected_props_col < 0:
            undirected_props_kr.close()
            raise KGTKException("Unknown undirected properties column %s" %
                                repr(undirectedpropsfilecolumn))

        for row in undirected_props_kr:
            undirected_property_name: str = row[undirected_props_col]
            undirected_property_set.add(undirected_property_name)
        undirected_props_kr.close()
    if len(undirected_props) > 0:
        # Edges where the predicate (label) column contains one of the selected
        # properties will be treated as undirected links.

        und_prop_group: str
        for und_prop_group in undirected_props:
            und_prop: str
            for und_prop in und_prop_group.split(','):
                undirected_property_set.add(und_prop)
    if verbose and len(undirected_property_set) > 0:
        print("undirected property set=%s" %
              " ".join(sorted(list(undirected_property_set))),
              file=error_file,
              flush=True)

    inverted_property_set: typing.Set[str] = set()
    if inverted_props_file is not None:
        if verbose:
            print("Reading the inverted properties file %s" %
                  repr(inverted_props_file),
                  file=error_file,
                  flush=True)
        try:
            inverted_props_kr: KgtkReader = KgtkReader.open(
                Path(inverted_props_file),
                error_file=error_file,
                who="inverted_props",
                options=inverted_props_reader_options,
                value_options=value_options,
                verbose=verbose,
                very_verbose=very_verbose,
            )
        except SystemExit:
            raise KGTKException("Exiting.")

        inverted_props_col: int
        if inverted_props_kr.is_edge_file:
            inverted_props_col = int(
                invertedpropsfilecolumn
            ) if invertedpropsfilecolumn is not None and invertedpropsfilecolumn.isdigit(
            ) else inverted_props_kr.get_node1_column_index(
                invertedpropsfilecolumn)
        elif inverted_props_kr.is_node_file:
            inverted_props_col = int(
                invertedpropsfilecolumn
            ) if invertedpropsfilecolumn is not None and invertedpropsfilecolumn.isdigit(
            ) else inverted_props_kr.get_id_column_index(
                invertedpropsfilecolumn)
        elif invertedpropsfilecolumn is not None:
            inverted_props_col = int(
                invertedpropsfilecolumn
            ) if invertedpropsfilecolumn is not None and invertedpropsfilecolumn.isdigit(
            ) else inverted_props_kr.column_name_map.get(
                invertedpropsfilecolumn, -1)
        else:
            inverted_props_kr.close()
            raise KGTKException(
                "The inverted props file is neither an edge nor a node file and the root column name was not supplied."
            )

        if inverted_props_col < 0:
            inverted_props_kr.close()
            raise KGTKException("Unknown inverted properties column %s" %
                                repr(invertedpropsfilecolumn))

        for row in inverted_props_kr:
            inverted_property_name: str = row[inverted_props_col]
            inverted_property_set.add(inverted_property_name)
        inverted_props_kr.close()

    if len(inverted_props) > 0:
        # Edges where the predicate (label) column contains one of the selected
        # properties will have the source and target columns swapped.

        inv_prop_group: str
        for inv_prop_group in inverted_props:
            inv_prop: str
            for inv_prop in inv_prop_group.split(','):
                inverted_property_set.add(inv_prop)
    if verbose and len(inverted_property_set):
        print("inverted property set=%s" %
              " ".join(sorted(list(inverted_property_set))),
              file=error_file,
              flush=True)

    try:
        kr: KgtkReader = KgtkReader.open(
            input_kgtk_file,
            error_file=error_file,
            who="input",
            options=input_reader_options,
            value_options=value_options,
            verbose=verbose,
            very_verbose=very_verbose,
        )
    except SystemExit:
        raise KGTKException("Exiting.")

    sub: int = kr.get_node1_column_index(subject_column_name)
    if sub < 0:
        print("Unknown subject column %s" % repr(subject_column_name),
              file=error_file,
              flush=True)

    pred: int = kr.get_label_column_index(predicate_column_name)
    if pred < 0:
        print("Unknown predicate column %s" % repr(predicate_column_name),
              file=error_file,
              flush=True)

    obj: int = kr.get_node2_column_index(object_column_name)
    if obj < 0:
        print("Unknown object column %s" % repr(object_column_name),
              file=error_file,
              flush=True)

    if sub < 0 or pred < 0 or obj < 0:
        kr.close()
        raise KGTKException("Exiting due to unknown column.")

    if verbose:
        print("special columns: sub=%d pred=%d obj=%d" % (sub, pred, obj),
              file=error_file,
              flush=True)

    # G = load_graph_from_csv(filename,not(undirected),skip_first=not(header_bool),hashed=True,csv_options={'delimiter': '\t'},ecols=(sub,obj))
    G = load_graph_from_kgtk(kr,
                             directed=not undirected,
                             inverted=inverted,
                             ecols=(sub, obj),
                             pcol=pred,
                             pset=property_set,
                             upset=undirected_property_set,
                             ipset=inverted_property_set,
                             verbose=verbose,
                             out=error_file)

    name = G.vp[
        "name"]  # Get the vertex name property map (vertex to ndoe1 (subject) name)

    if show_properties:
        print("Graph name=%s" % repr(name), file=error_file, flush=True)
        print("Graph properties:", file=error_file, flush=True)
        key: typing.Any
        for key in G.properties:
            print("    %s: %s" % (repr(key), repr(G.properties[key])),
                  file=error_file,
                  flush=True)

    index_list = []
    for v in G.vertices():
        if name[v] in root_set:
            index_list.append(v)
    if len(index_list) == 0:
        print(
            "Warning: No root nodes found in the graph, the output file will be empty.",
            file=error_file,
            flush=True)
    elif verbose:
        print("%d root nodes found in the graph." % len(index_list),
              file=error_file,
              flush=True)

    output_header: typing.List[str] = ['node1', 'label', 'node2']

    try:
        kw: KgtkWriter = KgtkWriter.open(output_header,
                                         output_kgtk_file,
                                         mode=KgtkWriter.Mode.EDGE,
                                         require_all_columns=True,
                                         prohibit_extra_columns=True,
                                         fill_missing_columns=False,
                                         verbose=verbose,
                                         very_verbose=very_verbose)
    except SystemExit:
        raise KGTKException("Exiting.")

    for index in index_list:
        if selflink_bool:
            kw.writerow([name[index], label, name[index]])

        if breadth_first:
            if depth_limit is None:
                for e in bfs_iterator(G, G.vertex(index)):
                    kw.writerow([name[index], label, name[e.target()]])

            else:

                class DepthExceeded(Exception):
                    pass

                class DepthLimitedVisitor(BFSVisitor):
                    def __init__(self, name, pred, dist):
                        self.name = name
                        self.pred = pred
                        self.dist = dist

                    def tree_edge(self, e):
                        self.pred[e.target()] = int(e.source())
                        newdist = self.dist[e.source()] + 1
                        if depth_limit is not None and newdist > depth_limit:
                            raise DepthExceeded
                        self.dist[e.target()] = newdist
                        kw.writerow([name[index], label, name[e.target()]])

                dist = G.new_vertex_property("int")
                pred = G.new_vertex_property("int64_t")
                try:
                    bfs_search(G, G.vertex(index),
                               DepthLimitedVisitor(name, pred, dist))
                except DepthExceeded:
                    pass
        else:
            for e in dfs_iterator(G, G.vertex(index)):
                kw.writerow([name[index], label, name[e.target()]])

    kw.close()
    kr.close()
def find_tree_by_closure(g,
                         root,
                         infection_times,
                         terminals,
                         closure_builder=build_closure_with_order,
                         strictly_smaller=True,
                         return_closure=False,
                         k=-1,
                         debug=False,
                         verbose=True):
    """find the steiner tree by trainsitive closure
    
    """
    gc, eweight = closure_builder(g,
                                  root,
                                  terminals,
                                  infection_times,
                                  strictly_smaller=strictly_smaller,
                                  k=k,
                                  return_r2pred=False,
                                  debug=debug,
                                  verbose=verbose)

    # get the minimum spanning arborescence
    # graph_tool does not provide minimum_spanning_arborescence
    if verbose:
        print('getting mst')
    tree_edges = find_minimum_branching(gc, [root], weights=eweight)

    efilt = gc.new_edge_property('bool')
    efilt.a = False
    for u, v in tree_edges:
        efilt[gc.edge(u, v)] = True

    mst_tree = GraphView(gc, efilt=efilt)

    if verbose:
        print('extract edges from original graph')

    # extract the edges from the original graph

    # sort observations by time
    # and also topological order
    # why doing this: we want to start collecting the edges
    # for nodes with higher order
    topological_index = {}
    for i, e in enumerate(bfs_iterator(mst_tree, source=root)):
        topological_index[int(e.target())] = i

    try:
        sorted_obs = sorted(set(terminals) - {root},
                            key=lambda o:
                            (infection_times[o], topological_index[o]))
    except KeyError:
        raise TreeNotFound(
            "it's likely that the input cannot produce a feasible solution, " +
            "because the topological sort on terminals does not visit all terminals"
        )

    # next, we start reconstructing the minimum steiner arborescence
    tree_nodes = {root}
    tree_edges = set()
    # print('root', root)
    for u in sorted_obs:
        if u in tree_nodes:
            if debug:
                print('{} covered already'.format(u))
            continue
        # print(u)
        v, u = map(int, next(mst_tree.vertex(u).in_edges()))  # v is ancestor
        tree_nodes.add(v)

        late_nodes = [
            n for n in terminals if infection_times[n] > infection_times[u]
        ]
        vis = init_visitor(g, u)
        # from child to any tree node, including v

        cpbfs_search(g,
                     source=u,
                     terminals=list(tree_nodes),
                     forbidden_nodes=late_nodes,
                     visitor=vis,
                     count_threshold=1)
        # dist, pred = shortest_distance(g, source=u, pred_map=True)
        node_set = {v for v, d in vis.dist.items() if d > 0}
        reachable_tree_nodes = node_set.intersection(tree_nodes)
        ancestor = min(reachable_tree_nodes, key=vis.dist.__getitem__)

        edges = extract_edges_from_pred(g, u, ancestor, vis.pred)
        edges = {(j, i) for i, j in edges}  # need to reverse it
        if debug:
            print('tree_nodes', tree_nodes)
            print('connecting {} to {}'.format(v, u))
            print('using ancestor {}'.format(ancestor))
            print('adding edges {}'.format(edges))
        tree_nodes |= {u for e in edges for u in e}

        tree_edges |= edges

    t = Graph(directed=True)
    t.add_vertex(g.num_vertices())

    for u, v in tree_edges:
        t.add_edge(t.vertex(u), t.vertex(v))

    tree_nodes = {u for e in tree_edges for u in e}
    vfilt = t.new_vertex_property('bool')
    vfilt.a = False
    for v in tree_nodes:
        vfilt[t.vertex(v)] = True

    t.set_vertex_filter(vfilt)

    if return_closure:
        return t, gc, mst_tree
    else:
        return t
示例#8
0
def run(
        input_file: KGTKFiles,
        output_file: KGTKFiles,
        root: typing.Optional[typing.List[str]],
        rootfile,
        rootfilecolumn,
        subject_column_name: typing.Optional[str],
        object_column_name: typing.Optional[str],
        predicate_column_name: typing.Optional[str],
        props: typing.Optional[typing.List[str]],
        undirected: bool,
        label: str,
        selflink_bool: bool,
        show_properties: bool,
        breadth_first: bool,
        errors_to_stdout: bool,
        errors_to_stderr: bool,
        show_options: bool,
        verbose: bool,
        very_verbose: bool,
        **kwargs,  # Whatever KgtkFileOptions and KgtkValueOptions want.
):
    import sys
    import csv
    from pathlib import Path
    import time
    from graph_tool.search import dfs_iterator, bfs_iterator
    # from graph_tool import load_graph_from_csv
    from graph_tool.util import find_edge
    from kgtk.exceptions import KGTKException
    from kgtk.cli_argparse import KGTKArgumentParser

    from kgtk.gt.gt_load import load_graph_from_kgtk
    from kgtk.io.kgtkwriter import KgtkWriter
    from kgtk.io.kgtkreader import KgtkReader, KgtkReaderOptions
    from kgtk.value.kgtkvalueoptions import KgtkValueOptions

    #Graph-tool names columns that are not subject or object c0, c1... This function finds the number that graph tool assigned to the predicate column
    def find_pred_position(sub, pred, obj):
        if pred < sub and pred < obj:
            return pred
        elif (pred > sub and pred < obj) or (pred < sub and pred > obj):
            return pred - 1
        else:
            return pred - 2

    def get_edges_by_edge_prop(g, p, v):
        return find_edge(g, prop=g.properties[('e', p)], match=v)

    input_kgtk_file: Path = KGTKArgumentParser.get_input_file(input_file)
    output_kgtk_file: Path = KGTKArgumentParser.get_output_file(output_file)

    # Select where to send error messages, defaulting to stderr.
    error_file: typing.TextIO = sys.stdout if errors_to_stdout else sys.stderr

    # Build the option structures.
    input_reader_options: KgtkReaderOptions = KgtkReaderOptions.from_dict(
        kwargs, who="input", fallback=True)
    root_reader_options: KgtkReaderOptions = KgtkReaderOptions.from_dict(
        kwargs, who="root", fallback=True)
    value_options: KgtkValueOptions = KgtkValueOptions.from_dict(kwargs)

    if root is None:
        root = []  # This simplifies matters.

    if props is None:
        props = []  # This simplifies matters.

    if show_options:
        if root is not None:
            print("--root %s" % " ".join(root), file=error_file)
        if rootfile is not None:
            print("--rootfile=%s" % rootfile, file=error_file)
        if subject_column_name is not None:
            print("--subj=%s" % subject_column_name, file=error_file)
        if object_column_name is not None:
            print("--obj=%s" % object_column_name, file=error_file)
        if predicate_column_name is not None:
            print("--pred=%s" % predicate_column_name, file=error_file)
        if props is not None:
            print("--props=%s" % " ".join(props), file=error_file)
        print("--undirected=%s" % str(undirected), file=error_file)
        print("--label=%s" % label, file=error_file)
        print("--selflink=%s" % str(selflink_bool), file=error_file)
        print("--breadth-first=%s" % str(breadth_first), file=error_file)
        input_reader_options.show(out=error_file)
        root_reader_options.show(out=error_file)
        value_options.show(out=error_file)
        KgtkReader.show_debug_arguments(errors_to_stdout=errors_to_stdout,
                                        errors_to_stderr=errors_to_stderr,
                                        show_options=show_options,
                                        verbose=verbose,
                                        very_verbose=very_verbose,
                                        out=error_file)
        print("=======", file=error_file, flush=True)

    root_set: typing.Set = set()
    property_list: typing.List = list()

    if rootfile is not None:
        if verbose:
            print("Reading the root file %s" % repr(rootfile),
                  file=error_file,
                  flush=True)
        root_kr: KgtkReader = KgtkReader.open(
            Path(rootfile),
            error_file=error_file,
            who="root",
            options=root_reader_options,
            value_options=value_options,
            verbose=verbose,
            very_verbose=very_verbose,
        )

        rootcol: int
        if root_kr.is_edge_file:
            rootcol = int(
                rootfilecolumn
            ) if rootfilecolumn is not None and rootfilecolumn.isdigit(
            ) else root_kr.get_node1_column_index(rootfilecolumn)
        elif root_kr.is_node_file:
            rootcol = int(
                rootfilecolumn
            ) if rootfilecolumn is not None and rootfilecolumn.isdigit(
            ) else root_kr.get_id_column_index(rootfilecolumn)
        elif rootfilecolumn is not None:
            rootcol = int(
                rootfilecolumn
            ) if rootfilecolumn is not None and rootfilecolumn.isdigit(
            ) else root_kr.column_name_map.get(rootfilecolumn, -1)
        else:
            root_kr.close()
            raise KGTKException(
                "The root file is neither an edge nor a node file and the root column name was not supplied."
            )

        if rootcol < 0:
            root_kr.close()
            raise KGTKException("Unknown root column %s" %
                                repr(rootfilecolumn))

        for row in root_kr:
            rootnode: str = row[rootcol]
            root_set.add(rootnode)
        root_kr.close()

    if len(root) > 0:
        if verbose:
            print("Adding root nodes from the command line.",
                  file=error_file,
                  flush=True)
        root_group: str
        for root_group in root:
            r: str
            for r in root_group.split(','):
                if verbose:
                    print("... adding %s" % repr(r),
                          file=error_file,
                          flush=True)
                root_set.add(r)
    if len(root_set) == 0:
        print(
            "Warning: No nodes in the root set, the output file will be empty.",
            file=error_file,
            flush=True)
    elif verbose:
        print("%d nodes in the root set." % len(root_set),
              file=error_file,
              flush=True)

    kr: KgtkReader = KgtkReader.open(
        input_kgtk_file,
        error_file=error_file,
        who="input",
        options=input_reader_options,
        value_options=value_options,
        verbose=verbose,
        very_verbose=very_verbose,
    )
    sub: int = kr.get_node1_column_index(subject_column_name)
    if sub < 0:
        print("Unknown subject column %s" % repr(subject_column_name),
              file=error_file,
              flush=True)

    pred: int = kr.get_label_column_index(predicate_column_name)
    if pred < 0:
        print("Unknown predicate column %s" % repr(predicate_column_name),
              file=error_file,
              flush=True)

    obj: int = kr.get_node2_column_index(object_column_name)
    if obj < 0:
        print("Unknown object column %s" % repr(object_column_name),
              file=error_file,
              flush=True)

    if sub < 0 or pred < 0 or obj < 0:
        kr.close()
        raise KGTKException("Exiting due to unknown column.")

    if verbose:
        print("special columns: sub=%d pred=%d obj=%d" % (sub, pred, obj),
              file=error_file,
              flush=True)

    # G = load_graph_from_csv(filename,not(undirected),skip_first=not(header_bool),hashed=True,csv_options={'delimiter': '\t'},ecols=(sub,obj))
    G = load_graph_from_kgtk(kr,
                             directed=not undirected,
                             ecols=(sub, obj),
                             verbose=verbose,
                             out=error_file)

    name = G.vp[
        "name"]  # Get the vertix name property map (vertex to ndoe1 (subject) name)

    if show_properties:
        print("Graph name=%s" % name, file=error_file, flush=True)
        print("Graph properties:", file=error_file, flush=True)
        key: typing.Any
        for key in G.properties:
            print("    %s: %s" % (repr(key), repr(G.properties[key])),
                  file=error_file,
                  flush=True)

    index_list = []
    for v in G.vertices():
        if name[v] in root_set:
            index_list.append(v)
    if len(index_list) == 0:
        print(
            "Warning: No root nodes found in the graph, the output file will be empty.",
            file=error_file,
            flush=True)
    elif verbose:
        print("%d root nodes found in the graph." % len(index_list),
              file=error_file,
              flush=True)

    if len(props) > 0:
        # Since the root file is a KGTK file, the columns will have names.
        # pred_label: str = 'c'+str(find_pred_position(sub, pred, obj))
        pred_label: str = kr.column_names[pred]
        if verbose:
            print("pred_label=%s" % repr(pred_label),
                  file=error_file,
                  flush=True)

        property_list = []
        prop_group: str
        for prop_group in props:
            prop: str
            for prop in prop_group.split(','):
                property_list.append(prop)
        if verbose:
            print("property list=%s" % " ".join(property_list),
                  file=error_file,
                  flush=True)

        edge_filter_set = set()
        for prop in property_list:
            edge_filter_set.update(get_edges_by_edge_prop(G, pred_label, prop))
        G.clear_edges()
        G.add_edge_list(list(edge_filter_set))

    output_header: typing.List[str] = ['node1', 'label', 'node2']

    kw: KgtkWriter = KgtkWriter.open(output_header,
                                     output_kgtk_file,
                                     mode=KgtkWriter.Mode.EDGE,
                                     require_all_columns=True,
                                     prohibit_extra_columns=True,
                                     fill_missing_columns=False,
                                     verbose=verbose,
                                     very_verbose=very_verbose)
    for index in index_list:
        if selflink_bool:
            kw.writerow([name[index], label, name[index]])

        if breadth_first:
            for e in bfs_iterator(G, G.vertex(index)):
                kw.writerow([name[index], label, name[e.target()]])
        else:
            for e in dfs_iterator(G, G.vertex(index)):
                kw.writerow([name[index], label, name[e.target()]])

    kw.close()
    kr.close()
def steiner_tree_mst(g,
                     root,
                     infection_times,
                     source,
                     terminals,
                     closure_builder=build_closure,
                     strictly_smaller=True,
                     return_closure=False,
                     k=-1,
                     debug=False,
                     verbose=True):
    gc, eweight, r2pred = closure_builder(g,
                                          root,
                                          terminals,
                                          infection_times,
                                          strictly_smaller=strictly_smaller,
                                          k=k,
                                          debug=debug,
                                          verbose=verbose)

    # get the minimum spanning arborescence
    # graph_tool does not provide minimum_spanning_arborescence
    if verbose:
        print('getting mst')
    gx = gt2nx(gc, root, terminals, edge_attrs={'weight': eweight})
    try:
        nx_tree = nx.minimum_spanning_arborescence(gx, 'weight')
    except nx.exception.NetworkXException:
        if debug:
            print('fail to find mst')
        if return_closure:
            return None, gc, None
        else:
            return None

    if verbose:
        print('returning tree')

    mst_tree = Graph(directed=True)
    for _ in range(g.num_vertices()):
        mst_tree.add_vertex()

    for u, v in nx_tree.edges():
        mst_tree.add_edge(u, v)

    if verbose:
        print('extract edges from original graph')

    # extract the edges from the original graph

    # sort observations by time
    # and also topological order
    topological_index = {}
    for i, e in enumerate(bfs_iterator(mst_tree, source=root)):
        topological_index[int(e.target())] = i
    sorted_obs = sorted(set(terminals) - {root},
                        key=lambda o:
                        (infection_times[o], topological_index[o]))

    tree_nodes = {root}
    tree_edges = set()
    # print('root', root)
    for u in sorted_obs:
        if u in tree_nodes:
            if debug:
                print('{} covered already'.format(u))
            continue
        # print(u)
        v, u = map(int, next(mst_tree.vertex(u).in_edges()))  # v is ancestor
        tree_nodes.add(v)

        late_nodes = [
            n for n in terminals if infection_times[n] > infection_times[u]
        ]
        vis = init_visitor(g, u)
        # from child to any tree node, including v

        cpbfs_search(g,
                     source=u,
                     terminals=list(tree_nodes),
                     forbidden_nodes=late_nodes,
                     visitor=vis,
                     count_threshold=1)
        # dist, pred = shortest_distance(g, source=u, pred_map=True)
        node_set = {v for v, d in vis.dist.items() if d > 0}
        reachable_tree_nodes = node_set.intersection(tree_nodes)
        ancestor = min(reachable_tree_nodes, key=vis.dist.__getitem__)

        edges = extract_edges_from_pred(g, u, ancestor, vis.pred)
        edges = {(j, i) for i, j in edges}  # need to reverse it
        if debug:
            print('tree_nodes', tree_nodes)
            print('connecting {} to {}'.format(v, u))
            print('using ancestor {}'.format(ancestor))
            print('adding edges {}'.format(edges))
        tree_nodes |= {u for e in edges for u in e}

        tree_edges |= edges

    t = Graph(directed=True)
    for _ in range(g.num_vertices()):
        t.add_vertex()

    for u, v in tree_edges:
        t.add_edge(t.vertex(u), t.vertex(v))

    tree_nodes = {u for e in tree_edges for u in e}
    vfilt = t.new_vertex_property('bool')
    vfilt.a = False
    for v in tree_nodes:
        vfilt[t.vertex(v)] = True

    t.set_vertex_filter(vfilt)

    if return_closure:
        return t, gc, mst_tree
    else:
        return t