Ejemplo n.º 1
0
def test_serial_resized(shape2, shape1=(5, ), precision=4):
    data1 = rng.randint(precision, size=(7,) + shape1, dtype="uint64")
    data2 = rng.randint(precision, size=(20,) + shape2, dtype="uint64")
    data = list(data1) + list(data2)

    codec = cs.Uniform(precision)
    push, pop = codec

    def push_resize(message, symbol):
        assert message[0].shape == shape2
        message = cs.reshape_head(message, shape1)
        message = push(message, symbol)
        return message

    def pop_resize(message):
        assert message[0].shape == shape1
        message, symbol = pop(message)
        message = cs.reshape_head(message, shape2)
        return message, symbol

    resize_codec = cs.Codec(push_resize, pop_resize)

    check_codec(shape2, cs.serial([codec for _ in data1[:-1]] +
                                  [resize_codec] +
                                  [codec for _ in data2]), data)
Ejemplo n.º 2
0
def test_flatten_rate():
    n = 1000

    init_data = np.random.randint(1 << 16, size=8 * n, dtype='uint64')

    init_message = cs.base_message((1, ))

    for datum in init_data:
        init_message = cs.Uniform(16).push(init_message, datum)

    l_init = len(cs.flatten(init_message))

    ps = np.random.rand(n, 1)
    data = np.random.rand(n, 1) < ps

    message = init_message
    for p, datum in zip(ps, data):
        message = cs.Bernoulli(p, 14).push(message, datum)

    l_scalar = len(cs.flatten(message))

    message = init_message
    message = cs.reshape_head(message, (n, 1))
    message = cs.Bernoulli(ps, 14).push(message, data)

    l_vector = len(cs.flatten(message))

    assert (l_vector - l_init) / (l_scalar - l_init) - 1 < 0.001
Ejemplo n.º 3
0
def test_serial(precision=16):
    shape = (2, 3, 4)
    data1 = rng.randint(precision, size=(7,) + shape, dtype="uint64")
    data2 = rng.randint(2 ** 31, 2 ** 63, size=(5,) + shape, dtype="uint64")
    data = list(data1) + list(data2)

    check_codec(shape, cs.serial([cs.Uniform(precision) for _ in data1] +
                                 [cs.Benford64 for _ in data2]), data)
Ejemplo n.º 4
0
def test_multiset_codec():
    multiset = cs.build_multiset([0, 255, 128, 128])

    ans_state = rans.base_message(shape=(1,))
    symbol_codec = cs.Uniform(8)
    codec = cs.Multiset(symbol_codec)

    ans_state, = codec.push(ans_state, multiset)
    ans_state, multiset_decoded = codec.pop(ans_state, multiset_size=4)

    assert cs.check_multiset_equality(multiset, multiset_decoded)
Ejemplo n.º 5
0
def test_parallel():
    precs = [1, 2, 4, 8, 16]
    szs = [2, 3, 4, 5, 6]
    u_codecs = [cs.Uniform(p) for p in precs]
    view_fun = lambda slc: lambda head: head[slc]
    view_funs = []
    start = 0
    for s in szs:
        view_funs.append(view_fun(slice(start, start + s)))
        start += s
    data = [rng.randint(1 << p, size=size, dtype='uint64')
            for p, size in zip(precs, szs)]
    check_codec(sum(szs), cs.parallel(u_codecs, view_funs), data)
Ejemplo n.º 6
0
def test_substack():
    n_data = 100
    prec = 4
    head, tail = cs.base_message((4, 4))
    head = np.split(head, 2)
    message = head, tail
    data = rng.randint(1 << prec, size=(n_data, 2, 4), dtype='uint64')
    view_fun = lambda h: h[0]
    append, pop = cs.substack(cs.repeat(cs.Uniform(prec), n_data), view_fun)
    message_ = append(message, data)
    np.testing.assert_array_equal(message_[0][1], message[0][1])
    message_, data_ = pop(message_)
    np.testing.assert_equal(message, message_)
    np.testing.assert_equal(data, data_)
Ejemplo n.º 7
0
def test_reshape_head(old_shape, new_shape, depth=1000):
    np.random.seed(0)
    p = 8
    bits = np.random.randint(1 << p, size=(depth,) + old_shape, dtype=np.uint64)

    message = cs.empty_message(old_shape)

    other_bits_push, _ = cs.repeat(cs.Uniform(p), depth)

    message = other_bits_push(message, bits)

    resized = cs.reshape_head(message, new_shape)
    reconstructed = cs.reshape_head(resized, old_shape)

    assert_message_equal(message, reconstructed)
