def test_random_DAG(N=1):
    np.random.seed(12345)
    i = 0
    while i < N:
        p = np.random.uniform(0.25, 1)
        n_v = np.random.randint(5, 50)

        G = random_DAG(n_v, p)
        G_nx = to_networkx(G)

        assert nx.is_directed_acyclic_graph(G_nx)
        print("PASSED")
        i += 1
def test_topological_ordering(N=1):
    np.random.seed(12345)
    i = 0
    while i < N:
        p = np.random.uniform(0.25, 1)
        n_v = np.random.randint(5, 10)

        G = random_DAG(n_v, p)
        G_nx = to_networkx(G)

        if nx.is_directed_acyclic_graph(G_nx):
            topo_order = G.topological_ordering()

            #  test topological order
            seen_it = set()
            for n_i in topo_order:
                seen_it.add(n_i)
                assert any([c_i in seen_it
                            for c_i in G.get_neighbors(n_i)]) == False

            print("PASSED")
            i += 1
def plot_ucb1_gaussian_shortest_path():
    """
    Plot the UCB1 policy on a graph shortest path problem each edge weight
    drawn from an independent univariate Gaussian
    """
    np.random.seed(12345)

    ep_length = 1
    n_duplicates = 5
    n_episodes = 5000
    p = np.random.rand()
    n_vertices = np.random.randint(5, 15)

    Gaussian = namedtuple("Gaussian", ["mean", "variance", "EV", "sample"])

    # create randomly-weighted edges
    print("Building graph")
    E = []
    G = random_DAG(n_vertices, p)
    V = G.vertices
    for e in G.edges:
        mean, var = np.random.uniform(0, 1), np.random.uniform(0, 1)
        w = lambda: np.random.normal(mean, var)  # noqa: E731
        rv = Gaussian(mean, var, mean, w)
        E.append(Edge(e.fr, e.to, rv))

    G = DiGraph(V, E)
    while not G.path_exists(V[0], V[-1]):
        print("Skipping")
        idx = np.random.randint(0, len(V))
        V[idx], V[-1] = V[-1], V[idx]

    mab = ShortestPathBandit(G, V[0], V[-1])
    policy = UCB1(C=1, ev_prior=0.5)
    policy = BanditTrainer().train(policy, mab, ep_length, n_episodes,
                                   n_duplicates)