Exemple #1
0
def mh(alpha, beta, traj_length, seq_dist,
    jt_traj=None, debug=False):
    """ A Metropolis-Hastings implementation for approximating distributions over
    junction trees.

    Args:
        traj_length (int): Number of Gibbs iterations (samples)
        alpha (float): sparsity parameter for the Christmas tree algorithm
        beta (float): sparsity parameter for the Christmas tree algorithm
        seq_dist (SequentialJTDistributions): the distribution to be sampled from

    Returns:
        mcmctraj.Trajectory: Markov chain of teh underlying graphs of the junction trees sampled by M-H.
    """
    graph_traj = mcmctraj.Trajectory()
    graph_traj.set_sequential_distribution(seq_dist)

    prev_tree = None
    for i in tqdm(range(traj_length), desc="Metropolis-Hastings samples"):
        tree = None
        start_time = time.time()
        if i == 0:
            tree = jtlib.sample(seq_dist.p, alpha, beta)
        else:
            # Sample backwards trajectories
            tree = trans_sample(prev_tree, alpha, beta, seq_dist)
        # Sample T from T_1..p
        end_time = time.time()
        graph_traj.add_sample(jtlib.graph(tree), end_time - start_time)
        prev_tree = tree

    return graph_traj
Exemple #2
0
def read_all_trajectories_in_dir(directory):

    from trilearn.graph import trajectory as gtraj
    trajlist = []

    for filename in glob.glob(directory + "/*.json"):
        print("Loading: " + str(filename))
        t = gtraj.Trajectory()
        t.read_file(filename)
        trajlist.append(t)

    return group_trajectories_by_setting(trajlist)
Exemple #3
0
def sample_trajectory(smc_N,
                      alpha,
                      beta,
                      radius,
                      n_samples,
                      seq_dist,
                      jt_traj=None,
                      debug=False,
                      reset_cache=True):
    """ A particle Gibbs implementation for approximating distributions over
    junction trees.

    Args:
        smc_N (int): Number of particles in SMC in each Gibbs iteration
        n_samples (int): Number of Gibbs iterations (samples)
        alpha (float): sparsity parameter for the Christmas tree algorithm
        beta (float): sparsity parameter for the Christmas tree algorithm
        radius (float): defines the radius within which ned nodes are selected
        seq_dist (SequentialJTDistributions): the distribution to be sampled from

    Returns:
        Trajectory: Markov chain of the underlying graphs of the junction trees sampled by pgibbs.
    """
    graph_traj = mcmctraj.Trajectory()
    graph_traj.set_sampling_method({
        "method": "pgibbs",
        "params": {
            "N": smc_N,
            "alpha": alpha,
            "beta": beta,
            "radius": radius
        }
    })
    graph_traj.set_sequential_distribution(seq_dist)
    neig_set_cache = {}
    (trees, log_w) = (None, None)
    prev_tree = None
    for i in tqdm(range(n_samples), desc="Particle Gibbs samples"):

        if reset_cache is True:
            seq_dist.cache = {}

        start_time = time.time()
        if i == 0:
            #start_graph = nx.Graph()
            #start_graph.add_nodes_from(range(seqdist.p))
            #start_tree = dlib.junction_tree(start_graph)
            (trees, log_w) = approximate(smc_N,
                                         alpha,
                                         beta,
                                         radius,
                                         seq_dist,
                                         neig_set_cache=neig_set_cache)
        else:
            # Sample backwards trajectories
            perm_traj = sp.backward_perm_traj_sample(seq_dist.p, radius)
            T_traj = trilearn.graph.junction_tree_collapser.backward_jt_traj_sample(
                perm_traj, prev_tree)
            (trees, log_w,
             Is) = approximate_cond(smc_N,
                                    alpha,
                                    beta,
                                    radius,
                                    seq_dist,
                                    T_traj,
                                    perm_traj,
                                    neig_set_cache=neig_set_cache)
        # Sample T from T_1..p
        log_w_array = np.array(log_w.T)[seq_dist.p - 1]
        log_w_rescaled = log_w_array - max(log_w_array)
        w_rescaled = np.exp(log_w_rescaled)
        norm_w = w_rescaled / sum(w_rescaled)
        I = np.random.choice(smc_N, size=1, p=norm_w)[0]
        T = trees[I]
        prev_tree = T
        graph = jtlib.graph(T)
        end_time = time.time()
        graph_traj.add_sample(graph, end_time - start_time)
    return graph_traj
