Esempio n. 1
0
        def posterior_pop(message):
            # pop top-down
            (post_mean, post_stdd), h_rec = rec_net_top(contexts[-1])
            (prior_mean, prior_stdd), h_gen = gen_net_top()
            codec = cs.substack(
                cs.DiagGaussian_GaussianBins(post_mean, post_stdd, prior_mean,
                                             prior_stdd, latent_prec,
                                             prior_prec), z_view)
            message, latent = codec.pop(message)
            latents = [(latent, (prior_mean, prior_stdd))]
            for rec_net, gen_net, context in reversed(
                    list(zip(rec_nets, gen_nets, contexts[:-1]))):
                previous_latent_val = prior_mean + \
                                      cs.std_gaussian_centres(prior_prec)[latents[-1][0]] * prior_stdd

                (post_mean,
                 post_stdd), h_rec = rec_net(h_rec, previous_latent_val,
                                             context)
                (prior_mean,
                 prior_stdd), h_gen = gen_net(h_gen, previous_latent_val)
                codec = cs.substack(
                    cs.DiagGaussian_GaussianBins(post_mean, post_stdd,
                                                 prior_mean, prior_stdd,
                                                 latent_prec, prior_prec),
                    z_view)
                message, latent = codec.pop(message)
                latents.append((latent, (prior_mean, prior_stdd)))
            return message, (latents[::-1], h_gen)
Esempio n. 2
0
        def posterior_push(message, latents):
            # first run the model top-down to get the params and latent vals
            latents, _ = latents

            (post_mean, post_stdd), h_rec = rec_net_top(contexts[-1])
            post_params = [(post_mean, post_stdd)]

            for rec_net, latent, context in reversed(
                    list(zip(rec_nets, latents[1:], contexts[:-1]))):
                previous_latent, (prior_mean, prior_stdd) = latent
                previous_latent_val = prior_mean + \
                                      cs.std_gaussian_centres(prior_prec)[previous_latent] * prior_stdd

                (post_mean,
                 post_stdd), h_rec = rec_net(h_rec, previous_latent_val,
                                             context)
                post_params.append((post_mean, post_stdd))

            # now append bottom up
            for latent, post_param in zip(latents, reversed(post_params)):
                latent, (prior_mean, prior_stdd) = latent
                post_mean, post_stdd = post_param
                codec = cs.substack(
                    cs.DiagGaussian_GaussianBins(post_mean, post_stdd,
                                                 prior_mean, prior_stdd,
                                                 latent_prec, prior_prec),
                    z_view)
                message = codec.push(message, latent)
            return message
Esempio n. 3
0
 def likelihood(latents):
     # get the z1 vals to condition on
     latents, h = latents
     z1_idxs, (prior_mean, prior_stdd) = latents[0]
     z1_vals = prior_mean + cs.std_gaussian_centres(
         prior_prec)[z1_idxs] * prior_stdd
     return cs.substack(obs_codec(h, z1_vals), x_view)
Esempio n. 4
0
def test_substack():
    n_data = 100
    prec = 4
    head, tail = cs.base_message((4, 4))
    head = np.split(head, 2)
    message = head, tail
    data = rng.randint(1 << prec, size=(n_data, 2, 4), dtype='uint64')
    view_fun = lambda h: h[0]
    append, pop = cs.substack(cs.repeat(cs.Uniform(prec), n_data), view_fun)
    message_ = append(message, data)
    np.testing.assert_array_equal(message_[0][1], message[0][1])
    message_, data_ = pop(message_)
    np.testing.assert_equal(message, message_)
    np.testing.assert_equal(data, data_)
Esempio n. 5
0
    def codec_from_shape(shape):
        print("Creating codec for shape " + str(shape))

        hps.image_size = (shape[2], shape[3])

        z_shape = latent_shape(hps)
        z_size = np.prod(z_shape)

        graph = tf.Graph()
        with graph.as_default():
            with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
                x = tf.placeholder(tf.float32, shape, 'x')
                model = CVAE1(hps, "eval", x)
                stepwise_model = LayerwiseCVAE(model)

        saver = tf.train.Saver(model.avg_dict)
        config = tf.ConfigProto(allow_soft_placement=True,
                                intra_op_parallelism_threads=4,
                                inter_op_parallelism_threads=4)
        sess = tf.Session(config=config, graph=graph)
        saver.restore(sess, restore_path())

        run_all_contexts, run_top_prior, runs_down_prior, run_top_posterior, runs_down_posterior, \
        run_reconstruction = stepwise_model.get_model_parts_as_numpy_functions(sess)

        # Setup codecs
        def vae_view(head):
            return ag_tuple(
                (np.reshape(head[:z_size],
                            z_shape), np.reshape(head[z_size:], shape)))

        obs_codec = lambda h, z1: cs.Logistic_UnifBins(*run_reconstruction(
            h, z1),
                                                       obs_precision,
                                                       bin_prec=8,
                                                       bin_lb=-0.5,
                                                       bin_ub=0.5)

        return cs.substack(
            ResNetVAE(run_all_contexts, run_top_posterior, runs_down_posterior,
                      run_top_prior, runs_down_prior, obs_codec,
                      prior_precision, q_precision), vae_view)
