Ejemplo n.º 1
0
def test_rans_simple():
    size = 3
    tail_capacity = 100
    precision = 24
    n_data = 10
    data = rng.integers(0, 4, size=(n_data, size))

    # x ~ Categorical(1 / 8, 2 / 8, 3 / 8, 2 / 8)
    m = m_init = rans.base_message(size, tail_capacity)
    enc_fun = (lambda x: (jnp.choose(x, [0, 1, 3, 6]),
                          jnp.choose(x, [1, 2, 3, (1 << 24) - 6])))
    def dec_fun(cf):
        return jnp.where(
            cf < 6,
            jnp.choose(cf, [0, 1, 1, 2, 2, 2], mode='clip'),
            3)
    codec_push, codec_pop = rans.NonUniform(enc_fun, dec_fun, precision)
    _, freqs = enc_fun(data)
    # Encode
    for x in reversed(data):
        m = codec_push(m, x)
    coded_arr = rans.flatten(m)
    assert coded_arr.dtype == np.uint8

    # Decode
    m = rans.unflatten(coded_arr, size, tail_capacity)
    data_decoded = []
    for _ in range(n_data):
        m, x = codec_pop(m)
        data_decoded.append(x)
    assert rans.message_equal(m, m_init)
    assert_equal(data, data_decoded)
Ejemplo n.º 2
0
def test_rans():
    x = rans.x_init
    scale_bits = 8
    starts = rng.randint(0, 256, size=1000)
    freqs = rng.randint(1, 256, size=1000) % (256 - starts)
    freqs[freqs == 0] = 1
    assert np.all(starts + freqs <= 256)
    print("Exact entropy: " + str(np.sum(np.log2(256 / freqs))) + " bits.")
    # Encode
    for start, freq in zip(starts, freqs):
        x = rans.append(x, start, freq, scale_bits)
    coded_arr = rans.flatten(x)
    assert coded_arr.dtype == np.uint32
    print("Actual output size: " + str(32 * len(coded_arr)) + " bits.")

    # Decode
    x = rans.unflatten(coded_arr)
    for start, freq in reversed(list(zip(starts, freqs))):

        def statfun(cf):
            assert start <= cf < start + freq
            return None, (start, freq)

        x, symbol = rans.pop(x, statfun, scale_bits)
    assert x == (rans.head_min, ())
Ejemplo n.º 3
0
def test_flatten_unflatten():
    state = rans.x_init
    some_bits = rng.randint(1 << 8, size=5)
    for b in some_bits:
        state = rans.append(state, b, 1, 8)
    flat = rans.flatten(state)
    state_ = rans.unflatten(flat)
    flat_ = rans.flatten(state_)
    assert np.all(flat == flat_)
    assert state == state_
Ejemplo n.º 4
0
def test_bvae_enc_dec():
    # load an mnist image, x_0
    image = pickle.load(open('torch_vae/sample_mnist_image', 'rb'))
    image = torch.round(torch.tensor(image)).float()

    # load vae params
    model = BinaryVAE()
    model.load_state_dict(
        torch.load('torch_vae/saved_params/torch_binary_vae_params'))

    latent_shape = (20,)

    rec_net = tvae_utils.torch_fun_to_numpy_fun(model.encode)
    gen_net = tvae_utils.torch_fun_to_numpy_fun(model.decode)

    obs_append = tvae_utils.bernoulli_obs_append(obs_precision)
    obs_pop = tvae_utils.bernoulli_obs_pop(obs_precision)

    vae_append = util.vae_append(
        latent_shape, gen_net, rec_net, obs_append,
        prior_precision, q_precision)

    vae_pop = util.vae_pop(
        latent_shape, gen_net, rec_net, obs_pop,
        prior_precision, q_precision)

    # randomly generate some 'other' bits
    other_bits = rng.randint(1 << 16, size=20, dtype=np.uint32)
    state = rans.x_init
    state = util.uniforms_append(16)(state, other_bits)

    # ---------------------------- ENCODE ------------------------------------
    state = vae_append(state, image)
    compressed_message = rans.flatten(state)

    print("Used " + str(32 * (len(compressed_message) - len(other_bits))) +
          " bits.")

    # ---------------------------- DECODE ------------------------------------
    state = rans.unflatten(compressed_message)

    state, image_ = vae_pop(state)
    assert all(image == image_)

    #  recover the other bits from q(y|x_0)
    state, recovered_bits = util.uniforms_pop(16, 20)(state)
    assert all(other_bits == recovered_bits)
    assert state == rans.x_init
