예제 #1
0
def sample(order, alpha=0.5, beta=0.5):
    """ Generates a random decomposable graph using the Christmas tree algorithm.

    Args:
        internal_nodes (list): list of internal nodes in the generated graph.
        alpha (float): Subtree kernel parameter
        beta (float): Subtree kernel parameter
        directory (string): Path to where the plots should be saved.

    Returns:
        NetworkX graph: A decomposable graph.

    Example:
        >>> g = dlib.sample_dec_graph(5)
        >>> g.edges
        EdgeView([(0, 1), (0, 3), (1, 3), (2, 3)])
        >>> g.nodes
        NodeView((0, 1, 2, 3, 4))

    """

    if type(order) is int:
        nodes = range(order)  # OBS. Python 2.7
        random.shuffle(nodes)
        tree = libj.sample(nodes, alpha, beta)
        return jtlib.graph(tree)
    elif type(order) is list:
        tree = libj.sample(order, alpha, beta)
        return jtlib.graph(tree)
예제 #2
0
def mh(alpha, beta, traj_length, seq_dist,
    jt_traj=None, debug=False):
    """ A Metropolis-Hastings implementation for approximating distributions over
    junction trees.

    Args:
        traj_length (int): Number of Gibbs iterations (samples)
        alpha (float): sparsity parameter for the Christmas tree algorithm
        beta (float): sparsity parameter for the Christmas tree algorithm
        seq_dist (SequentialJTDistributions): the distribution to be sampled from

    Returns:
        mcmctraj.Trajectory: Markov chain of teh underlying graphs of the junction trees sampled by M-H.
    """
    graph_traj = mcmctraj.Trajectory()
    graph_traj.set_sequential_distribution(seq_dist)

    prev_tree = None
    for i in tqdm(range(traj_length), desc="Metropolis-Hastings samples"):
        tree = None
        start_time = time.time()
        if i == 0:
            tree = jtlib.sample(seq_dist.p, alpha, beta)
        else:
            # Sample backwards trajectories
            tree = trans_sample(prev_tree, alpha, beta, seq_dist)
        # Sample T from T_1..p
        end_time = time.time()
        graph_traj.add_sample(jtlib.graph(tree), end_time - start_time)
        prev_tree = tree

    return graph_traj
예제 #3
0
def smc_ggm_graphs(N, alpha, beta, radius, X, D, delta):
    cache = {}
    seq_dist = seqdist.GGMJTPosterior()
    seq_dist.init_model(X, D, delta, cache)
    (trees, log_w) = approximate(N, alpha, beta, radius, seq_dist)
    log_w_rescaled = np.array(log_w.T)[seq_dist.p - 1] - \
                     max(np.array(log_w.T)[seq_dist.p - 1])
    norm_w = np.exp(log_w_rescaled) / sum(np.exp(log_w_rescaled))
    graphs = [jtlib.graph(tree) for tree in trees]
    return (graphs, norm_w)
예제 #4
0
def est_log_norm_consts(order,
                        n_particles,
                        sequential_distribution,
                        alpha=0.5,
                        beta=0.5,
                        n_smc_estimates=1,
                        debug=False):
    log_consts = np.zeros(n_smc_estimates * (order)).reshape(
        n_smc_estimates, (order))

    def estimate_norm_const(order, weights):
        log_consts = np.zeros(order)
        for n in range(1, order):
            log_consts[n] = log_consts[n - 1] + np.log(np.mean(weights[:, n]))

        return log_consts

    for t in tqdm(range(n_smc_estimates), desc="Const estimates"):
        (trees, log_w) = approximate(n_particles, alpha, beta,
                                     sequential_distribution.p,
                                     sequential_distribution)
        w = np.exp(log_w)
        log_consts[t, :] = estimate_norm_const(order, w)

        if debug:
            unique_trees = set()
            for tree in trees:
                tree_alt = (frozenset(tree.nodes()),
                            frozenset([frozenset(e) for e in tree.edges()]))
                unique_trees.add(tree_alt)

            print("Sampled unique junction trees: " + str(len(unique_trees)))
            unique_graphs = set(
                [glib.hash_graph(jtlib.graph(tree)) for tree in trees])

            print(
                "Sampled unique chordal graphs: {n_unique_chordal_graphs}".
                format(n_unique_chordal_graphs=len(unique_graphs)), )

    if n_smc_estimates == 1:
        log_consts = log_consts.flatten()
    return log_consts
