def mh(alpha, beta, traj_length, seq_dist, jt_traj=None, debug=False): """ A Metropolis-Hastings implementation for approximating distributions over junction trees. Args: traj_length (int): Number of Gibbs iterations (samples) alpha (float): sparsity parameter for the Christmas tree algorithm beta (float): sparsity parameter for the Christmas tree algorithm seq_dist (SequentialJTDistributions): the distribution to be sampled from Returns: mcmctraj.Trajectory: Markov chain of teh underlying graphs of the junction trees sampled by M-H. """ graph_traj = mcmctraj.Trajectory() graph_traj.set_sequential_distribution(seq_dist) prev_tree = None for i in tqdm(range(traj_length), desc="Metropolis-Hastings samples"): tree = None start_time = time.time() if i == 0: tree = jtlib.sample(seq_dist.p, alpha, beta) else: # Sample backwards trajectories tree = trans_sample(prev_tree, alpha, beta, seq_dist) # Sample T from T_1..p end_time = time.time() graph_traj.add_sample(jtlib.graph(tree), end_time - start_time) prev_tree = tree return graph_traj
def read_all_trajectories_in_dir(directory): from trilearn.graph import trajectory as gtraj trajlist = [] for filename in glob.glob(directory + "/*.json"): print("Loading: " + str(filename)) t = gtraj.Trajectory() t.read_file(filename) trajlist.append(t) return group_trajectories_by_setting(trajlist)
def sample_trajectory(smc_N, alpha, beta, radius, n_samples, seq_dist, jt_traj=None, debug=False, reset_cache=True): """ A particle Gibbs implementation for approximating distributions over junction trees. Args: smc_N (int): Number of particles in SMC in each Gibbs iteration n_samples (int): Number of Gibbs iterations (samples) alpha (float): sparsity parameter for the Christmas tree algorithm beta (float): sparsity parameter for the Christmas tree algorithm radius (float): defines the radius within which ned nodes are selected seq_dist (SequentialJTDistributions): the distribution to be sampled from Returns: Trajectory: Markov chain of the underlying graphs of the junction trees sampled by pgibbs. """ graph_traj = mcmctraj.Trajectory() graph_traj.set_sampling_method({ "method": "pgibbs", "params": { "N": smc_N, "alpha": alpha, "beta": beta, "radius": radius } }) graph_traj.set_sequential_distribution(seq_dist) neig_set_cache = {} (trees, log_w) = (None, None) prev_tree = None for i in tqdm(range(n_samples), desc="Particle Gibbs samples"): if reset_cache is True: seq_dist.cache = {} start_time = time.time() if i == 0: #start_graph = nx.Graph() #start_graph.add_nodes_from(range(seqdist.p)) #start_tree = dlib.junction_tree(start_graph) (trees, log_w) = approximate(smc_N, alpha, beta, radius, seq_dist, neig_set_cache=neig_set_cache) else: # Sample backwards trajectories perm_traj = sp.backward_perm_traj_sample(seq_dist.p, radius) T_traj = trilearn.graph.junction_tree_collapser.backward_jt_traj_sample( perm_traj, prev_tree) (trees, log_w, Is) = approximate_cond(smc_N, alpha, beta, radius, seq_dist, T_traj, perm_traj, neig_set_cache=neig_set_cache) # Sample T from T_1..p log_w_array = np.array(log_w.T)[seq_dist.p - 1] log_w_rescaled = log_w_array - max(log_w_array) w_rescaled = np.exp(log_w_rescaled) norm_w = w_rescaled / sum(w_rescaled) I = np.random.choice(smc_N, size=1, p=norm_w)[0] T = trees[I] prev_tree = T graph = jtlib.graph(T) end_time = time.time() graph_traj.add_sample(graph, end_time - start_time) return graph_traj
def sample_trajectory(n_samples, randomize, sd): graph = nx.Graph() graph.add_nodes_from(range(sd.p)) jt = dlib.junction_tree(graph) assert (jtlib.is_junction_tree(jt)) jt_traj = [None] * n_samples graphs = [None] * n_samples jt_traj[0] = jt graphs[0] = jtlib.graph(jt) log_prob_traj = [None] * n_samples gtraj = mcmctraj.Trajectory() gtraj.set_sampling_method({ "method": "mh", "params": { "samples": n_samples, "randomize_interval": randomize } }) gtraj.set_sequential_distribution(sd) log_prob_traj[0] = 0.0 log_prob_traj[0] = sd.log_likelihood(jtlib.graph(jt_traj[0])) log_prob_traj[0] += -jtlib.log_n_junction_trees( jt_traj[0], jtlib.separators(jt_traj[0])) accept_traj = [0] * n_samples MAP_graph = (graphs[0], log_prob_traj[0]) for i in tqdm(range(1, n_samples), desc="Metropolis-Hastings samples"): if log_prob_traj[i - 1] > MAP_graph[1]: MAP_graph = (graphs[i - 1], log_prob_traj[i - 1]) if i % randomize == 0: jtlib.randomize(jt) graphs[i] = jtlib.graph(jt) # TODO: Improve. log_prob_traj[i] = sd.log_likelihood( graphs[i]) - jtlib.log_n_junction_trees( jt, jtlib.separators(jt)) r = np.random.randint(2) # Connect / disconnect move num_seps = jt.size() log_p1 = log_prob_traj[i - 1] if r == 0: # Connect move num_cliques = jt.order() conn = aglib.connect_move( jt) # need to move to calculate posterior seps_prop = jtlib.separators(jt) log_p2 = sd.log_likelihood( jtlib.graph(jt)) - jtlib.log_n_junction_trees(jt, seps_prop) if not conn: log_prob_traj[i] = log_prob_traj[i - 1] graphs[i] = graphs[i - 1] continue C_disconn = conn[2] | conn[3] | conn[4] if conn[0] == "a": (case, log_q12, X, Y, S, CX_disconn, CY_disconn, XSneig, YSneig) = conn (NX_disconn, NY_disconn, N_disconn) = aglib.disconnect_get_neighbors( jt, C_disconn, X, Y) # TODO: could this be done faster? log_q21 = aglib.disconnect_logprob_a(num_cliques - 1, X, Y, S, N_disconn) #print log_p2, log_q21, log_p1, log_q12 alpha = min(np.exp(log_p2 + log_q21 - log_p1 - log_q12), 1) #print alpha samp = np.random.choice(2, 1, p=[(1 - alpha), alpha]) if samp == 1: # print "Accept" accept_traj[i] = 1 log_prob_traj[i] = log_p2 graphs[i] = jtlib.graph(jt) # TODO: Improve. else: #print "Reject" aglib.disconnect_a(jt, C_disconn, X, Y, CX_disconn, CY_disconn, XSneig, YSneig) log_prob_traj[i] = log_prob_traj[i - 1] graphs[i] = graphs[i - 1] continue elif conn[0] == "b": (case, log_q12, X, Y, S, CX_disconn, CY_disconn) = conn log_q21 = aglib.disconnect_logprob_bcd(num_cliques, X, Y, S) alpha = min(np.exp(log_p2 + log_q21 - log_p1 - log_q12), 1) samp = np.random.choice(2, 1, p=[(1 - alpha), alpha]) if samp == 1: #print "Accept" accept_traj[i] = 1 log_prob_traj[i] = log_p2 graphs[i] = jtlib.graph(jt) # TODO: Improve. else: #print "Reject" aglib.disconnect_b(jt, C_disconn, X, Y, CX_disconn, CY_disconn) log_prob_traj[i] = log_prob_traj[i - 1] graphs[i] = graphs[i - 1] continue elif conn[0] == "c": (case, log_q12, X, Y, S, CX_disconn, CY_disconn) = conn log_q21 = aglib.disconnect_logprob_bcd(num_cliques, X, Y, S) alpha = min(np.exp(log_p2 + log_q21 - log_p1 - log_q12), 1) samp = np.random.choice(2, 1, p=[(1 - alpha), alpha]) if samp == 1: accept_traj[i] = 1 #print "Accept" log_prob_traj[i] = log_p2 graphs[i] = jtlib.graph(jt) # TODO: Improve. else: #print "Reject" aglib.disconnect_c(jt, C_disconn, X, Y, CX_disconn, CY_disconn) log_prob_traj[i] = log_prob_traj[i - 1] graphs[i] = graphs[i - 1] continue elif conn[0] == "d": (case, log_q12, X, Y, S, CX_disconn, CY_disconn) = conn log_q21 = aglib.disconnect_logprob_bcd(num_cliques + 1, X, Y, S) alpha = min(np.exp(log_p2 + log_q21 - log_p1 - log_q12), 1) samp = np.random.choice(2, 1, p=[(1 - alpha), alpha]) if samp == 1: accept_traj[i] = 1 #print "Accept" log_prob_traj[i] = log_p2 graphs[i] = jtlib.graph(jt) # TODO: Improve. else: #print "Reject" aglib.disconnect_d(jt, C_disconn, X, Y, CX_disconn, CY_disconn) log_prob_traj[i] = log_prob_traj[i - 1] graphs[i] = graphs[i - 1] continue elif r == 1: # Disconnect move disconnect = aglib.disconnect_move( jt) # need to move to calculate posterior seps_prop = jtlib.separators(jt) log_p2 = sd.log_likelihood( jtlib.graph(jt)) - jtlib.log_n_junction_trees(jt, seps_prop) #assert(jtlib.is_junction_tree(jt)) #print "disconnect" if disconnect is not False: if disconnect[0] == "a": (case, log_q12, X, Y, S, CX_conn, CY_conn) = disconnect log_q21 = aglib.connect_logprob(num_seps + 1, X, Y, CX_conn, CY_conn) alpha = min(np.exp(log_p2 + log_q21 - log_p1 - log_q12), 1) samp = np.random.choice(2, 1, p=[(1 - alpha), alpha]) if samp == 1: accept_traj[i] = 1 #print "Accept" log_prob_traj[i] = log_p2 graphs[i] = jtlib.graph(jt) # TODO: Improve. else: #print "Reject" aglib.connect_a(jt, S, X, Y, CX_conn, CY_conn) log_prob_traj[i] = log_prob_traj[i - 1] graphs[i] = graphs[i - 1] continue elif disconnect[0] == "b": (case, log_q12, X, Y, S, CX_conn, CY_conn) = disconnect log_q21 = aglib.connect_logprob(num_seps, X, Y, CX_conn, CY_conn) alpha = min(np.exp(log_p2 + log_q21 - log_p1 - log_q12), 1) samp = np.random.choice(2, 1, p=[(1 - alpha), alpha]) if samp == 1: accept_traj[i] = 1 #print "Accept" log_prob_traj[i] = log_p2 graphs[i] = jtlib.graph(jt) # TODO: Improve. else: #print "Reject" aglib.connect_b(jt, S, X, Y, CX_conn, CY_conn) log_prob_traj[i] = log_prob_traj[i - 1] graphs[i] = graphs[i - 1] continue elif disconnect[0] == "c": (case, log_q12, X, Y, S, CX_conn, CY_conn) = disconnect log_q21 = aglib.connect_logprob(num_seps, X, Y, CX_conn, CY_conn) alpha = min(np.exp(log_p2 + log_q21 - log_p1 - log_q12), 1) samp = np.random.choice(2, 1, p=[(1 - alpha), alpha]) if samp == 1: accept_traj[i] = 1 #print "Accept" log_prob_traj[i] = log_p2 graphs[i] = jtlib.graph(jt) # TODO: Improve. else: #print "Reject" aglib.connect_c(jt, S, X, Y, CX_conn, CY_conn) log_prob_traj[i] = log_prob_traj[i - 1] graphs[i] = graphs[i - 1] continue elif disconnect[0] == "d": (case, log_q12, X, Y, S, CX_conn, CY_conn) = disconnect log_q21 = aglib.connect_logprob(num_seps - 1, X, Y, CX_conn, CY_conn) alpha = min(np.exp(log_p2 + log_q21 - log_p1 - log_q12), 1) samp = np.random.choice(2, 1, p=[(1 - alpha), alpha]) if samp == 1: #print "Accept" accept_traj[i] = 1 log_prob_traj[i] = log_p2 graphs[i] = jtlib.graph(jt) # TODO: Improve. else: #print "Reject" aglib.connect_d(jt, S, X, Y, CX_conn, CY_conn) log_prob_traj[i] = log_prob_traj[i - 1] graphs[i] = graphs[i - 1] continue else: log_prob_traj[i] = log_prob_traj[i - 1] graphs[i] = graphs[i - 1] continue #print(np.mean(accept_traj[:i])) gtraj.set_trajectory(graphs) gtraj.logl = log_prob_traj return gtraj