def prior_pop(message): message, z2 = prior_z2_pop(message) message, z1 = prior_z1_pop(message) # compute theta1 eps1_vals = codecs.std_gaussian_centres(prior_prec)[z1] z2_vals = codecs.std_gaussian_centres(prior_prec)[z2] theta1 = get_theta(eps1_vals, z2_vals) return message, ((z1, z2), theta1)
def posterior_pop(message): # pop top-down (post_mean, post_stdd), h_rec = rec_net_top(contexts[-1]) (prior_mean, prior_stdd), h_gen = gen_net_top() _, pop = codecs.substack( codecs.DiagGaussian_GaussianBins(post_mean, post_stdd, prior_mean, prior_stdd, latent_prec, prior_prec), z_view) message, latent = pop(message) latents = [(latent, (prior_mean, prior_stdd))] for rec_net, gen_net, context in reversed( list(zip(rec_nets, gen_nets, contexts[:-1]))): previous_latent_val = prior_mean + \ codecs.std_gaussian_centres(prior_prec)[latents[-1][0]] * prior_stdd (post_mean, post_stdd), h_rec = rec_net(h_rec, previous_latent_val, context) (prior_mean, prior_stdd), h_gen = gen_net(h_gen, previous_latent_val) _, pop = codecs.substack( codecs.DiagGaussian_GaussianBins(post_mean, post_stdd, prior_mean, prior_stdd, latent_prec, prior_prec), z_view) message, latent = pop(message) latents.append((latent, (prior_mean, prior_stdd))) return message, (latents[::-1], h_gen)
def posterior_append(message, latents): # first run the model top-down to get the params and latent vals latents, _ = latents (post_mean, post_stdd), h_rec = rec_net_top(contexts[-1]) post_params = [(post_mean, post_stdd)] for rec_net, latent, context in reversed( list(zip(rec_nets, latents[1:], contexts[:-1]))): previous_latent, (prior_mean, prior_stdd) = latent previous_latent_val = prior_mean + \ codecs.std_gaussian_centres(prior_prec)[previous_latent] * prior_stdd (post_mean, post_stdd), h_rec = rec_net(h_rec, previous_latent_val, context) post_params.append((post_mean, post_stdd)) # now append bottom up for latent, post_param in zip(latents, reversed(post_params)): latent, (prior_mean, prior_stdd) = latent post_mean, post_stdd = post_param append, _ = codecs.substack( codecs.DiagGaussian_GaussianBins(post_mean, post_stdd, prior_mean, prior_stdd, latent_prec, prior_prec), z_view) message = append(message, latent) return message
def likelihood(latents): # get the z1 vals to condition on latents, h = latents z1_idxs, (prior_mean, prior_stdd) = latents[0] z1_vals = prior_mean + codecs.std_gaussian_centres( prior_prec)[z1_idxs] * prior_stdd return codecs.substack(obs_codec(h, z1_vals), x_view)
def posterior_pop(message): message, z2 = post_z2_pop(message) z2_vals = codecs.std_gaussian_centres(prior_prec)[z2] # need to return theta1 from the z1 pop _, post_z1_pop = codecs.substack(post1_codec(z2_vals, mu1, sig1), z1_view) message, (z1, theta1) = post_z1_pop(message) return message, ((z1, z2), theta1)
def likelihood(latent): (z1, _), theta1 = latent # get z1_vals from the latent _, _, mu1_prior, sig1_prior = np.moveaxis(theta1, -1, 0) eps1_vals = codecs.std_gaussian_centres(prior_prec)[z1] z1_vals = mu1_prior + sig1_prior * eps1_vals append, pop = codecs.substack(obs_codec(gen_net2_partial(z1_vals)), x_view) return append, pop
def posterior_append(message, latents): (z1, z2), theta1 = latents z2_vals = codecs.std_gaussian_centres(prior_prec)[z2] post_z1_append, _ = codecs.substack( post1_codec(z2_vals, mu1, sig1), z1_view) theta1[..., 0] = mu1 theta1[..., 1] = sig1 message = post_z1_append(message, z1, theta1) message = post_z2_append(message, z2) return message
def prior_pop(message): # pop top-down (prior_mean, prior_stdd), h_gen = gen_net_top() _, pop = prior_codec message, latent = pop(message) latents = [(latent, (prior_mean, prior_stdd))] for gen_net in reversed(gen_nets): previous_latent_val = prior_mean + codecs.std_gaussian_centres( prior_prec)[latent] * prior_stdd (prior_mean, prior_stdd), h_gen = gen_net(h_gen, previous_latent_val) message, latent = pop(message) latents.append((latent, (prior_mean, prior_stdd))) return message, (latents[::-1], h_gen)
def likelihood(latent_idxs): z = std_gaussian_centres(prior_prec)[latent_idxs] return substack(obs_codec(gen_net(z)), x_view)