def sample(order, alpha=0.5, beta=0.5): """ Generates a random decomposable graph using the Christmas tree algorithm. Args: internal_nodes (list): list of internal nodes in the generated graph. alpha (float): Subtree kernel parameter beta (float): Subtree kernel parameter directory (string): Path to where the plots should be saved. Returns: NetworkX graph: A decomposable graph. Example: >>> g = dlib.sample_dec_graph(5) >>> g.edges EdgeView([(0, 1), (0, 3), (1, 3), (2, 3)]) >>> g.nodes NodeView((0, 1, 2, 3, 4)) """ if type(order) is int: nodes = range(order) # OBS. Python 2.7 random.shuffle(nodes) tree = libj.sample(nodes, alpha, beta) return jtlib.graph(tree) elif type(order) is list: tree = libj.sample(order, alpha, beta) return jtlib.graph(tree)
def mh(alpha, beta, traj_length, seq_dist, jt_traj=None, debug=False): """ A Metropolis-Hastings implementation for approximating distributions over junction trees. Args: traj_length (int): Number of Gibbs iterations (samples) alpha (float): sparsity parameter for the Christmas tree algorithm beta (float): sparsity parameter for the Christmas tree algorithm seq_dist (SequentialJTDistributions): the distribution to be sampled from Returns: mcmctraj.Trajectory: Markov chain of teh underlying graphs of the junction trees sampled by M-H. """ graph_traj = mcmctraj.Trajectory() graph_traj.set_sequential_distribution(seq_dist) prev_tree = None for i in tqdm(range(traj_length), desc="Metropolis-Hastings samples"): tree = None start_time = time.time() if i == 0: tree = jtlib.sample(seq_dist.p, alpha, beta) else: # Sample backwards trajectories tree = trans_sample(prev_tree, alpha, beta, seq_dist) # Sample T from T_1..p end_time = time.time() graph_traj.add_sample(jtlib.graph(tree), end_time - start_time) prev_tree = tree return graph_traj
def smc_ggm_graphs(N, alpha, beta, radius, X, D, delta): cache = {} seq_dist = seqdist.GGMJTPosterior() seq_dist.init_model(X, D, delta, cache) (trees, log_w) = approximate(N, alpha, beta, radius, seq_dist) log_w_rescaled = np.array(log_w.T)[seq_dist.p - 1] - \ max(np.array(log_w.T)[seq_dist.p - 1]) norm_w = np.exp(log_w_rescaled) / sum(np.exp(log_w_rescaled)) graphs = [jtlib.graph(tree) for tree in trees] return (graphs, norm_w)
def est_log_norm_consts(order, n_particles, sequential_distribution, alpha=0.5, beta=0.5, n_smc_estimates=1, debug=False): log_consts = np.zeros(n_smc_estimates * (order)).reshape( n_smc_estimates, (order)) def estimate_norm_const(order, weights): log_consts = np.zeros(order) for n in range(1, order): log_consts[n] = log_consts[n - 1] + np.log(np.mean(weights[:, n])) return log_consts for t in tqdm(range(n_smc_estimates), desc="Const estimates"): (trees, log_w) = approximate(n_particles, alpha, beta, sequential_distribution.p, sequential_distribution) w = np.exp(log_w) log_consts[t, :] = estimate_norm_const(order, w) if debug: unique_trees = set() for tree in trees: tree_alt = (frozenset(tree.nodes()), frozenset([frozenset(e) for e in tree.edges()])) unique_trees.add(tree_alt) print("Sampled unique junction trees: " + str(len(unique_trees))) unique_graphs = set( [glib.hash_graph(jtlib.graph(tree)) for tree in trees]) print( "Sampled unique chordal graphs: {n_unique_chordal_graphs}". format(n_unique_chordal_graphs=len(unique_graphs)), ) if n_smc_estimates == 1: log_consts = log_consts.flatten() return log_consts
def sample_dec_graph(internal_nodes, alpha=0.5, beta=0.5, directory='.'): """ Generates a random decomposable graph using the Christmas tree algorithm. Args: internal_nodes (list): list of internal nodes in the generated graph. alpha (float): Subtree kernel parameter beta (float): Subtree kernel parameter directory (string): Path to where the plots should be saved. Returns: NetworkX graph: A decomposable graph. Example: >>> g = dlib.sample_dec_graph(5) >>> g.edges EdgeView([(0, 1), (0, 3), (1, 3), (2, 3)]) >>> g.nodes NodeView((0, 1, 2, 3, 4)) """ T = libj.sample(internal_nodes, alpha=alpha, beta=beta) return libj.graph(T)
def sample_trajectory(smc_N, alpha, beta, radius, n_samples, seq_dist, jt_traj=None, debug=False, reset_cache=True): """ A particle Gibbs implementation for approximating distributions over junction trees. Args: smc_N (int): Number of particles in SMC in each Gibbs iteration n_samples (int): Number of Gibbs iterations (samples) alpha (float): sparsity parameter for the Christmas tree algorithm beta (float): sparsity parameter for the Christmas tree algorithm radius (float): defines the radius within which ned nodes are selected seq_dist (SequentialJTDistributions): the distribution to be sampled from Returns: Trajectory: Markov chain of the underlying graphs of the junction trees sampled by pgibbs. """ graph_traj = mcmctraj.Trajectory() graph_traj.set_sampling_method({ "method": "pgibbs", "params": { "N": smc_N, "alpha": alpha, "beta": beta, "radius": radius } }) graph_traj.set_sequential_distribution(seq_dist) neig_set_cache = {} (trees, log_w) = (None, None) prev_tree = None for i in tqdm(range(n_samples), desc="Particle Gibbs samples"): if reset_cache is True: seq_dist.cache = {} start_time = time.time() if i == 0: #start_graph = nx.Graph() #start_graph.add_nodes_from(range(seqdist.p)) #start_tree = dlib.junction_tree(start_graph) (trees, log_w) = approximate(smc_N, alpha, beta, radius, seq_dist, neig_set_cache=neig_set_cache) else: # Sample backwards trajectories perm_traj = sp.backward_perm_traj_sample(seq_dist.p, radius) T_traj = trilearn.graph.junction_tree_collapser.backward_jt_traj_sample( perm_traj, prev_tree) (trees, log_w, Is) = approximate_cond(smc_N, alpha, beta, radius, seq_dist, T_traj, perm_traj, neig_set_cache=neig_set_cache) # Sample T from T_1..p log_w_array = np.array(log_w.T)[seq_dist.p - 1] log_w_rescaled = log_w_array - max(log_w_array) w_rescaled = np.exp(log_w_rescaled) norm_w = w_rescaled / sum(w_rescaled) I = np.random.choice(smc_N, size=1, p=norm_w)[0] T = trees[I] prev_tree = T graph = jtlib.graph(T) end_time = time.time() graph_traj.add_sample(graph, end_time - start_time) return graph_traj
def sample_trajectory(n_samples, randomize, sd): graph = nx.Graph() graph.add_nodes_from(range(sd.p)) jt = dlib.junction_tree(graph) assert (jtlib.is_junction_tree(jt)) jt_traj = [None] * n_samples graphs = [None] * n_samples jt_traj[0] = jt graphs[0] = jtlib.graph(jt) log_prob_traj = [None] * n_samples gtraj = mcmctraj.Trajectory() gtraj.set_sampling_method({ "method": "mh", "params": { "samples": n_samples, "randomize_interval": randomize } }) gtraj.set_sequential_distribution(sd) log_prob_traj[0] = 0.0 log_prob_traj[0] = sd.log_likelihood(jtlib.graph(jt_traj[0])) log_prob_traj[0] += -jtlib.log_n_junction_trees( jt_traj[0], jtlib.separators(jt_traj[0])) accept_traj = [0] * n_samples MAP_graph = (graphs[0], log_prob_traj[0]) for i in tqdm(range(1, n_samples), desc="Metropolis-Hastings samples"): if log_prob_traj[i - 1] > MAP_graph[1]: MAP_graph = (graphs[i - 1], log_prob_traj[i - 1]) if i % randomize == 0: jtlib.randomize(jt) graphs[i] = jtlib.graph(jt) # TODO: Improve. log_prob_traj[i] = sd.log_likelihood( graphs[i]) - jtlib.log_n_junction_trees( jt, jtlib.separators(jt)) r = np.random.randint(2) # Connect / disconnect move num_seps = jt.size() log_p1 = log_prob_traj[i - 1] if r == 0: # Connect move num_cliques = jt.order() conn = aglib.connect_move( jt) # need to move to calculate posterior seps_prop = jtlib.separators(jt) log_p2 = sd.log_likelihood( jtlib.graph(jt)) - jtlib.log_n_junction_trees(jt, seps_prop) if not conn: log_prob_traj[i] = log_prob_traj[i - 1] graphs[i] = graphs[i - 1] continue C_disconn = conn[2] | conn[3] | conn[4] if conn[0] == "a": (case, log_q12, X, Y, S, CX_disconn, CY_disconn, XSneig, YSneig) = conn (NX_disconn, NY_disconn, N_disconn) = aglib.disconnect_get_neighbors( jt, C_disconn, X, Y) # TODO: could this be done faster? log_q21 = aglib.disconnect_logprob_a(num_cliques - 1, X, Y, S, N_disconn) #print log_p2, log_q21, log_p1, log_q12 alpha = min(np.exp(log_p2 + log_q21 - log_p1 - log_q12), 1) #print alpha samp = np.random.choice(2, 1, p=[(1 - alpha), alpha]) if samp == 1: # print "Accept" accept_traj[i] = 1 log_prob_traj[i] = log_p2 graphs[i] = jtlib.graph(jt) # TODO: Improve. else: #print "Reject" aglib.disconnect_a(jt, C_disconn, X, Y, CX_disconn, CY_disconn, XSneig, YSneig) log_prob_traj[i] = log_prob_traj[i - 1] graphs[i] = graphs[i - 1] continue elif conn[0] == "b": (case, log_q12, X, Y, S, CX_disconn, CY_disconn) = conn log_q21 = aglib.disconnect_logprob_bcd(num_cliques, X, Y, S) alpha = min(np.exp(log_p2 + log_q21 - log_p1 - log_q12), 1) samp = np.random.choice(2, 1, p=[(1 - alpha), alpha]) if samp == 1: #print "Accept" accept_traj[i] = 1 log_prob_traj[i] = log_p2 graphs[i] = jtlib.graph(jt) # TODO: Improve. else: #print "Reject" aglib.disconnect_b(jt, C_disconn, X, Y, CX_disconn, CY_disconn) log_prob_traj[i] = log_prob_traj[i - 1] graphs[i] = graphs[i - 1] continue elif conn[0] == "c": (case, log_q12, X, Y, S, CX_disconn, CY_disconn) = conn log_q21 = aglib.disconnect_logprob_bcd(num_cliques, X, Y, S) alpha = min(np.exp(log_p2 + log_q21 - log_p1 - log_q12), 1) samp = np.random.choice(2, 1, p=[(1 - alpha), alpha]) if samp == 1: accept_traj[i] = 1 #print "Accept" log_prob_traj[i] = log_p2 graphs[i] = jtlib.graph(jt) # TODO: Improve. else: #print "Reject" aglib.disconnect_c(jt, C_disconn, X, Y, CX_disconn, CY_disconn) log_prob_traj[i] = log_prob_traj[i - 1] graphs[i] = graphs[i - 1] continue elif conn[0] == "d": (case, log_q12, X, Y, S, CX_disconn, CY_disconn) = conn log_q21 = aglib.disconnect_logprob_bcd(num_cliques + 1, X, Y, S) alpha = min(np.exp(log_p2 + log_q21 - log_p1 - log_q12), 1) samp = np.random.choice(2, 1, p=[(1 - alpha), alpha]) if samp == 1: accept_traj[i] = 1 #print "Accept" log_prob_traj[i] = log_p2 graphs[i] = jtlib.graph(jt) # TODO: Improve. else: #print "Reject" aglib.disconnect_d(jt, C_disconn, X, Y, CX_disconn, CY_disconn) log_prob_traj[i] = log_prob_traj[i - 1] graphs[i] = graphs[i - 1] continue elif r == 1: # Disconnect move disconnect = aglib.disconnect_move( jt) # need to move to calculate posterior seps_prop = jtlib.separators(jt) log_p2 = sd.log_likelihood( jtlib.graph(jt)) - jtlib.log_n_junction_trees(jt, seps_prop) #assert(jtlib.is_junction_tree(jt)) #print "disconnect" if disconnect is not False: if disconnect[0] == "a": (case, log_q12, X, Y, S, CX_conn, CY_conn) = disconnect log_q21 = aglib.connect_logprob(num_seps + 1, X, Y, CX_conn, CY_conn) alpha = min(np.exp(log_p2 + log_q21 - log_p1 - log_q12), 1) samp = np.random.choice(2, 1, p=[(1 - alpha), alpha]) if samp == 1: accept_traj[i] = 1 #print "Accept" log_prob_traj[i] = log_p2 graphs[i] = jtlib.graph(jt) # TODO: Improve. else: #print "Reject" aglib.connect_a(jt, S, X, Y, CX_conn, CY_conn) log_prob_traj[i] = log_prob_traj[i - 1] graphs[i] = graphs[i - 1] continue elif disconnect[0] == "b": (case, log_q12, X, Y, S, CX_conn, CY_conn) = disconnect log_q21 = aglib.connect_logprob(num_seps, X, Y, CX_conn, CY_conn) alpha = min(np.exp(log_p2 + log_q21 - log_p1 - log_q12), 1) samp = np.random.choice(2, 1, p=[(1 - alpha), alpha]) if samp == 1: accept_traj[i] = 1 #print "Accept" log_prob_traj[i] = log_p2 graphs[i] = jtlib.graph(jt) # TODO: Improve. else: #print "Reject" aglib.connect_b(jt, S, X, Y, CX_conn, CY_conn) log_prob_traj[i] = log_prob_traj[i - 1] graphs[i] = graphs[i - 1] continue elif disconnect[0] == "c": (case, log_q12, X, Y, S, CX_conn, CY_conn) = disconnect log_q21 = aglib.connect_logprob(num_seps, X, Y, CX_conn, CY_conn) alpha = min(np.exp(log_p2 + log_q21 - log_p1 - log_q12), 1) samp = np.random.choice(2, 1, p=[(1 - alpha), alpha]) if samp == 1: accept_traj[i] = 1 #print "Accept" log_prob_traj[i] = log_p2 graphs[i] = jtlib.graph(jt) # TODO: Improve. else: #print "Reject" aglib.connect_c(jt, S, X, Y, CX_conn, CY_conn) log_prob_traj[i] = log_prob_traj[i - 1] graphs[i] = graphs[i - 1] continue elif disconnect[0] == "d": (case, log_q12, X, Y, S, CX_conn, CY_conn) = disconnect log_q21 = aglib.connect_logprob(num_seps - 1, X, Y, CX_conn, CY_conn) alpha = min(np.exp(log_p2 + log_q21 - log_p1 - log_q12), 1) samp = np.random.choice(2, 1, p=[(1 - alpha), alpha]) if samp == 1: #print "Accept" accept_traj[i] = 1 log_prob_traj[i] = log_p2 graphs[i] = jtlib.graph(jt) # TODO: Improve. else: #print "Reject" aglib.connect_d(jt, S, X, Y, CX_conn, CY_conn) log_prob_traj[i] = log_prob_traj[i - 1] graphs[i] = graphs[i - 1] continue else: log_prob_traj[i] = log_prob_traj[i - 1] graphs[i] = graphs[i - 1] continue #print(np.mean(accept_traj[:i])) gtraj.set_trajectory(graphs) gtraj.logl = log_prob_traj return gtraj
def uniform_dec_samples(order, n_particles, alpha=0.5, beta=0.5, debug=False): sd = seqdist.CondUniformJTDistribution(order) (trees, log_w) = approximate(n_particles, alpha, beta, sd.p, sd) graphs = [jtlib.graph(tree) for tree in trees] return graphs