Ejemplo n.º 8
0
def test_flatten_unflatten(shape, depth=1000):
    np.random.seed(0)
    p = 8
    bits = np.random.randint(1 << p, size=(depth, ) + shape, dtype=np.uint64)

    message = cs.base_message(shape)

    other_bits_push, _ = cs.repeat(cs.Uniform(p), depth)

    message = other_bits_push(message, bits)

    flattened = cs.flatten(message)
    reconstructed = cs.unflatten(flattened, shape)

    assert_message_equal(message, reconstructed)
Ejemplo n.º 9
0
def rvae_variable_size_codec(codec_from_shape,
                             latent_from_image_shape,
                             image_count,
                             dimensions=4,
                             dimension_bits=16,
                             previous_dims=0):
    size_codec = cs.repeat(cs.Uniform(dimension_bits), dimensions)

    def push(message, symbol):
        """push sizes and array in alternating order"""
        assert len(symbol.shape) == dimensions

        codec = codec_from_shape(symbol.shape)
        head_size = np.prod(latent_from_image_shape(symbol.shape)) + np.prod(
            symbol.shape)
        message = cs.reshape_head(message, (head_size, ))
        message = codec.push(message, symbol)
        message = cs.reshape_head(message, (1, ))
        message = size_codec.push(message, np.array(symbol.shape))
        return message

    def pop(message):
        message, size = size_codec.pop(message)
        # TODO make codec 0 dimensional:
        size = np.array(size)[:, 0]
        assert size.shape == (dimensions, )
        size = size.astype(np.int)
        head_size = np.prod(latent_from_image_shape(size)) + np.prod(size)
        codec = codec_from_shape(tuple(size))

        message = cs.reshape_head(message, (head_size, ))
        message, symbol = codec.pop(message)
        message = cs.reshape_head(message, (1, ))

        return message, symbol

    return rvae_serial_with_progress([cs.Codec(push, pop)] * image_count,
                                     previous_dims)
Ejemplo n.º 10
0
def test_repeat():
    precision = 4
    n_data = 7
    shape = (2, 3, 5)
    data = rng.randint(1 << precision, size=(n_data, ) + shape, dtype="uint64")
    check_codec(shape, cs.repeat(cs.Uniform(precision), n_data), data)
Ejemplo n.º 11
0
def test_uniform():
    precision = 4
    shape = (2, 3, 5)
    data = rng.randint(precision, size=shape, dtype="uint64")
    check_codec(shape, cs.Uniform(precision), data)
