def test_fill_missing_time():
    """simple chain graph test
    """
    g = Graph(directed=False)
    g.add_vertex(4)
    g.add_edge_list([(0, 1), (1, 2), (2, 3)])

    t = GraphView(g, directed=True)
    efilt = t.new_edge_property('bool')
    efilt.a = True
    efilt[t.edge(2, 3)] = False
    t.set_edge_filter(efilt)
    vfilt = t.new_vertex_property('bool')
    vfilt.a = True
    vfilt[3] = False
    t.set_vertex_filter(vfilt)

    root = 0
    obs_nodes = {0, 2}
    infection_times = [0, 1.5, 3, -1]

    pt = fill_missing_time(g, t, root, obs_nodes, infection_times, debug=False)

    for i in range(4):
        assert pt[i] == infection_times[i]
def min_steiner_tree(g,
                     obs_nodes,
                     p=None,
                     return_type='tree',
                     debug=False,
                     verbose=False):
    assert len(obs_nodes) > 0, 'no terminals'

    if g.num_vertices() == len(obs_nodes):
        print('it\'s a minimum spanning tree problem')

    gc, eweight, r2pred = build_closure(g,
                                        obs_nodes,
                                        p=p,
                                        debug=debug,
                                        verbose=verbose)
    # print('gc', gc)

    tree_map = min_spanning_tree(gc, eweight, root=None)
    tree = GraphView(gc, directed=False, efilt=tree_map)

    tree_edges = set()

    for e in tree.edges():
        u, v = map(int, e)
        recovered_edges = extract_edges_from_pred(u, v, r2pred[u])
        assert recovered_edges, 'empty!'
        for i, j in recovered_edges:
            tree_edges.add((i, j))

    tree_nodes = list(set(itertools.chain(*tree_edges)))

    if return_type == 'nodes':
        return tree_nodes
    elif return_type == 'edges':
        return list(map(edge2tuple, tree_edges))
    elif return_type == 'tree':
        vfilt = g.new_vertex_property('bool')
        vfilt.set_value(False)
        for n in tree_nodes:
            vfilt[n] = True

        efilt = g.new_edge_property('bool')
        for i, j in tree_edges:
            efilt[g.edge(i, j)] = 1
        subg = GraphView(g, efilt=efilt, vfilt=vfilt, directed=False)

        if p is not None:
            weights = subg.new_edge_property('float')
            for e in subg.edges():
                weights[e] = p[e]
        else:
            weights = None
        # remove cycles
        tree_map = min_spanning_tree(subg, weights, root=None)
        t = GraphView(g, directed=False, vfilt=vfilt, efilt=tree_map)
        return t