Esempio n. 6
0
## Setup codecs
# VAE codec
model = BinaryVAE(hidden_dim=100, latent_dim=40)
model.load_state_dict(torch.load('vae_params'))

rec_net = torch_fun_to_numpy_fun(model.encode)
gen_net = torch_fun_to_numpy_fun(model.decode)

obs_codec = lambda p: cs.Bernoulli(p, bernoulli_precision)

def vae_view(head):
    return ag_tuple((np.reshape(head[:latent_size], latent_shape),
                     np.reshape(head[latent_size:], obs_shape)))

vae_append, vae_pop = cs.repeat(cs.substack(
    bb_ans.VAE(gen_net, rec_net, obs_codec, prior_precision, q_precision),
    vae_view), num_batches)

## Load mnist images
images = datasets.MNIST(sys.argv[1], train=False, download=True).data.numpy()
images = np.uint64(rng.random_sample(np.shape(images)) < images / 255.)
images = np.split(np.reshape(images, (num_images, -1)), num_batches)

## Encode
# Initialize message with some 'extra' bits
encode_t0 = time.time()
init_message = cs.base_message(obs_size + latent_size)

# Encode the mnist images
message, = vae_append(init_message, images)
Esempio n. 7
0
                                            np.shape(images[0]) + (256,),
                                            obs_elem_idxs,
                                            obs_elem_codec)
        def pop_(msg):
            msg, (data, _) = pop(msg)
            return msg, data
        return append, pop_

    # Setup codecs
    def vae_view(head):
        return ag_tuple((np.reshape(head[:latent_size], latent_shape),
                         np.reshape(head[latent_size:], (batch_size,))))


    vae_append, vae_pop = cs.repeat(cs.substack(
        bb_ans.VAE(gen_net, rec_net, obs_codec, prior_precision, q_precision),
        vae_view), num_batches)

    # Codec for adding extra bits to the start of the chain (necessary for bits
    # back).
    p = prior_precision
    other_bits_depth = 10
    other_bits_append, _ = cs.substack(cs.repeat(codecs.Uniform(p), other_bits_depth),
                                       lambda h: vae_view(h)[0])

    ## Encode
    # Initialize message with some 'extra' bits
    encode_t0 = time.time()
    init_message = vrans.x_init(batch_size + latent_size)

    other_bits = rng.randint(1 << p, size=(other_bits_depth,) + latent_shape, dtype=np.uint64)