def sample_trajectory(n_samples, randomize, sd):
    graph = nx.Graph()
    graph.add_nodes_from(range(sd.p))
    jt = dlib.junction_tree(graph)
    assert (jtlib.is_junction_tree(jt))
    jt_traj = [None] * n_samples
    graphs = [None] * n_samples
    jt_traj[0] = jt
    graphs[0] = jtlib.graph(jt)
    log_prob_traj = [None] * n_samples

    gtraj = mcmctraj.Trajectory()
    gtraj.set_sampling_method({
        "method": "mh",
        "params": {
            "samples": n_samples,
            "randomize_interval": randomize
        }
    })

    gtraj.set_sequential_distribution(sd)

    log_prob_traj[0] = 0.0
    log_prob_traj[0] = sd.log_likelihood(jtlib.graph(jt_traj[0]))
    log_prob_traj[0] += -jtlib.log_n_junction_trees(
        jt_traj[0], jtlib.separators(jt_traj[0]))

    accept_traj = [0] * n_samples

    MAP_graph = (graphs[0], log_prob_traj[0])

    for i in tqdm(range(1, n_samples), desc="Metropolis-Hastings samples"):
        if log_prob_traj[i - 1] > MAP_graph[1]:
            MAP_graph = (graphs[i - 1], log_prob_traj[i - 1])

        if i % randomize == 0:
            jtlib.randomize(jt)
            graphs[i] = jtlib.graph(jt)  # TODO: Improve.
            log_prob_traj[i] = sd.log_likelihood(
                graphs[i]) - jtlib.log_n_junction_trees(
                    jt, jtlib.separators(jt))

        r = np.random.randint(2)  # Connect / disconnect move
        num_seps = jt.size()
        log_p1 = log_prob_traj[i - 1]
        if r == 0:
            # Connect move
            num_cliques = jt.order()
            conn = aglib.connect_move(
                jt)  # need to move to calculate posterior
            seps_prop = jtlib.separators(jt)
            log_p2 = sd.log_likelihood(
                jtlib.graph(jt)) - jtlib.log_n_junction_trees(jt, seps_prop)

            if not conn:
                log_prob_traj[i] = log_prob_traj[i - 1]
                graphs[i] = graphs[i - 1]
                continue
            C_disconn = conn[2] | conn[3] | conn[4]
            if conn[0] == "a":
                (case, log_q12, X, Y, S, CX_disconn, CY_disconn, XSneig,
                 YSneig) = conn
                (NX_disconn, NY_disconn,
                 N_disconn) = aglib.disconnect_get_neighbors(
                     jt, C_disconn, X, Y)  # TODO: could this be done faster?
                log_q21 = aglib.disconnect_logprob_a(num_cliques - 1, X, Y, S,
                                                     N_disconn)
                #print log_p2, log_q21, log_p1, log_q12
                alpha = min(np.exp(log_p2 + log_q21 - log_p1 - log_q12), 1)
                #print alpha
                samp = np.random.choice(2, 1, p=[(1 - alpha), alpha])
                if samp == 1:
                    # print "Accept"
                    accept_traj[i] = 1
                    log_prob_traj[i] = log_p2
                    graphs[i] = jtlib.graph(jt)  # TODO: Improve.
                else:
                    #print "Reject"
                    aglib.disconnect_a(jt, C_disconn, X, Y, CX_disconn,
                                       CY_disconn, XSneig, YSneig)
                    log_prob_traj[i] = log_prob_traj[i - 1]
                    graphs[i] = graphs[i - 1]
                    continue

            elif conn[0] == "b":
                (case, log_q12, X, Y, S, CX_disconn, CY_disconn) = conn
                log_q21 = aglib.disconnect_logprob_bcd(num_cliques, X, Y, S)
                alpha = min(np.exp(log_p2 + log_q21 - log_p1 - log_q12), 1)
                samp = np.random.choice(2, 1, p=[(1 - alpha), alpha])
                if samp == 1:
                    #print "Accept"
                    accept_traj[i] = 1
                    log_prob_traj[i] = log_p2
                    graphs[i] = jtlib.graph(jt)  # TODO: Improve.
                else:
                    #print "Reject"
                    aglib.disconnect_b(jt, C_disconn, X, Y, CX_disconn,
                                       CY_disconn)
                    log_prob_traj[i] = log_prob_traj[i - 1]
                    graphs[i] = graphs[i - 1]
                    continue

            elif conn[0] == "c":
                (case, log_q12, X, Y, S, CX_disconn, CY_disconn) = conn
                log_q21 = aglib.disconnect_logprob_bcd(num_cliques, X, Y, S)
                alpha = min(np.exp(log_p2 + log_q21 - log_p1 - log_q12), 1)
                samp = np.random.choice(2, 1, p=[(1 - alpha), alpha])
                if samp == 1:
                    accept_traj[i] = 1
                    #print "Accept"
                    log_prob_traj[i] = log_p2
                    graphs[i] = jtlib.graph(jt)  # TODO: Improve.
                else:
                    #print "Reject"
                    aglib.disconnect_c(jt, C_disconn, X, Y, CX_disconn,
                                       CY_disconn)
                    log_prob_traj[i] = log_prob_traj[i - 1]
                    graphs[i] = graphs[i - 1]
                    continue

            elif conn[0] == "d":
                (case, log_q12, X, Y, S, CX_disconn, CY_disconn) = conn
                log_q21 = aglib.disconnect_logprob_bcd(num_cliques + 1, X, Y,
                                                       S)
                alpha = min(np.exp(log_p2 + log_q21 - log_p1 - log_q12), 1)
                samp = np.random.choice(2, 1, p=[(1 - alpha), alpha])
                if samp == 1:
                    accept_traj[i] = 1
                    #print "Accept"
                    log_prob_traj[i] = log_p2
                    graphs[i] = jtlib.graph(jt)  # TODO: Improve.
                else:
                    #print "Reject"
                    aglib.disconnect_d(jt, C_disconn, X, Y, CX_disconn,
                                       CY_disconn)
                    log_prob_traj[i] = log_prob_traj[i - 1]
                    graphs[i] = graphs[i - 1]
                    continue

        elif r == 1:
            # Disconnect move
            disconnect = aglib.disconnect_move(
                jt)  # need to move to calculate posterior
            seps_prop = jtlib.separators(jt)
            log_p2 = sd.log_likelihood(
                jtlib.graph(jt)) - jtlib.log_n_junction_trees(jt, seps_prop)

            #assert(jtlib.is_junction_tree(jt))
            #print "disconnect"
            if disconnect is not False:
                if disconnect[0] == "a":
                    (case, log_q12, X, Y, S, CX_conn, CY_conn) = disconnect
                    log_q21 = aglib.connect_logprob(num_seps + 1, X, Y,
                                                    CX_conn, CY_conn)
                    alpha = min(np.exp(log_p2 + log_q21 - log_p1 - log_q12), 1)
                    samp = np.random.choice(2, 1, p=[(1 - alpha), alpha])
                    if samp == 1:
                        accept_traj[i] = 1
                        #print "Accept"
                        log_prob_traj[i] = log_p2
                        graphs[i] = jtlib.graph(jt)  # TODO: Improve.
                    else:
                        #print "Reject"
                        aglib.connect_a(jt, S, X, Y, CX_conn, CY_conn)
                        log_prob_traj[i] = log_prob_traj[i - 1]
                        graphs[i] = graphs[i - 1]
                        continue

                elif disconnect[0] == "b":
                    (case, log_q12, X, Y, S, CX_conn, CY_conn) = disconnect
                    log_q21 = aglib.connect_logprob(num_seps, X, Y, CX_conn,
                                                    CY_conn)
                    alpha = min(np.exp(log_p2 + log_q21 - log_p1 - log_q12), 1)
                    samp = np.random.choice(2, 1, p=[(1 - alpha), alpha])
                    if samp == 1:
                        accept_traj[i] = 1
                        #print "Accept"
                        log_prob_traj[i] = log_p2
                        graphs[i] = jtlib.graph(jt)  # TODO: Improve.
                    else:
                        #print "Reject"
                        aglib.connect_b(jt, S, X, Y, CX_conn, CY_conn)
                        log_prob_traj[i] = log_prob_traj[i - 1]
                        graphs[i] = graphs[i - 1]
                        continue

                elif disconnect[0] == "c":
                    (case, log_q12, X, Y, S, CX_conn, CY_conn) = disconnect
                    log_q21 = aglib.connect_logprob(num_seps, X, Y, CX_conn,
                                                    CY_conn)
                    alpha = min(np.exp(log_p2 + log_q21 - log_p1 - log_q12), 1)
                    samp = np.random.choice(2, 1, p=[(1 - alpha), alpha])
                    if samp == 1:
                        accept_traj[i] = 1
                        #print "Accept"
                        log_prob_traj[i] = log_p2
                        graphs[i] = jtlib.graph(jt)  # TODO: Improve.
                    else:
                        #print "Reject"
                        aglib.connect_c(jt, S, X, Y, CX_conn, CY_conn)
                        log_prob_traj[i] = log_prob_traj[i - 1]
                        graphs[i] = graphs[i - 1]
                        continue

                elif disconnect[0] == "d":
                    (case, log_q12, X, Y, S, CX_conn, CY_conn) = disconnect
                    log_q21 = aglib.connect_logprob(num_seps - 1, X, Y,
                                                    CX_conn, CY_conn)
                    alpha = min(np.exp(log_p2 + log_q21 - log_p1 - log_q12), 1)
                    samp = np.random.choice(2, 1, p=[(1 - alpha), alpha])
                    if samp == 1:
                        #print "Accept"
                        accept_traj[i] = 1
                        log_prob_traj[i] = log_p2
                        graphs[i] = jtlib.graph(jt)  # TODO: Improve.
                    else:
                        #print "Reject"
                        aglib.connect_d(jt, S, X, Y, CX_conn, CY_conn)
                        log_prob_traj[i] = log_prob_traj[i - 1]
                        graphs[i] = graphs[i - 1]
                        continue
            else:
                log_prob_traj[i] = log_prob_traj[i - 1]
                graphs[i] = graphs[i - 1]
                continue
        #print(np.mean(accept_traj[:i]))
    gtraj.set_trajectory(graphs)
    gtraj.logl = log_prob_traj
    return gtraj