예제 #5
0
def sample_dec_graph(internal_nodes, alpha=0.5, beta=0.5, directory='.'):
    """ Generates a random decomposable graph using the Christmas tree algorithm.

    Args:
        internal_nodes (list): list of internal nodes in the generated graph.
        alpha (float): Subtree kernel parameter
        beta (float): Subtree kernel parameter
        directory (string): Path to where the plots should be saved.

    Returns:
        NetworkX graph: A decomposable graph.

    Example:
        >>> g = dlib.sample_dec_graph(5)
        >>> g.edges
        EdgeView([(0, 1), (0, 3), (1, 3), (2, 3)])
        >>> g.nodes
        NodeView((0, 1, 2, 3, 4))

    """

    T = libj.sample(internal_nodes, alpha=alpha, beta=beta)
    return libj.graph(T)
예제 #6
0
def sample_trajectory(smc_N,
                      alpha,
                      beta,
                      radius,
                      n_samples,
                      seq_dist,
                      jt_traj=None,
                      debug=False,
                      reset_cache=True):
    """ A particle Gibbs implementation for approximating distributions over
    junction trees.

    Args:
        smc_N (int): Number of particles in SMC in each Gibbs iteration
        n_samples (int): Number of Gibbs iterations (samples)
        alpha (float): sparsity parameter for the Christmas tree algorithm
        beta (float): sparsity parameter for the Christmas tree algorithm
        radius (float): defines the radius within which ned nodes are selected
        seq_dist (SequentialJTDistributions): the distribution to be sampled from

    Returns:
        Trajectory: Markov chain of the underlying graphs of the junction trees sampled by pgibbs.
    """
    graph_traj = mcmctraj.Trajectory()
    graph_traj.set_sampling_method({
        "method": "pgibbs",
        "params": {
            "N": smc_N,
            "alpha": alpha,
            "beta": beta,
            "radius": radius
        }
    })
    graph_traj.set_sequential_distribution(seq_dist)
    neig_set_cache = {}
    (trees, log_w) = (None, None)
    prev_tree = None
    for i in tqdm(range(n_samples), desc="Particle Gibbs samples"):

        if reset_cache is True:
            seq_dist.cache = {}

        start_time = time.time()
        if i == 0:
            #start_graph = nx.Graph()
            #start_graph.add_nodes_from(range(seqdist.p))
            #start_tree = dlib.junction_tree(start_graph)
            (trees, log_w) = approximate(smc_N,
                                         alpha,
                                         beta,
                                         radius,
                                         seq_dist,
                                         neig_set_cache=neig_set_cache)
        else:
            # Sample backwards trajectories
            perm_traj = sp.backward_perm_traj_sample(seq_dist.p, radius)
            T_traj = trilearn.graph.junction_tree_collapser.backward_jt_traj_sample(
                perm_traj, prev_tree)
            (trees, log_w,
             Is) = approximate_cond(smc_N,
                                    alpha,
                                    beta,
                                    radius,
                                    seq_dist,
                                    T_traj,
                                    perm_traj,
                                    neig_set_cache=neig_set_cache)
        # Sample T from T_1..p
        log_w_array = np.array(log_w.T)[seq_dist.p - 1]
        log_w_rescaled = log_w_array - max(log_w_array)
        w_rescaled = np.exp(log_w_rescaled)
        norm_w = w_rescaled / sum(w_rescaled)
        I = np.random.choice(smc_N, size=1, p=norm_w)[0]
        T = trees[I]
        prev_tree = T
        graph = jtlib.graph(T)
        end_time = time.time()
        graph_traj.add_sample(graph, end_time - start_time)
    return graph_traj
