def remove_redundant_edges_from_tree(g, tree, r, terminals):
    """given a set of edges, a root, and terminals to cover,
    return a new tree with redundant edges removed"""
    efilt = g.new_edge_property('bool')
    for u, v in tree:
        efilt[g.edge(u, v)] = True
    tree = GraphView(g, efilt=efilt)

    # remove redundant edges
    min_tree_efilt = g.new_edge_property('bool')
    min_tree_efilt.set_2d_array(np.zeros(g.num_edges()))
    for o in terminals:
        if o != r:
            tree.vertex(r)
            tree.vertex(o)
            _, edge_list = shortest_path(tree,
                                         source=tree.vertex(r),
                                         target=tree.vertex(o))
            assert len(edge_list) > 0, 'unable to reach {} from {}'.format(
                o, r)
            for e in edge_list:
                min_tree_efilt[e] = True
    min_tree = GraphView(g, efilt=min_tree_efilt)
    return min_tree
def is_order_respected(tree, root, obs_nodes, infection_times):
    tree = GraphView(tree)
    obs_set = set(obs_nodes)
    vfilt = tree.new_vertex_property('bool')
    vfilt.a = True
    tree.set_vertex_filter(vfilt)

    leaves = [o for o in obs_nodes if tree.vertex(o).out_degree() == 0]
    vis = init_visitor(tree, root)
    pbfs_search(tree, root, terminals=leaves, visitor=vis, count_threshold=-1)
    for l in leaves:
        edges = extract_edges_from_pred(tree, root, l, vis.pred)
        edges = edges[::-1]
        path = list(edges[0]) + [u for _, u in edges[1:]]
        useful_nodes_on_path = [v for v in path if v in obs_set]

        for i in range(len(useful_nodes_on_path)-1):
            u, v = useful_nodes_on_path[i: i+2]
            if infection_times[u] > infection_times[v]:
                return False
    return True
def is_order_respected(tree, root, obs_nodes, infection_times):
    tree = GraphView(tree)
    obs_set = set(obs_nodes)
    vfilt = tree.new_vertex_property('bool')
    vfilt.a = True
    tree.set_vertex_filter(vfilt)

    leaves = [o for o in obs_nodes if tree.vertex(o).out_degree() == 0]
    vis = init_visitor(tree, root)
    pbfs_search(tree, root, terminals=leaves, visitor=vis, count_threshold=-1)
    for l in leaves:
        edges = extract_edges_from_pred(tree, root, l, vis.pred)
        edges = edges[::-1]
        path = list(edges[0]) + [u for _, u in edges[1:]]
        useful_nodes_on_path = [v for v in path if v in obs_set]

        for i in range(len(useful_nodes_on_path)-1):
            u, v = useful_nodes_on_path[i: i+2]
            if infection_times[u] > infection_times[v]:
                return False
    return True
def find_tree_by_closure(g,
                         root,
                         infection_times,
                         terminals,
                         closure_builder=build_closure_with_order,
                         strictly_smaller=True,
                         return_closure=False,
                         k=-1,
                         debug=False,
                         verbose=True):
    """find the steiner tree by trainsitive closure
    
    """
    gc, eweight = closure_builder(g,
                                  root,
                                  terminals,
                                  infection_times,
                                  strictly_smaller=strictly_smaller,
                                  k=k,
                                  return_r2pred=False,
                                  debug=debug,
                                  verbose=verbose)

    # get the minimum spanning arborescence
    # graph_tool does not provide minimum_spanning_arborescence
    if verbose:
        print('getting mst')
    tree_edges = find_minimum_branching(gc, [root], weights=eweight)

    efilt = gc.new_edge_property('bool')
    efilt.a = False
    for u, v in tree_edges:
        efilt[gc.edge(u, v)] = True

    mst_tree = GraphView(gc, efilt=efilt)

    if verbose:
        print('extract edges from original graph')

    # extract the edges from the original graph

    # sort observations by time
    # and also topological order
    # why doing this: we want to start collecting the edges
    # for nodes with higher order
    topological_index = {}
    for i, e in enumerate(bfs_iterator(mst_tree, source=root)):
        topological_index[int(e.target())] = i

    try:
        sorted_obs = sorted(set(terminals) - {root},
                            key=lambda o:
                            (infection_times[o], topological_index[o]))
    except KeyError:
        raise TreeNotFound(
            "it's likely that the input cannot produce a feasible solution, " +
            "because the topological sort on terminals does not visit all terminals"
        )

    # next, we start reconstructing the minimum steiner arborescence
    tree_nodes = {root}
    tree_edges = set()
    # print('root', root)
    for u in sorted_obs:
        if u in tree_nodes:
            if debug:
                print('{} covered already'.format(u))
            continue
        # print(u)
        v, u = map(int, next(mst_tree.vertex(u).in_edges()))  # v is ancestor
        tree_nodes.add(v)

        late_nodes = [
            n for n in terminals if infection_times[n] > infection_times[u]
        ]
        vis = init_visitor(g, u)
        # from child to any tree node, including v

        cpbfs_search(g,
                     source=u,
                     terminals=list(tree_nodes),
                     forbidden_nodes=late_nodes,
                     visitor=vis,
                     count_threshold=1)
        # dist, pred = shortest_distance(g, source=u, pred_map=True)
        node_set = {v for v, d in vis.dist.items() if d > 0}
        reachable_tree_nodes = node_set.intersection(tree_nodes)
        ancestor = min(reachable_tree_nodes, key=vis.dist.__getitem__)

        edges = extract_edges_from_pred(g, u, ancestor, vis.pred)
        edges = {(j, i) for i, j in edges}  # need to reverse it
        if debug:
            print('tree_nodes', tree_nodes)
            print('connecting {} to {}'.format(v, u))
            print('using ancestor {}'.format(ancestor))
            print('adding edges {}'.format(edges))
        tree_nodes |= {u for e in edges for u in e}

        tree_edges |= edges

    t = Graph(directed=True)
    t.add_vertex(g.num_vertices())

    for u, v in tree_edges:
        t.add_edge(t.vertex(u), t.vertex(v))

    tree_nodes = {u for e in tree_edges for u in e}
    vfilt = t.new_vertex_property('bool')
    vfilt.a = False
    for v in tree_nodes:
        vfilt[t.vertex(v)] = True

    t.set_vertex_filter(vfilt)

    if return_closure:
        return t, gc, mst_tree
    else:
        return t