def posterior_pop(message): # pop top-down (post_mean, post_stdd), h_rec = rec_net_top(contexts[-1]) (prior_mean, prior_stdd), h_gen = gen_net_top() codec = cs.substack( cs.DiagGaussian_GaussianBins(post_mean, post_stdd, prior_mean, prior_stdd, latent_prec, prior_prec), z_view) message, latent = codec.pop(message) latents = [(latent, (prior_mean, prior_stdd))] for rec_net, gen_net, context in reversed( list(zip(rec_nets, gen_nets, contexts[:-1]))): previous_latent_val = prior_mean + \ cs.std_gaussian_centres(prior_prec)[latents[-1][0]] * prior_stdd (post_mean, post_stdd), h_rec = rec_net(h_rec, previous_latent_val, context) (prior_mean, prior_stdd), h_gen = gen_net(h_gen, previous_latent_val) codec = cs.substack( cs.DiagGaussian_GaussianBins(post_mean, post_stdd, prior_mean, prior_stdd, latent_prec, prior_prec), z_view) message, latent = codec.pop(message) latents.append((latent, (prior_mean, prior_stdd))) return message, (latents[::-1], h_gen)
def posterior_push(message, latents): # first run the model top-down to get the params and latent vals latents, _ = latents (post_mean, post_stdd), h_rec = rec_net_top(contexts[-1]) post_params = [(post_mean, post_stdd)] for rec_net, latent, context in reversed( list(zip(rec_nets, latents[1:], contexts[:-1]))): previous_latent, (prior_mean, prior_stdd) = latent previous_latent_val = prior_mean + \ cs.std_gaussian_centres(prior_prec)[previous_latent] * prior_stdd (post_mean, post_stdd), h_rec = rec_net(h_rec, previous_latent_val, context) post_params.append((post_mean, post_stdd)) # now append bottom up for latent, post_param in zip(latents, reversed(post_params)): latent, (prior_mean, prior_stdd) = latent post_mean, post_stdd = post_param codec = cs.substack( cs.DiagGaussian_GaussianBins(post_mean, post_stdd, prior_mean, prior_stdd, latent_prec, prior_prec), z_view) message = codec.push(message, latent) return message
def likelihood(latents): # get the z1 vals to condition on latents, h = latents z1_idxs, (prior_mean, prior_stdd) = latents[0] z1_vals = prior_mean + cs.std_gaussian_centres( prior_prec)[z1_idxs] * prior_stdd return cs.substack(obs_codec(h, z1_vals), x_view)
def test_substack(): n_data = 100 prec = 4 head, tail = cs.base_message((4, 4)) head = np.split(head, 2) message = head, tail data = rng.randint(1 << prec, size=(n_data, 2, 4), dtype='uint64') view_fun = lambda h: h[0] append, pop = cs.substack(cs.repeat(cs.Uniform(prec), n_data), view_fun) message_ = append(message, data) np.testing.assert_array_equal(message_[0][1], message[0][1]) message_, data_ = pop(message_) np.testing.assert_equal(message, message_) np.testing.assert_equal(data, data_)
def codec_from_shape(shape): print("Creating codec for shape " + str(shape)) hps.image_size = (shape[2], shape[3]) z_shape = latent_shape(hps) z_size = np.prod(z_shape) graph = tf.Graph() with graph.as_default(): with tf.variable_scope("model", reuse=tf.AUTO_REUSE): x = tf.placeholder(tf.float32, shape, 'x') model = CVAE1(hps, "eval", x) stepwise_model = LayerwiseCVAE(model) saver = tf.train.Saver(model.avg_dict) config = tf.ConfigProto(allow_soft_placement=True, intra_op_parallelism_threads=4, inter_op_parallelism_threads=4) sess = tf.Session(config=config, graph=graph) saver.restore(sess, restore_path()) run_all_contexts, run_top_prior, runs_down_prior, run_top_posterior, runs_down_posterior, \ run_reconstruction = stepwise_model.get_model_parts_as_numpy_functions(sess) # Setup codecs def vae_view(head): return ag_tuple( (np.reshape(head[:z_size], z_shape), np.reshape(head[z_size:], shape))) obs_codec = lambda h, z1: cs.Logistic_UnifBins(*run_reconstruction( h, z1), obs_precision, bin_prec=8, bin_lb=-0.5, bin_ub=0.5) return cs.substack( ResNetVAE(run_all_contexts, run_top_posterior, runs_down_posterior, run_top_prior, runs_down_prior, obs_codec, prior_precision, q_precision), vae_view)
## Setup codecs # VAE codec model = BinaryVAE(hidden_dim=100, latent_dim=40) model.load_state_dict(torch.load('vae_params')) rec_net = torch_fun_to_numpy_fun(model.encode) gen_net = torch_fun_to_numpy_fun(model.decode) obs_codec = lambda p: cs.Bernoulli(p, bernoulli_precision) def vae_view(head): return ag_tuple((np.reshape(head[:latent_size], latent_shape), np.reshape(head[latent_size:], obs_shape))) vae_append, vae_pop = cs.repeat(cs.substack( bb_ans.VAE(gen_net, rec_net, obs_codec, prior_precision, q_precision), vae_view), num_batches) ## Load mnist images images = datasets.MNIST(sys.argv[1], train=False, download=True).data.numpy() images = np.uint64(rng.random_sample(np.shape(images)) < images / 255.) images = np.split(np.reshape(images, (num_images, -1)), num_batches) ## Encode # Initialize message with some 'extra' bits encode_t0 = time.time() init_message = cs.base_message(obs_size + latent_size) # Encode the mnist images message, = vae_append(init_message, images)
np.shape(images[0]) + (256,), obs_elem_idxs, obs_elem_codec) def pop_(msg): msg, (data, _) = pop(msg) return msg, data return append, pop_ # Setup codecs def vae_view(head): return ag_tuple((np.reshape(head[:latent_size], latent_shape), np.reshape(head[latent_size:], (batch_size,)))) vae_append, vae_pop = cs.repeat(cs.substack( bb_ans.VAE(gen_net, rec_net, obs_codec, prior_precision, q_precision), vae_view), num_batches) # Codec for adding extra bits to the start of the chain (necessary for bits # back). p = prior_precision other_bits_depth = 10 other_bits_append, _ = cs.substack(cs.repeat(codecs.Uniform(p), other_bits_depth), lambda h: vae_view(h)[0]) ## Encode # Initialize message with some 'extra' bits encode_t0 = time.time() init_message = vrans.x_init(batch_size + latent_size) other_bits = rng.randint(1 << p, size=(other_bits_depth,) + latent_shape, dtype=np.uint64)
def ResNetVAE(up_pass, rec_net_top, rec_nets, gen_net_top, gen_nets, obs_codec, prior_prec, latent_prec): """ Codec for a ResNetVAE. Assume that the posterior is bidirectional - i.e. has a deterministic upper pass but top down sampling. Further assume that all latent conditionals are factorised Gaussians, both in the generative network p(z_n|z_{n-1}) and in the inference network q(z_n|x, z_{n-1}) Assume that everything is ordered bottom up """ z_view = lambda head: head[0] x_view = lambda head: head[1] prior_codec = cs.substack(cs.Uniform(prior_prec), z_view) def prior_push(message, latents): # push bottom-up latents, _ = latents for latent in latents: latent, _ = latent message = prior_codec.push(message, latent) return message def prior_pop(message): # pop top-down (prior_mean, prior_stdd), h_gen = gen_net_top() message, latent = prior_codec.pop(message) latents = [(latent, (prior_mean, prior_stdd))] for gen_net in reversed(gen_nets): previous_latent_val = prior_mean + cs.std_gaussian_centres( prior_prec)[latent] * prior_stdd (prior_mean, prior_stdd), h_gen = gen_net(h_gen, previous_latent_val) message, latent = prior_codec.pop(message) latents.append((latent, (prior_mean, prior_stdd))) return message, (latents[::-1], h_gen) def posterior(data): # run deterministic upper-pass contexts = up_pass(data) def posterior_push(message, latents): # first run the model top-down to get the params and latent vals latents, _ = latents (post_mean, post_stdd), h_rec = rec_net_top(contexts[-1]) post_params = [(post_mean, post_stdd)] for rec_net, latent, context in reversed( list(zip(rec_nets, latents[1:], contexts[:-1]))): previous_latent, (prior_mean, prior_stdd) = latent previous_latent_val = prior_mean + \ cs.std_gaussian_centres(prior_prec)[previous_latent] * prior_stdd (post_mean, post_stdd), h_rec = rec_net(h_rec, previous_latent_val, context) post_params.append((post_mean, post_stdd)) # now append bottom up for latent, post_param in zip(latents, reversed(post_params)): latent, (prior_mean, prior_stdd) = latent post_mean, post_stdd = post_param codec = cs.substack( cs.DiagGaussian_GaussianBins(post_mean, post_stdd, prior_mean, prior_stdd, latent_prec, prior_prec), z_view) message = codec.push(message, latent) return message def posterior_pop(message): # pop top-down (post_mean, post_stdd), h_rec = rec_net_top(contexts[-1]) (prior_mean, prior_stdd), h_gen = gen_net_top() codec = cs.substack( cs.DiagGaussian_GaussianBins(post_mean, post_stdd, prior_mean, prior_stdd, latent_prec, prior_prec), z_view) message, latent = codec.pop(message) latents = [(latent, (prior_mean, prior_stdd))] for rec_net, gen_net, context in reversed( list(zip(rec_nets, gen_nets, contexts[:-1]))): previous_latent_val = prior_mean + \ cs.std_gaussian_centres(prior_prec)[latents[-1][0]] * prior_stdd (post_mean, post_stdd), h_rec = rec_net(h_rec, previous_latent_val, context) (prior_mean, prior_stdd), h_gen = gen_net(h_gen, previous_latent_val) codec = cs.substack( cs.DiagGaussian_GaussianBins(post_mean, post_stdd, prior_mean, prior_stdd, latent_prec, prior_prec), z_view) message, latent = codec.pop(message) latents.append((latent, (prior_mean, prior_stdd))) return message, (latents[::-1], h_gen) return cs.Codec(posterior_push, posterior_pop) def likelihood(latents): # get the z1 vals to condition on latents, h = latents z1_idxs, (prior_mean, prior_stdd) = latents[0] z1_vals = prior_mean + cs.std_gaussian_centres( prior_prec)[z1_idxs] * prior_stdd return cs.substack(obs_codec(h, z1_vals), x_view) return BBANS(cs.Codec(prior_push, prior_pop), likelihood, posterior)
def pop(message): message, x = codec.pop(message) return message, ArraySymbol(x) return cs.Codec(push, pop) def vae_view(head): return ag_tuple( (np.reshape(head[:latent_size], latent_shape), np.reshape(head[latent_size:], obs_shape))) vae_append, vae_pop = cs.Multiset( cs.substack( ArrayCodec( bb_ans.VAE(gen_net, rec_net, obs_codec, prior_precision, q_precision)), vae_view)) ## Load mnist images images = datasets.MNIST(sys.argv[1], train=False, download=True).data.numpy() images = np.uint64(rng.random_sample(np.shape(images)) < images / 255.) images = np.split(np.reshape(images, (num_images, -1)), num_batches) images = list(map(ArraySymbol, images)) ## Encode # Initialize message with some 'extra' bits encode_t0 = time.time() init_message = cs.base_message(obs_size + latent_size) # Build multiset
def post1_elem_codec(params, idx): return cs.substack( codecs.DiagGaussian_GaussianBins(params[..., 0], params[..., 1], params[..., 2], params[..., 3], q_precision, prior_precision), lambda head: head[idx])