def test_serial_resized(shape2, shape1=(5, ), precision=4): data1 = rng.randint(precision, size=(7,) + shape1, dtype="uint64") data2 = rng.randint(precision, size=(20,) + shape2, dtype="uint64") data = list(data1) + list(data2) codec = cs.Uniform(precision) push, pop = codec def push_resize(message, symbol): assert message[0].shape == shape2 message = cs.reshape_head(message, shape1) message = push(message, symbol) return message def pop_resize(message): assert message[0].shape == shape1 message, symbol = pop(message) message = cs.reshape_head(message, shape2) return message, symbol resize_codec = cs.Codec(push_resize, pop_resize) check_codec(shape2, cs.serial([codec for _ in data1[:-1]] + [resize_codec] + [codec for _ in data2]), data)
def test_flatten_rate(): n = 1000 init_data = np.random.randint(1 << 16, size=8 * n, dtype='uint64') init_message = cs.base_message((1, )) for datum in init_data: init_message = cs.Uniform(16).push(init_message, datum) l_init = len(cs.flatten(init_message)) ps = np.random.rand(n, 1) data = np.random.rand(n, 1) < ps message = init_message for p, datum in zip(ps, data): message = cs.Bernoulli(p, 14).push(message, datum) l_scalar = len(cs.flatten(message)) message = init_message message = cs.reshape_head(message, (n, 1)) message = cs.Bernoulli(ps, 14).push(message, data) l_vector = len(cs.flatten(message)) assert (l_vector - l_init) / (l_scalar - l_init) - 1 < 0.001
def test_serial(precision=16): shape = (2, 3, 4) data1 = rng.randint(precision, size=(7,) + shape, dtype="uint64") data2 = rng.randint(2 ** 31, 2 ** 63, size=(5,) + shape, dtype="uint64") data = list(data1) + list(data2) check_codec(shape, cs.serial([cs.Uniform(precision) for _ in data1] + [cs.Benford64 for _ in data2]), data)
def test_multiset_codec(): multiset = cs.build_multiset([0, 255, 128, 128]) ans_state = rans.base_message(shape=(1,)) symbol_codec = cs.Uniform(8) codec = cs.Multiset(symbol_codec) ans_state, = codec.push(ans_state, multiset) ans_state, multiset_decoded = codec.pop(ans_state, multiset_size=4) assert cs.check_multiset_equality(multiset, multiset_decoded)
def test_parallel(): precs = [1, 2, 4, 8, 16] szs = [2, 3, 4, 5, 6] u_codecs = [cs.Uniform(p) for p in precs] view_fun = lambda slc: lambda head: head[slc] view_funs = [] start = 0 for s in szs: view_funs.append(view_fun(slice(start, start + s))) start += s data = [rng.randint(1 << p, size=size, dtype='uint64') for p, size in zip(precs, szs)] check_codec(sum(szs), cs.parallel(u_codecs, view_funs), data)
def test_substack(): n_data = 100 prec = 4 head, tail = cs.base_message((4, 4)) head = np.split(head, 2) message = head, tail data = rng.randint(1 << prec, size=(n_data, 2, 4), dtype='uint64') view_fun = lambda h: h[0] append, pop = cs.substack(cs.repeat(cs.Uniform(prec), n_data), view_fun) message_ = append(message, data) np.testing.assert_array_equal(message_[0][1], message[0][1]) message_, data_ = pop(message_) np.testing.assert_equal(message, message_) np.testing.assert_equal(data, data_)
def test_reshape_head(old_shape, new_shape, depth=1000): np.random.seed(0) p = 8 bits = np.random.randint(1 << p, size=(depth,) + old_shape, dtype=np.uint64) message = cs.empty_message(old_shape) other_bits_push, _ = cs.repeat(cs.Uniform(p), depth) message = other_bits_push(message, bits) resized = cs.reshape_head(message, new_shape) reconstructed = cs.reshape_head(resized, old_shape) assert_message_equal(message, reconstructed)
def test_flatten_unflatten(shape, depth=1000): np.random.seed(0) p = 8 bits = np.random.randint(1 << p, size=(depth, ) + shape, dtype=np.uint64) message = cs.base_message(shape) other_bits_push, _ = cs.repeat(cs.Uniform(p), depth) message = other_bits_push(message, bits) flattened = cs.flatten(message) reconstructed = cs.unflatten(flattened, shape) assert_message_equal(message, reconstructed)
def rvae_variable_size_codec(codec_from_shape, latent_from_image_shape, image_count, dimensions=4, dimension_bits=16, previous_dims=0): size_codec = cs.repeat(cs.Uniform(dimension_bits), dimensions) def push(message, symbol): """push sizes and array in alternating order""" assert len(symbol.shape) == dimensions codec = codec_from_shape(symbol.shape) head_size = np.prod(latent_from_image_shape(symbol.shape)) + np.prod( symbol.shape) message = cs.reshape_head(message, (head_size, )) message = codec.push(message, symbol) message = cs.reshape_head(message, (1, )) message = size_codec.push(message, np.array(symbol.shape)) return message def pop(message): message, size = size_codec.pop(message) # TODO make codec 0 dimensional: size = np.array(size)[:, 0] assert size.shape == (dimensions, ) size = size.astype(np.int) head_size = np.prod(latent_from_image_shape(size)) + np.prod(size) codec = codec_from_shape(tuple(size)) message = cs.reshape_head(message, (head_size, )) message, symbol = codec.pop(message) message = cs.reshape_head(message, (1, )) return message, symbol return rvae_serial_with_progress([cs.Codec(push, pop)] * image_count, previous_dims)
def test_repeat(): precision = 4 n_data = 7 shape = (2, 3, 5) data = rng.randint(1 << precision, size=(n_data, ) + shape, dtype="uint64") check_codec(shape, cs.repeat(cs.Uniform(precision), n_data), data)
def test_uniform(): precision = 4 shape = (2, 3, 5) data = rng.randint(precision, size=shape, dtype="uint64") check_codec(shape, cs.Uniform(precision), data)
def ResNetVAE(up_pass, rec_net_top, rec_nets, gen_net_top, gen_nets, obs_codec, prior_prec, latent_prec): """ Codec for a ResNetVAE. Assume that the posterior is bidirectional - i.e. has a deterministic upper pass but top down sampling. Further assume that all latent conditionals are factorised Gaussians, both in the generative network p(z_n|z_{n-1}) and in the inference network q(z_n|x, z_{n-1}) Assume that everything is ordered bottom up """ z_view = lambda head: head[0] x_view = lambda head: head[1] prior_codec = cs.substack(cs.Uniform(prior_prec), z_view) def prior_push(message, latents): # push bottom-up latents, _ = latents for latent in latents: latent, _ = latent message = prior_codec.push(message, latent) return message def prior_pop(message): # pop top-down (prior_mean, prior_stdd), h_gen = gen_net_top() message, latent = prior_codec.pop(message) latents = [(latent, (prior_mean, prior_stdd))] for gen_net in reversed(gen_nets): previous_latent_val = prior_mean + cs.std_gaussian_centres( prior_prec)[latent] * prior_stdd (prior_mean, prior_stdd), h_gen = gen_net(h_gen, previous_latent_val) message, latent = prior_codec.pop(message) latents.append((latent, (prior_mean, prior_stdd))) return message, (latents[::-1], h_gen) def posterior(data): # run deterministic upper-pass contexts = up_pass(data) def posterior_push(message, latents): # first run the model top-down to get the params and latent vals latents, _ = latents (post_mean, post_stdd), h_rec = rec_net_top(contexts[-1]) post_params = [(post_mean, post_stdd)] for rec_net, latent, context in reversed( list(zip(rec_nets, latents[1:], contexts[:-1]))): previous_latent, (prior_mean, prior_stdd) = latent previous_latent_val = prior_mean + \ cs.std_gaussian_centres(prior_prec)[previous_latent] * prior_stdd (post_mean, post_stdd), h_rec = rec_net(h_rec, previous_latent_val, context) post_params.append((post_mean, post_stdd)) # now append bottom up for latent, post_param in zip(latents, reversed(post_params)): latent, (prior_mean, prior_stdd) = latent post_mean, post_stdd = post_param codec = cs.substack( cs.DiagGaussian_GaussianBins(post_mean, post_stdd, prior_mean, prior_stdd, latent_prec, prior_prec), z_view) message = codec.push(message, latent) return message def posterior_pop(message): # pop top-down (post_mean, post_stdd), h_rec = rec_net_top(contexts[-1]) (prior_mean, prior_stdd), h_gen = gen_net_top() codec = cs.substack( cs.DiagGaussian_GaussianBins(post_mean, post_stdd, prior_mean, prior_stdd, latent_prec, prior_prec), z_view) message, latent = codec.pop(message) latents = [(latent, (prior_mean, prior_stdd))] for rec_net, gen_net, context in reversed( list(zip(rec_nets, gen_nets, contexts[:-1]))): previous_latent_val = prior_mean + \ cs.std_gaussian_centres(prior_prec)[latents[-1][0]] * prior_stdd (post_mean, post_stdd), h_rec = rec_net(h_rec, previous_latent_val, context) (prior_mean, prior_stdd), h_gen = gen_net(h_gen, previous_latent_val) codec = cs.substack( cs.DiagGaussian_GaussianBins(post_mean, post_stdd, prior_mean, prior_stdd, latent_prec, prior_prec), z_view) message, latent = codec.pop(message) latents.append((latent, (prior_mean, prior_stdd))) return message, (latents[::-1], h_gen) return cs.Codec(posterior_push, posterior_pop) def likelihood(latents): # get the z1 vals to condition on latents, h = latents z1_idxs, (prior_mean, prior_stdd) = latents[0] z1_vals = prior_mean + cs.std_gaussian_centres( prior_prec)[z1_idxs] * prior_stdd return cs.substack(obs_codec(h, z1_vals), x_view) return BBANS(cs.Codec(prior_push, prior_pop), likelihood, posterior)
def gen_factory(): precs = (yield cs.Uniform(16)) yield cs.Uniform(precs)
def gen_factory(): for _ in range(7): yield cs.Uniform(precision) for _ in range(5): yield cs.Benford64
def gen_factory(): yield cs.Uniform(8)
def vae_view(head): return ag_tuple( (np.reshape(head[:latent_size], latent_shape), np.reshape(head[latent_size:], obs_shape))) vae_append, vae_pop = cs.repeat( cs.substack( bb_ans.VAE(gen_net, rec_net, obs_codec, prior_precision, q_precision), vae_view), num_batches) # Codec for adding extra bits to the start of the chain (necessary for bits # back). other_bits_append, _ = cs.substack(cs.Uniform(q_precision), lambda h: vae_view(h)[0]) ## Load mnist images images = datasets.MNIST('mnist', train=False, download=True).data.numpy() images = np.uint64(rng.random_sample(np.shape(images)) < images / 255.) images = np.split(np.reshape(images, (num_images, -1)), num_batches) ## Encode # Initialize message with some 'extra' bits encode_t0 = time.time() init_message = cs.base_message(obs_size + latent_size) # Enough bits to pop a single latent other_bits = rng.randint(1 << q_precision, size=latent_shape, dtype=np.uint64) init_message = other_bits_append(init_message, other_bits)
import subprocess import cv2 import numpy as np from rvae.datasets import test_image import craystack as cs encode_command = f'flif -e - - --effort=100 --no-metadata --no-color-profile --no-crc' decode_command = f'flif -d - -' codec = cs.Uniform(8) len_codec = cs.Uniform(31) def pop_varint(bytes): leading_bit = 1 i = 0 content = 0 mask = (1 << 7) - 1 while leading_bit: byte = bytes[i] i += 1 leading_bit = byte >> 7 byte_content = mask & byte content = (content << 7) + byte_content content += 1 # ?! return content, bytes[i:] def im_transform(image): return np.swapaxes(image, 0, 2)[:, :, ::-1]