def test_deep_kl_lag(m, p=.99, trials=100, eps=1e-3, n_samples=300, pre_burnin=100): net = m_deep_bistable(m+1, p) # using m+1 since leaf node is used for evidence # compute steady state model_nodes = net._nodes[:-1] evidence_node = net._nodes[-1] ss_distributions = steady_state(net, {evidence_node : 1}, model_nodes, eps) # accumulate data (KL divergence) in nodes x time_steps x trials data = np.zeros((m, n_samples, trials)) for t in range(trials): counts = {node: eps * np.ones(node.size()) for node in model_nodes} def do_divergence(i, net): for j,n in enumerate(model_nodes): counts[n][n.state_index()] += 1 # compute divergence between current estimate and ss_distribution data[j][i][t] = KL(ss_distributions[n], counts[n] / counts[n].sum()) # prepare net by burning-in to 0-evidence gibbs_sample(net, {evidence_node : 0}, None, 0, pre_burnin) # get data on 1-evidence (no burnin this time; demonstrating lag) gibbs_sample(net, {evidence_node : 1}, do_divergence, n_samples, 0) return data
def steady_state(net, evidence, nodes, eps=0, M=10000, burnin=100): """computes steady state distribution for each node """ # eps allows for some small count at each state (to avoid zero-probability states) counts = {node: eps * np.ones(node.size()) for node in nodes} def do_count(i, net): for n in nodes: counts[n][n.state_index()] += 1 gibbs_sample(net, evidence, do_count, M, burnin) for _, c in counts.iteritems(): c /= c.sum() return counts
def sample_marginal_states(net, evidence, samples, when=None): """Computes S[i] = vector of marginal probabilities that net is in state id i. If given, when(net) is evaluated to decide whether each sample is included """ n_states = count_states(net) # estimate starting distribution over states by sampling S = np.zeros(n_states) def do_count_state(i, net): if when is None or when(net): S[state_to_id(net, net.state_vector())] += 1 gibbs_sample(net, evidence, do_count_state, samples, 1) return normalized(S, order=1)
def test_deep_likelihood_lag(m, p=.99, trials=100, eps=1e-3, n_samples=300, pre_burnin=100): net = m_deep_bistable(m+1, p) # compute steady state model_nodes = net._nodes[:-1] evidence_node = net._nodes[-1] ss_distributions = steady_state(net, {evidence_node : 1}, model_nodes, eps) # accumulate data (sample likelihood) in nodes x time_steps x trials data = np.zeros((m-1, n_samples, trials)) for t in range(trials): def do_likelihood(i, net): for j,n in enumerate(model_nodes): data[j][i][t] = ss_distributions[n][n.state_index()] # prepare net by burning-in to 0-evidence gibbs_sample(net, {evidence_node : 0}, None, 0, pre_burnin) # get data on 1-evidence (no burnin this time; demonstrating lag) gibbs_sample(net, {evidence_node : 1}, do_likelihood, n_samples, 0) return data
def compute_distribution(): counter = SwitchedFunction() gibbs_sample(net, {}, counter, args.samples, args.burnin) return counter.distribution()