Esempio n. 8
0
def ResNetVAE(up_pass, rec_net_top, rec_nets, gen_net_top, gen_nets, obs_codec,
              prior_prec, latent_prec):
    """
    Codec for a ResNetVAE.
    Assume that the posterior is bidirectional -
    i.e. has a deterministic upper pass but top down sampling.
    Further assume that all latent conditionals are factorised Gaussians,
    both in the generative network p(z_n|z_{n-1})
    and in the inference network q(z_n|x, z_{n-1})

    Assume that everything is ordered bottom up
    """
    z_view = lambda head: head[0]
    x_view = lambda head: head[1]

    prior_codec = cs.substack(cs.Uniform(prior_prec), z_view)

    def prior_push(message, latents):
        # push bottom-up
        latents, _ = latents
        for latent in latents:
            latent, _ = latent
            message = prior_codec.push(message, latent)
        return message

    def prior_pop(message):
        # pop top-down
        (prior_mean, prior_stdd), h_gen = gen_net_top()
        message, latent = prior_codec.pop(message)
        latents = [(latent, (prior_mean, prior_stdd))]
        for gen_net in reversed(gen_nets):
            previous_latent_val = prior_mean + cs.std_gaussian_centres(
                prior_prec)[latent] * prior_stdd
            (prior_mean, prior_stdd), h_gen = gen_net(h_gen,
                                                      previous_latent_val)
            message, latent = prior_codec.pop(message)
            latents.append((latent, (prior_mean, prior_stdd)))
        return message, (latents[::-1], h_gen)

    def posterior(data):
        # run deterministic upper-pass
        contexts = up_pass(data)

        def posterior_push(message, latents):
            # first run the model top-down to get the params and latent vals
            latents, _ = latents

            (post_mean, post_stdd), h_rec = rec_net_top(contexts[-1])
            post_params = [(post_mean, post_stdd)]

            for rec_net, latent, context in reversed(
                    list(zip(rec_nets, latents[1:], contexts[:-1]))):
                previous_latent, (prior_mean, prior_stdd) = latent
                previous_latent_val = prior_mean + \
                                      cs.std_gaussian_centres(prior_prec)[previous_latent] * prior_stdd

                (post_mean,
                 post_stdd), h_rec = rec_net(h_rec, previous_latent_val,
                                             context)
                post_params.append((post_mean, post_stdd))

            # now append bottom up
            for latent, post_param in zip(latents, reversed(post_params)):
                latent, (prior_mean, prior_stdd) = latent
                post_mean, post_stdd = post_param
                codec = cs.substack(
                    cs.DiagGaussian_GaussianBins(post_mean, post_stdd,
                                                 prior_mean, prior_stdd,
                                                 latent_prec, prior_prec),
                    z_view)
                message = codec.push(message, latent)
            return message

        def posterior_pop(message):
            # pop top-down
            (post_mean, post_stdd), h_rec = rec_net_top(contexts[-1])
            (prior_mean, prior_stdd), h_gen = gen_net_top()
            codec = cs.substack(
                cs.DiagGaussian_GaussianBins(post_mean, post_stdd, prior_mean,
                                             prior_stdd, latent_prec,
                                             prior_prec), z_view)
            message, latent = codec.pop(message)
            latents = [(latent, (prior_mean, prior_stdd))]
            for rec_net, gen_net, context in reversed(
                    list(zip(rec_nets, gen_nets, contexts[:-1]))):
                previous_latent_val = prior_mean + \
                                      cs.std_gaussian_centres(prior_prec)[latents[-1][0]] * prior_stdd

                (post_mean,
                 post_stdd), h_rec = rec_net(h_rec, previous_latent_val,
                                             context)
                (prior_mean,
                 prior_stdd), h_gen = gen_net(h_gen, previous_latent_val)
                codec = cs.substack(
                    cs.DiagGaussian_GaussianBins(post_mean, post_stdd,
                                                 prior_mean, prior_stdd,
                                                 latent_prec, prior_prec),
                    z_view)
                message, latent = codec.pop(message)
                latents.append((latent, (prior_mean, prior_stdd)))
            return message, (latents[::-1], h_gen)

        return cs.Codec(posterior_push, posterior_pop)

    def likelihood(latents):
        # get the z1 vals to condition on
        latents, h = latents
        z1_idxs, (prior_mean, prior_stdd) = latents[0]
        z1_vals = prior_mean + cs.std_gaussian_centres(
            prior_prec)[z1_idxs] * prior_stdd
        return cs.substack(obs_codec(h, z1_vals), x_view)

    return BBANS(cs.Codec(prior_push, prior_pop), likelihood, posterior)
    def pop(message):
        message, x = codec.pop(message)
        return message, ArraySymbol(x)

    return cs.Codec(push, pop)


def vae_view(head):
    return ag_tuple(
        (np.reshape(head[:latent_size],
                    latent_shape), np.reshape(head[latent_size:], obs_shape)))


vae_append, vae_pop = cs.Multiset(
    cs.substack(
        ArrayCodec(
            bb_ans.VAE(gen_net, rec_net, obs_codec, prior_precision,
                       q_precision)), vae_view))

## Load mnist images
images = datasets.MNIST(sys.argv[1], train=False, download=True).data.numpy()
images = np.uint64(rng.random_sample(np.shape(images)) < images / 255.)
images = np.split(np.reshape(images, (num_images, -1)), num_batches)
images = list(map(ArraySymbol, images))

## Encode
# Initialize message with some 'extra' bits
encode_t0 = time.time()

init_message = cs.base_message(obs_size + latent_size)

# Build multiset
Esempio n. 10
0
 def post1_elem_codec(params, idx):
     return cs.substack(
         codecs.DiagGaussian_GaussianBins(params[..., 0], params[..., 1],
                                          params[..., 2], params[..., 3],
                                          q_precision, prior_precision),
         lambda head: head[idx])