Ejemplo n.º 5
0
def test_rans_jit():
    size = 3
    tail_capacity = 100
    precision = 3
    n_data = 100
    data = rng.integers(0, 4, size=(n_data, size))

    # x ~ Categorical(1 / 8, 2 / 8, 3 / 8, 2 / 8)
    m = m_init = rans.base_message(size, tail_capacity)
    choose = partial(jnp.choose, mode='clip')
    def enc_fun(x):
        assert is_tracing
        return (choose(x, jnp.array([0, 1, 3, 6])),
                choose(x, jnp.array([1, 2, 3, 2])))

    def dec_fun(cf):
        assert is_tracing
        return choose(cf, jnp.array([0, 1, 1, 2, 2, 2, 3, 3]))

    codec_push, codec_pop = rans.NonUniform(enc_fun, dec_fun, precision)
    codec_push, codec_pop = map(jax.jit, (codec_push, codec_pop))

    is_tracing = True
    _, freqs = enc_fun(data)
    print("Exact entropy: " + str(np.sum(np.log2(8 / freqs))) + " bits.")

    # Encode
    m_ = codec_push(m, data[0])
    is_tracing = False
    for x in reversed(data):
        m = codec_push(m, x)
    coded_arr = rans.flatten(m)
    assert coded_arr.dtype == np.uint8
    print("Actual output shape: " + str(16 * len(coded_arr)) + " bits.")

    # Decode
    m = rans.unflatten(coded_arr, size, tail_capacity)
    is_tracing = True
    codec_pop(m)
    is_tracing = False
    data_decoded = []
    for _ in range(n_data):
        m, x = codec_pop(m)
        data_decoded.append(x)
    assert rans.message_equal(m, m_init)
    assert_equal(data, data_decoded)
Ejemplo n.º 6
0
def test_rans_lax_fori_loop():
    size = 3
    tail_capacity = 100
    precision = 3
    n_data = 100
    data = jnp.array(rng.integers(0, 4, size=(n_data, size)))

    # x ~ Categorical(1 / 8, 2 / 8, 3 / 8, 2 / 8)
    m = m_init = rans.base_message(size, tail_capacity)
    choose = partial(jnp.choose, mode='clip')
    def enc_fun(x):
        return (choose(x, jnp.array([0, 1, 3, 6])),
                choose(x, jnp.array([1, 2, 3, 2])))

    def dec_fun(cf):
        return choose(cf, jnp.array([0, 1, 1, 2, 2, 2, 3, 3]))

    codec_push, codec_pop = rans.NonUniform(enc_fun, dec_fun, precision)

    _, freqs = enc_fun(data)

    # Encode
    def push_body(i, carry):
        m = carry
        m = codec_push(m, data[n_data - i - 1])
        return m
    m = lax.fori_loop(0, n_data, push_body, m)
    coded_arr = rans.flatten(m)
    assert coded_arr.dtype == np.uint8

    # Decode
    def pop_body(i, carry):
        m, xs = carry
        m, x = codec_pop(m)
        return m, lax.dynamic_update_index_in_dim(xs, x, i, 0)
    m = rans.unflatten(coded_arr, size, tail_capacity)
    m, data_decoded = lax.fori_loop(0, n_data, pop_body,
                                    (m, jnp.zeros((n_data, size), 'int32')))
    assert rans.message_equal(m, m_init)
Ejemplo n.º 7
0
                             prior_precision, q_precision)
vae_pop = util.vae_pop(latent_shape, gen_net, rec_net, obs_pop,
                       prior_precision, q_precision)

# load some mnist images
mnist = datasets.MNIST('data/mnist',
                       train=False,
                       download=True,
                       transform=transforms.Compose([transforms.ToTensor()]))
images = mnist.test_data[:num_images]

images = [image.float().view(1, -1) for image in images]

# randomly generate some 'other' bits
other_bits = rng.randint(low=1 << 16, high=1 << 31, size=50, dtype=np.uint32)
state = rans.unflatten(other_bits)

print_interval = 10
encode_start_time = time.time()
for i, image in enumerate(images):
    state = vae_append(state, image)
    if i == 1:
        break
    '''
    if not i % print_interval:
        print('Encoded {}'.format(i))

    compressed_length = 32 * (len(rans.flatten(state)) - len(other_bits)) / (i+1)
    compress_lengths.append(compressed_length)

print('\nAll encoded in {:.2f}s'.format(time.time() - encode_start_time))
Ejemplo n.º 8
0
    # load some mnist images
    mnist = datasets.MNIST('data/mnist', train=False, download=True,
                           transform=transforms.Compose(
                               [transforms.ToTensor()]))
    images = mnist.test_data[:num_images]

    if randomise_data:
        images = Bernoulli(images.float() / 255.).sample()
    else:
        images = torch.round(images.float() / 255.)

    images = [image.view(-1) for image in images]

    # randomly generate some 'other' bits
    other_bits = rng.randint(low=1 << 16, high=1 << 31, size=20, dtype=np.uint32)
    state = rans.unflatten(other_bits)

    print_interval = 10
    encode_start_time = time.time()
    for i, image in enumerate(images):
        state = vae_append(state, image)

        if not i % print_interval:
            print('Encoded {}'.format(i))

        compressed_length = 32*(len(rans.flatten(state)) - len(other_bits)) / (i+1)
        compress_lengths.append(compressed_length)

    print('\nAll encoded in {:.2f}s'.format(time.time() - encode_start_time))
    compressed_message = rans.flatten(state)