Ejemplo n.º 1
0
def test_bvae_enc_dec():
    # load an mnist image, x_0
    image = pickle.load(open('torch_vae/sample_mnist_image', 'rb'))
    image = torch.round(torch.tensor(image)).float()

    # load vae params
    model = BinaryVAE()
    model.load_state_dict(
        torch.load('torch_vae/saved_params/torch_binary_vae_params'))

    latent_shape = (20,)

    rec_net = tvae_utils.torch_fun_to_numpy_fun(model.encode)
    gen_net = tvae_utils.torch_fun_to_numpy_fun(model.decode)

    obs_append = tvae_utils.bernoulli_obs_append(obs_precision)
    obs_pop = tvae_utils.bernoulli_obs_pop(obs_precision)

    vae_append = util.vae_append(
        latent_shape, gen_net, rec_net, obs_append,
        prior_precision, q_precision)

    vae_pop = util.vae_pop(
        latent_shape, gen_net, rec_net, obs_pop,
        prior_precision, q_precision)

    # randomly generate some 'other' bits
    other_bits = rng.randint(1 << 16, size=20, dtype=np.uint32)
    state = rans.x_init
    state = util.uniforms_append(16)(state, other_bits)

    # ---------------------------- ENCODE ------------------------------------
    state = vae_append(state, image)
    compressed_message = rans.flatten(state)

    print("Used " + str(32 * (len(compressed_message) - len(other_bits))) +
          " bits.")

    # ---------------------------- DECODE ------------------------------------
    state = rans.unflatten(compressed_message)

    state, image_ = vae_pop(state)
    assert all(image == image_)

    #  recover the other bits from q(y|x_0)
    state, recovered_bits = util.uniforms_pop(16, 20)(state)
    assert all(other_bits == recovered_bits)
    assert state == rans.x_init
Ejemplo n.º 2
0
latent_dim = 50
latent_shape = (1, latent_dim)
model = BetaBinomialVAE(hidden_dim=200, latent_dim=latent_dim)
model.load_state_dict(
    torch.load('torch_vae/saved_params/torch_vae_beta_binomial_params',
               map_location=lambda storage, location: storage))
model.eval()

rec_net = tvae_utils.torch_fun_to_numpy_fun(model.encode)
gen_net = tvae_utils.torch_fun_to_numpy_fun(model.decode)

obs_append = tvae_utils.beta_binomial_obs_append(255, obs_precision)
obs_pop = tvae_utils.beta_binomial_obs_pop(255, obs_precision)

vae_append = util.vae_append(latent_shape, gen_net, rec_net, obs_append,
                             prior_precision, q_precision)
vae_pop = util.vae_pop(latent_shape, gen_net, rec_net, obs_pop,
                       prior_precision, q_precision)

# load some mnist images
mnist = datasets.MNIST('data/mnist',
                       train=False,
                       download=True,
                       transform=transforms.Compose([transforms.ToTensor()]))
images = mnist.test_data[:num_images]

images = [image.float().view(1, -1) for image in images]

# randomly generate some 'other' bits
other_bits = rng.randint(low=1 << 16, high=1 << 31, size=50, dtype=np.uint32)
state = rans.unflatten(other_bits)