def test_rans_simple(): size = 3 tail_capacity = 100 precision = 24 n_data = 10 data = rng.integers(0, 4, size=(n_data, size)) # x ~ Categorical(1 / 8, 2 / 8, 3 / 8, 2 / 8) m = m_init = rans.base_message(size, tail_capacity) enc_fun = (lambda x: (jnp.choose(x, [0, 1, 3, 6]), jnp.choose(x, [1, 2, 3, (1 << 24) - 6]))) def dec_fun(cf): return jnp.where( cf < 6, jnp.choose(cf, [0, 1, 1, 2, 2, 2], mode='clip'), 3) codec_push, codec_pop = rans.NonUniform(enc_fun, dec_fun, precision) _, freqs = enc_fun(data) # Encode for x in reversed(data): m = codec_push(m, x) coded_arr = rans.flatten(m) assert coded_arr.dtype == np.uint8 # Decode m = rans.unflatten(coded_arr, size, tail_capacity) data_decoded = [] for _ in range(n_data): m, x = codec_pop(m) data_decoded.append(x) assert rans.message_equal(m, m_init) assert_equal(data, data_decoded)
def test_rans(): x = rans.x_init scale_bits = 8 starts = rng.randint(0, 256, size=1000) freqs = rng.randint(1, 256, size=1000) % (256 - starts) freqs[freqs == 0] = 1 assert np.all(starts + freqs <= 256) print("Exact entropy: " + str(np.sum(np.log2(256 / freqs))) + " bits.") # Encode for start, freq in zip(starts, freqs): x = rans.append(x, start, freq, scale_bits) coded_arr = rans.flatten(x) assert coded_arr.dtype == np.uint32 print("Actual output size: " + str(32 * len(coded_arr)) + " bits.") # Decode x = rans.unflatten(coded_arr) for start, freq in reversed(list(zip(starts, freqs))): def statfun(cf): assert start <= cf < start + freq return None, (start, freq) x, symbol = rans.pop(x, statfun, scale_bits) assert x == (rans.head_min, ())
def test_flatten_unflatten(): state = rans.x_init some_bits = rng.randint(1 << 8, size=5) for b in some_bits: state = rans.append(state, b, 1, 8) flat = rans.flatten(state) state_ = rans.unflatten(flat) flat_ = rans.flatten(state_) assert np.all(flat == flat_) assert state == state_
def test_bvae_enc_dec(): # load an mnist image, x_0 image = pickle.load(open('torch_vae/sample_mnist_image', 'rb')) image = torch.round(torch.tensor(image)).float() # load vae params model = BinaryVAE() model.load_state_dict( torch.load('torch_vae/saved_params/torch_binary_vae_params')) latent_shape = (20,) rec_net = tvae_utils.torch_fun_to_numpy_fun(model.encode) gen_net = tvae_utils.torch_fun_to_numpy_fun(model.decode) obs_append = tvae_utils.bernoulli_obs_append(obs_precision) obs_pop = tvae_utils.bernoulli_obs_pop(obs_precision) vae_append = util.vae_append( latent_shape, gen_net, rec_net, obs_append, prior_precision, q_precision) vae_pop = util.vae_pop( latent_shape, gen_net, rec_net, obs_pop, prior_precision, q_precision) # randomly generate some 'other' bits other_bits = rng.randint(1 << 16, size=20, dtype=np.uint32) state = rans.x_init state = util.uniforms_append(16)(state, other_bits) # ---------------------------- ENCODE ------------------------------------ state = vae_append(state, image) compressed_message = rans.flatten(state) print("Used " + str(32 * (len(compressed_message) - len(other_bits))) + " bits.") # ---------------------------- DECODE ------------------------------------ state = rans.unflatten(compressed_message) state, image_ = vae_pop(state) assert all(image == image_) # recover the other bits from q(y|x_0) state, recovered_bits = util.uniforms_pop(16, 20)(state) assert all(other_bits == recovered_bits) assert state == rans.x_init
def test_rans_jit(): size = 3 tail_capacity = 100 precision = 3 n_data = 100 data = rng.integers(0, 4, size=(n_data, size)) # x ~ Categorical(1 / 8, 2 / 8, 3 / 8, 2 / 8) m = m_init = rans.base_message(size, tail_capacity) choose = partial(jnp.choose, mode='clip') def enc_fun(x): assert is_tracing return (choose(x, jnp.array([0, 1, 3, 6])), choose(x, jnp.array([1, 2, 3, 2]))) def dec_fun(cf): assert is_tracing return choose(cf, jnp.array([0, 1, 1, 2, 2, 2, 3, 3])) codec_push, codec_pop = rans.NonUniform(enc_fun, dec_fun, precision) codec_push, codec_pop = map(jax.jit, (codec_push, codec_pop)) is_tracing = True _, freqs = enc_fun(data) print("Exact entropy: " + str(np.sum(np.log2(8 / freqs))) + " bits.") # Encode m_ = codec_push(m, data[0]) is_tracing = False for x in reversed(data): m = codec_push(m, x) coded_arr = rans.flatten(m) assert coded_arr.dtype == np.uint8 print("Actual output shape: " + str(16 * len(coded_arr)) + " bits.") # Decode m = rans.unflatten(coded_arr, size, tail_capacity) is_tracing = True codec_pop(m) is_tracing = False data_decoded = [] for _ in range(n_data): m, x = codec_pop(m) data_decoded.append(x) assert rans.message_equal(m, m_init) assert_equal(data, data_decoded)
def test_rans_lax_fori_loop(): size = 3 tail_capacity = 100 precision = 3 n_data = 100 data = jnp.array(rng.integers(0, 4, size=(n_data, size))) # x ~ Categorical(1 / 8, 2 / 8, 3 / 8, 2 / 8) m = m_init = rans.base_message(size, tail_capacity) choose = partial(jnp.choose, mode='clip') def enc_fun(x): return (choose(x, jnp.array([0, 1, 3, 6])), choose(x, jnp.array([1, 2, 3, 2]))) def dec_fun(cf): return choose(cf, jnp.array([0, 1, 1, 2, 2, 2, 3, 3])) codec_push, codec_pop = rans.NonUniform(enc_fun, dec_fun, precision) _, freqs = enc_fun(data) # Encode def push_body(i, carry): m = carry m = codec_push(m, data[n_data - i - 1]) return m m = lax.fori_loop(0, n_data, push_body, m) coded_arr = rans.flatten(m) assert coded_arr.dtype == np.uint8 # Decode def pop_body(i, carry): m, xs = carry m, x = codec_pop(m) return m, lax.dynamic_update_index_in_dim(xs, x, i, 0) m = rans.unflatten(coded_arr, size, tail_capacity) m, data_decoded = lax.fori_loop(0, n_data, pop_body, (m, jnp.zeros((n_data, size), 'int32'))) assert rans.message_equal(m, m_init)
prior_precision, q_precision) vae_pop = util.vae_pop(latent_shape, gen_net, rec_net, obs_pop, prior_precision, q_precision) # load some mnist images mnist = datasets.MNIST('data/mnist', train=False, download=True, transform=transforms.Compose([transforms.ToTensor()])) images = mnist.test_data[:num_images] images = [image.float().view(1, -1) for image in images] # randomly generate some 'other' bits other_bits = rng.randint(low=1 << 16, high=1 << 31, size=50, dtype=np.uint32) state = rans.unflatten(other_bits) print_interval = 10 encode_start_time = time.time() for i, image in enumerate(images): state = vae_append(state, image) if i == 1: break ''' if not i % print_interval: print('Encoded {}'.format(i)) compressed_length = 32 * (len(rans.flatten(state)) - len(other_bits)) / (i+1) compress_lengths.append(compressed_length) print('\nAll encoded in {:.2f}s'.format(time.time() - encode_start_time))
# load some mnist images mnist = datasets.MNIST('data/mnist', train=False, download=True, transform=transforms.Compose( [transforms.ToTensor()])) images = mnist.test_data[:num_images] if randomise_data: images = Bernoulli(images.float() / 255.).sample() else: images = torch.round(images.float() / 255.) images = [image.view(-1) for image in images] # randomly generate some 'other' bits other_bits = rng.randint(low=1 << 16, high=1 << 31, size=20, dtype=np.uint32) state = rans.unflatten(other_bits) print_interval = 10 encode_start_time = time.time() for i, image in enumerate(images): state = vae_append(state, image) if not i % print_interval: print('Encoded {}'.format(i)) compressed_length = 32*(len(rans.flatten(state)) - len(other_bits)) / (i+1) compress_lengths.append(compressed_length) print('\nAll encoded in {:.2f}s'.format(time.time() - encode_start_time)) compressed_message = rans.flatten(state)