Ejemplo n.º 12
0
def ResNetVAE(up_pass, rec_net_top, rec_nets, gen_net_top, gen_nets, obs_codec,
              prior_prec, latent_prec):
    """
    Codec for a ResNetVAE.
    Assume that the posterior is bidirectional -
    i.e. has a deterministic upper pass but top down sampling.
    Further assume that all latent conditionals are factorised Gaussians,
    both in the generative network p(z_n|z_{n-1})
    and in the inference network q(z_n|x, z_{n-1})

    Assume that everything is ordered bottom up
    """
    z_view = lambda head: head[0]
    x_view = lambda head: head[1]

    prior_codec = cs.substack(cs.Uniform(prior_prec), z_view)

    def prior_push(message, latents):
        # push bottom-up
        latents, _ = latents
        for latent in latents:
            latent, _ = latent
            message = prior_codec.push(message, latent)
        return message

    def prior_pop(message):
        # pop top-down
        (prior_mean, prior_stdd), h_gen = gen_net_top()
        message, latent = prior_codec.pop(message)
        latents = [(latent, (prior_mean, prior_stdd))]
        for gen_net in reversed(gen_nets):
            previous_latent_val = prior_mean + cs.std_gaussian_centres(
                prior_prec)[latent] * prior_stdd
            (prior_mean, prior_stdd), h_gen = gen_net(h_gen,
                                                      previous_latent_val)
            message, latent = prior_codec.pop(message)
            latents.append((latent, (prior_mean, prior_stdd)))
        return message, (latents[::-1], h_gen)

    def posterior(data):
        # run deterministic upper-pass
        contexts = up_pass(data)

        def posterior_push(message, latents):
            # first run the model top-down to get the params and latent vals
            latents, _ = latents

            (post_mean, post_stdd), h_rec = rec_net_top(contexts[-1])
            post_params = [(post_mean, post_stdd)]

            for rec_net, latent, context in reversed(
                    list(zip(rec_nets, latents[1:], contexts[:-1]))):
                previous_latent, (prior_mean, prior_stdd) = latent
                previous_latent_val = prior_mean + \
                                      cs.std_gaussian_centres(prior_prec)[previous_latent] * prior_stdd

                (post_mean,
                 post_stdd), h_rec = rec_net(h_rec, previous_latent_val,
                                             context)
                post_params.append((post_mean, post_stdd))

            # now append bottom up
            for latent, post_param in zip(latents, reversed(post_params)):
                latent, (prior_mean, prior_stdd) = latent
                post_mean, post_stdd = post_param
                codec = cs.substack(
                    cs.DiagGaussian_GaussianBins(post_mean, post_stdd,
                                                 prior_mean, prior_stdd,
                                                 latent_prec, prior_prec),
                    z_view)
                message = codec.push(message, latent)
            return message

        def posterior_pop(message):
            # pop top-down
            (post_mean, post_stdd), h_rec = rec_net_top(contexts[-1])
            (prior_mean, prior_stdd), h_gen = gen_net_top()
            codec = cs.substack(
                cs.DiagGaussian_GaussianBins(post_mean, post_stdd, prior_mean,
                                             prior_stdd, latent_prec,
                                             prior_prec), z_view)
            message, latent = codec.pop(message)
            latents = [(latent, (prior_mean, prior_stdd))]
            for rec_net, gen_net, context in reversed(
                    list(zip(rec_nets, gen_nets, contexts[:-1]))):
                previous_latent_val = prior_mean + \
                                      cs.std_gaussian_centres(prior_prec)[latents[-1][0]] * prior_stdd

                (post_mean,
                 post_stdd), h_rec = rec_net(h_rec, previous_latent_val,
                                             context)
                (prior_mean,
                 prior_stdd), h_gen = gen_net(h_gen, previous_latent_val)
                codec = cs.substack(
                    cs.DiagGaussian_GaussianBins(post_mean, post_stdd,
                                                 prior_mean, prior_stdd,
                                                 latent_prec, prior_prec),
                    z_view)
                message, latent = codec.pop(message)
                latents.append((latent, (prior_mean, prior_stdd)))
            return message, (latents[::-1], h_gen)

        return cs.Codec(posterior_push, posterior_pop)

    def likelihood(latents):
        # get the z1 vals to condition on
        latents, h = latents
        z1_idxs, (prior_mean, prior_stdd) = latents[0]
        z1_vals = prior_mean + cs.std_gaussian_centres(
            prior_prec)[z1_idxs] * prior_stdd
        return cs.substack(obs_codec(h, z1_vals), x_view)

    return BBANS(cs.Codec(prior_push, prior_pop), likelihood, posterior)
Ejemplo n.º 13
0
 def gen_factory():
     precs = (yield cs.Uniform(16))
     yield cs.Uniform(precs)
Ejemplo n.º 14
0
 def gen_factory():
     for _ in range(7):
         yield cs.Uniform(precision)
     for _ in range(5):
         yield cs.Benford64
Ejemplo n.º 15
0
 def gen_factory():
     yield cs.Uniform(8)
Ejemplo n.º 16
0

def vae_view(head):
    return ag_tuple(
        (np.reshape(head[:latent_size],
                    latent_shape), np.reshape(head[latent_size:], obs_shape)))


vae_append, vae_pop = cs.repeat(
    cs.substack(
        bb_ans.VAE(gen_net, rec_net, obs_codec, prior_precision, q_precision),
        vae_view), num_batches)

# Codec for adding extra bits to the start of the chain (necessary for bits
# back).
other_bits_append, _ = cs.substack(cs.Uniform(q_precision),
                                   lambda h: vae_view(h)[0])

## Load mnist images
images = datasets.MNIST('mnist', train=False, download=True).data.numpy()
images = np.uint64(rng.random_sample(np.shape(images)) < images / 255.)
images = np.split(np.reshape(images, (num_images, -1)), num_batches)

## Encode
# Initialize message with some 'extra' bits
encode_t0 = time.time()
init_message = cs.base_message(obs_size + latent_size)

# Enough bits to pop a single latent
other_bits = rng.randint(1 << q_precision, size=latent_shape, dtype=np.uint64)
init_message = other_bits_append(init_message, other_bits)
Ejemplo n.º 17
0
import subprocess

import cv2
import numpy as np

from rvae.datasets import test_image
import craystack as cs

encode_command = f'flif -e - - --effort=100 --no-metadata --no-color-profile --no-crc'
decode_command = f'flif -d - -'
codec = cs.Uniform(8)
len_codec = cs.Uniform(31)


def pop_varint(bytes):
    leading_bit = 1
    i = 0
    content = 0
    mask = (1 << 7) - 1
    while leading_bit:
        byte = bytes[i]
        i += 1
        leading_bit = byte >> 7
        byte_content = mask & byte
        content = (content << 7) + byte_content
    content += 1  # ?!
    return content, bytes[i:]


def im_transform(image):
    return np.swapaxes(image, 0, 2)[:, :, ::-1]