예제 #7
0
def sample_trajectory(n_samples, randomize, sd):
    graph = nx.Graph()
    graph.add_nodes_from(range(sd.p))
    jt = dlib.junction_tree(graph)
    assert (jtlib.is_junction_tree(jt))
    jt_traj = [None] * n_samples
    graphs = [None] * n_samples
    jt_traj[0] = jt
    graphs[0] = jtlib.graph(jt)
    log_prob_traj = [None] * n_samples

    gtraj = mcmctraj.Trajectory()
    gtraj.set_sampling_method({
        "method": "mh",
        "params": {
            "samples": n_samples,
            "randomize_interval": randomize
        }
    })

    gtraj.set_sequential_distribution(sd)

    log_prob_traj[0] = 0.0
    log_prob_traj[0] = sd.log_likelihood(jtlib.graph(jt_traj[0]))
    log_prob_traj[0] += -jtlib.log_n_junction_trees(
        jt_traj[0], jtlib.separators(jt_traj[0]))

    accept_traj = [0] * n_samples

    MAP_graph = (graphs[0], log_prob_traj[0])

    for i in tqdm(range(1, n_samples), desc="Metropolis-Hastings samples"):
        if log_prob_traj[i - 1] > MAP_graph[1]:
            MAP_graph = (graphs[i - 1], log_prob_traj[i - 1])

        if i % randomize == 0:
            jtlib.randomize(jt)
            graphs[i] = jtlib.graph(jt)  # TODO: Improve.
            log_prob_traj[i] = sd.log_likelihood(
                graphs[i]) - jtlib.log_n_junction_trees(
                    jt, jtlib.separators(jt))

        r = np.random.randint(2)  # Connect / disconnect move
        num_seps = jt.size()
        log_p1 = log_prob_traj[i - 1]
        if r == 0:
            # Connect move
            num_cliques = jt.order()
            conn = aglib.connect_move(
                jt)  # need to move to calculate posterior
            seps_prop = jtlib.separators(jt)
            log_p2 = sd.log_likelihood(
                jtlib.graph(jt)) - jtlib.log_n_junction_trees(jt, seps_prop)

            if not conn:
                log_prob_traj[i] = log_prob_traj[i - 1]
                graphs[i] = graphs[i - 1]
                continue
            C_disconn = conn[2] | conn[3] | conn[4]
            if conn[0] == "a":
                (case, log_q12, X, Y, S, CX_disconn, CY_disconn, XSneig,
                 YSneig) = conn
                (NX_disconn, NY_disconn,
                 N_disconn) = aglib.disconnect_get_neighbors(
                     jt, C_disconn, X, Y)  # TODO: could this be done faster?
                log_q21 = aglib.disconnect_logprob_a(num_cliques - 1, X, Y, S,
                                                     N_disconn)
                #print log_p2, log_q21, log_p1, log_q12
                alpha = min(np.exp(log_p2 + log_q21 - log_p1 - log_q12), 1)
                #print alpha
                samp = np.random.choice(2, 1, p=[(1 - alpha), alpha])
                if samp == 1:
                    # print "Accept"
                    accept_traj[i] = 1
                    log_prob_traj[i] = log_p2
                    graphs[i] = jtlib.graph(jt)  # TODO: Improve.
                else:
                    #print "Reject"
                    aglib.disconnect_a(jt, C_disconn, X, Y, CX_disconn,
                                       CY_disconn, XSneig, YSneig)
                    log_prob_traj[i] = log_prob_traj[i - 1]
                    graphs[i] = graphs[i - 1]
                    continue

            elif conn[0] == "b":
                (case, log_q12, X, Y, S, CX_disconn, CY_disconn) = conn
                log_q21 = aglib.disconnect_logprob_bcd(num_cliques, X, Y, S)
                alpha = min(np.exp(log_p2 + log_q21 - log_p1 - log_q12), 1)
                samp = np.random.choice(2, 1, p=[(1 - alpha), alpha])
                if samp == 1:
                    #print "Accept"
                    accept_traj[i] = 1
                    log_prob_traj[i] = log_p2
                    graphs[i] = jtlib.graph(jt)  # TODO: Improve.
                else:
                    #print "Reject"
                    aglib.disconnect_b(jt, C_disconn, X, Y, CX_disconn,
                                       CY_disconn)
                    log_prob_traj[i] = log_prob_traj[i - 1]
                    graphs[i] = graphs[i - 1]
                    continue

            elif conn[0] == "c":
                (case, log_q12, X, Y, S, CX_disconn, CY_disconn) = conn
                log_q21 = aglib.disconnect_logprob_bcd(num_cliques, X, Y, S)
                alpha = min(np.exp(log_p2 + log_q21 - log_p1 - log_q12), 1)
                samp = np.random.choice(2, 1, p=[(1 - alpha), alpha])
                if samp == 1:
                    accept_traj[i] = 1
                    #print "Accept"
                    log_prob_traj[i] = log_p2
                    graphs[i] = jtlib.graph(jt)  # TODO: Improve.
                else:
                    #print "Reject"
                    aglib.disconnect_c(jt, C_disconn, X, Y, CX_disconn,
                                       CY_disconn)
                    log_prob_traj[i] = log_prob_traj[i - 1]
                    graphs[i] = graphs[i - 1]
                    continue

            elif conn[0] == "d":
                (case, log_q12, X, Y, S, CX_disconn, CY_disconn) = conn
                log_q21 = aglib.disconnect_logprob_bcd(num_cliques + 1, X, Y,
                                                       S)
                alpha = min(np.exp(log_p2 + log_q21 - log_p1 - log_q12), 1)
                samp = np.random.choice(2, 1, p=[(1 - alpha), alpha])
                if samp == 1:
                    accept_traj[i] = 1
                    #print "Accept"
                    log_prob_traj[i] = log_p2
                    graphs[i] = jtlib.graph(jt)  # TODO: Improve.
                else:
                    #print "Reject"
                    aglib.disconnect_d(jt, C_disconn, X, Y, CX_disconn,
                                       CY_disconn)
                    log_prob_traj[i] = log_prob_traj[i - 1]
                    graphs[i] = graphs[i - 1]
                    continue

        elif r == 1:
            # Disconnect move
            disconnect = aglib.disconnect_move(
                jt)  # need to move to calculate posterior
            seps_prop = jtlib.separators(jt)
            log_p2 = sd.log_likelihood(
                jtlib.graph(jt)) - jtlib.log_n_junction_trees(jt, seps_prop)

            #assert(jtlib.is_junction_tree(jt))
            #print "disconnect"
            if disconnect is not False:
                if disconnect[0] == "a":
                    (case, log_q12, X, Y, S, CX_conn, CY_conn) = disconnect
                    log_q21 = aglib.connect_logprob(num_seps + 1, X, Y,
                                                    CX_conn, CY_conn)
                    alpha = min(np.exp(log_p2 + log_q21 - log_p1 - log_q12), 1)
                    samp = np.random.choice(2, 1, p=[(1 - alpha), alpha])
                    if samp == 1:
                        accept_traj[i] = 1
                        #print "Accept"
                        log_prob_traj[i] = log_p2
                        graphs[i] = jtlib.graph(jt)  # TODO: Improve.
                    else:
                        #print "Reject"
                        aglib.connect_a(jt, S, X, Y, CX_conn, CY_conn)
                        log_prob_traj[i] = log_prob_traj[i - 1]
                        graphs[i] = graphs[i - 1]
                        continue

                elif disconnect[0] == "b":
                    (case, log_q12, X, Y, S, CX_conn, CY_conn) = disconnect
                    log_q21 = aglib.connect_logprob(num_seps, X, Y, CX_conn,
                                                    CY_conn)
                    alpha = min(np.exp(log_p2 + log_q21 - log_p1 - log_q12), 1)
                    samp = np.random.choice(2, 1, p=[(1 - alpha), alpha])
                    if samp == 1:
                        accept_traj[i] = 1
                        #print "Accept"
                        log_prob_traj[i] = log_p2
                        graphs[i] = jtlib.graph(jt)  # TODO: Improve.
                    else:
                        #print "Reject"
                        aglib.connect_b(jt, S, X, Y, CX_conn, CY_conn)
                        log_prob_traj[i] = log_prob_traj[i - 1]
                        graphs[i] = graphs[i - 1]
                        continue

                elif disconnect[0] == "c":
                    (case, log_q12, X, Y, S, CX_conn, CY_conn) = disconnect
                    log_q21 = aglib.connect_logprob(num_seps, X, Y, CX_conn,
                                                    CY_conn)
                    alpha = min(np.exp(log_p2 + log_q21 - log_p1 - log_q12), 1)
                    samp = np.random.choice(2, 1, p=[(1 - alpha), alpha])
                    if samp == 1:
                        accept_traj[i] = 1
                        #print "Accept"
                        log_prob_traj[i] = log_p2
                        graphs[i] = jtlib.graph(jt)  # TODO: Improve.
                    else:
                        #print "Reject"
                        aglib.connect_c(jt, S, X, Y, CX_conn, CY_conn)
                        log_prob_traj[i] = log_prob_traj[i - 1]
                        graphs[i] = graphs[i - 1]
                        continue

                elif disconnect[0] == "d":
                    (case, log_q12, X, Y, S, CX_conn, CY_conn) = disconnect
                    log_q21 = aglib.connect_logprob(num_seps - 1, X, Y,
                                                    CX_conn, CY_conn)
                    alpha = min(np.exp(log_p2 + log_q21 - log_p1 - log_q12), 1)
                    samp = np.random.choice(2, 1, p=[(1 - alpha), alpha])
                    if samp == 1:
                        #print "Accept"
                        accept_traj[i] = 1
                        log_prob_traj[i] = log_p2
                        graphs[i] = jtlib.graph(jt)  # TODO: Improve.
                    else:
                        #print "Reject"
                        aglib.connect_d(jt, S, X, Y, CX_conn, CY_conn)
                        log_prob_traj[i] = log_prob_traj[i - 1]
                        graphs[i] = graphs[i - 1]
                        continue
            else:
                log_prob_traj[i] = log_prob_traj[i - 1]
                graphs[i] = graphs[i - 1]
                continue
        #print(np.mean(accept_traj[:i]))
    gtraj.set_trajectory(graphs)
    gtraj.logl = log_prob_traj
    return gtraj
예제 #8
0
def uniform_dec_samples(order, n_particles, alpha=0.5, beta=0.5, debug=False):
    sd = seqdist.CondUniformJTDistribution(order)
    (trees, log_w) = approximate(n_particles, alpha, beta, sd.p, sd)
    graphs = [jtlib.graph(tree) for tree in trees]
    return graphs