def test_bvae_enc_dec(): # load an mnist image, x_0 image = pickle.load(open('torch_vae/sample_mnist_image', 'rb')) image = torch.round(torch.tensor(image)).float() # load vae params model = BinaryVAE() model.load_state_dict( torch.load('torch_vae/saved_params/torch_binary_vae_params')) latent_shape = (20,) rec_net = tvae_utils.torch_fun_to_numpy_fun(model.encode) gen_net = tvae_utils.torch_fun_to_numpy_fun(model.decode) obs_append = tvae_utils.bernoulli_obs_append(obs_precision) obs_pop = tvae_utils.bernoulli_obs_pop(obs_precision) vae_append = util.vae_append( latent_shape, gen_net, rec_net, obs_append, prior_precision, q_precision) vae_pop = util.vae_pop( latent_shape, gen_net, rec_net, obs_pop, prior_precision, q_precision) # randomly generate some 'other' bits other_bits = rng.randint(1 << 16, size=20, dtype=np.uint32) state = rans.x_init state = util.uniforms_append(16)(state, other_bits) # ---------------------------- ENCODE ------------------------------------ state = vae_append(state, image) compressed_message = rans.flatten(state) print("Used " + str(32 * (len(compressed_message) - len(other_bits))) + " bits.") # ---------------------------- DECODE ------------------------------------ state = rans.unflatten(compressed_message) state, image_ = vae_pop(state) assert all(image == image_) # recover the other bits from q(y|x_0) state, recovered_bits = util.uniforms_pop(16, 20)(state) assert all(other_bits == recovered_bits) assert state == rans.x_init
latent_dim = 50 latent_shape = (1, latent_dim) model = BetaBinomialVAE(hidden_dim=200, latent_dim=latent_dim) model.load_state_dict( torch.load('torch_vae/saved_params/torch_vae_beta_binomial_params', map_location=lambda storage, location: storage)) model.eval() rec_net = tvae_utils.torch_fun_to_numpy_fun(model.encode) gen_net = tvae_utils.torch_fun_to_numpy_fun(model.decode) obs_append = tvae_utils.beta_binomial_obs_append(255, obs_precision) obs_pop = tvae_utils.beta_binomial_obs_pop(255, obs_precision) vae_append = util.vae_append(latent_shape, gen_net, rec_net, obs_append, prior_precision, q_precision) vae_pop = util.vae_pop(latent_shape, gen_net, rec_net, obs_pop, prior_precision, q_precision) # load some mnist images mnist = datasets.MNIST('data/mnist', train=False, download=True, transform=transforms.Compose([transforms.ToTensor()])) images = mnist.test_data[:num_images] images = [image.float().view(1, -1) for image in images] # randomly generate some 'other' bits other_bits = rng.randint(low=1 << 16, high=1 << 31, size=50, dtype=np.uint32) state = rans.unflatten(other_bits)