Exemplo n.º 1
0
def test_serial_resized(shape2, shape1=(5, ), precision=4):
    data1 = rng.randint(precision, size=(7,) + shape1, dtype="uint64")
    data2 = rng.randint(precision, size=(20,) + shape2, dtype="uint64")
    data = list(data1) + list(data2)

    codec = cs.Uniform(precision)
    push, pop = codec

    def push_resize(message, symbol):
        assert message[0].shape == shape2
        message = cs.reshape_head(message, shape1)
        message = push(message, symbol)
        return message

    def pop_resize(message):
        assert message[0].shape == shape1
        message, symbol = pop(message)
        message = cs.reshape_head(message, shape2)
        return message, symbol

    resize_codec = cs.Codec(push_resize, pop_resize)

    check_codec(shape2, cs.serial([codec for _ in data1[:-1]] +
                                  [resize_codec] +
                                  [codec for _ in data2]), data)
Exemplo n.º 2
0
def rvae_variable_known_size_codec(codec_from_image_shape,
                                   latent_from_image_shape, shapes,
                                   previous_dims):
    """
    Applies given codecs in series on a sequence of symbols requiring various ANS stack head shapes.
    The head shape required for each symbol is given through shapes.
    """
    def reshape_push(shape, message, symbol):
        head_shape = (np.prod(latent_from_image_shape(shape)) +
                      np.prod(shape), )
        message = cs.reshape_head(message, head_shape)
        codec = codec_from_image_shape(shape)
        message = codec.push(message, symbol)
        return message

    def reshape_pop(shape, message):
        head_shape = (np.prod(latent_from_image_shape(shape)) +
                      np.prod(shape), )
        message = cs.reshape_head(message, head_shape)
        codec = codec_from_image_shape(shape)
        message, symbol = codec.pop(message)
        return message, symbol

    return rvae_serial_with_progress([
        cs.Codec(partial(reshape_push, shape), partial(reshape_pop, shape))
        for shape in shapes
    ], previous_dims)
def ArrayCodec(codec):
    def push(message, x):
        return codec.push(message, x.arr)

    def pop(message):
        message, x = codec.pop(message)
        return message, ArraySymbol(x)

    return cs.Codec(push, pop)
Exemplo n.º 4
0
def FLIF():
    def push(message, image):
        """expects image to be chw"""
        image = im_transform(image).astype(np.uint8)
        success, im_buffer = cv2.imencode(".ppm", image)
        process = subprocess.run(encode_command.split(),
                                 input=im_buffer.tobytes(),
                                 capture_output=True)
        if process.returncode != 0:
            raise Exception(f"flif encode failed: {process.stderr}")
        compressed_bytes = process.stdout

        # take off the 'FLIF' magic header
        compressed_bytes = compressed_bytes[4:]
        # can also remove RGB interlaced byte and bytes per chan (next two bytes)
        # then there are 3 varints for width, height and number of frames
        # https://flif.info/spec.html for details

        compressed_bits = list(compressed_bytes)  # list of uint8s
        n_compressed_bits = len(compressed_bits)
        cbits_codec = cs.repeat(codec, n_compressed_bits)
        message = cbits_codec.push(message, compressed_bits)
        message = len_codec.push(message, np.uint64(n_compressed_bits))
        return message

    def pop(message):
        message, n_compressed_bits = len_codec.pop(message)
        cbits_codec = cs.repeat(codec, n_compressed_bits[0])
        message, compressed_bits = cbits_codec.pop(message)
        compressed_bits = np.squeeze(compressed_bits).astype(np.uint8)
        bytes_buffer = b'FLIF' + bytes(compressed_bits)
        process = subprocess.run(decode_command.split(),
                                 input=bytes_buffer,
                                 capture_output=True)
        if process.returncode != 0:
            raise Exception(f"flif decode failed: {process.stderr}")
        im_buffer = np.frombuffer(process.stdout, dtype=np.uint8)
        image = cv2.imdecode(im_buffer, flags=1)  # this gives in hwc
        return message, inverse_im_transform(image)

    return cs.Codec(push, pop)
Exemplo n.º 5
0
def rvae_serial_with_progress(codecs, previous_dims):
    def push(message, symbols):
        init_len = 32 * len(cs.flatten(message))
        t_start = time.time()
        dims = previous_dims

        for i, (codec,
                symbol) in enumerate(reversed(list(zip(codecs, symbols)))):
            t0 = time.time()
            message = codec.push(message, symbol)
            dims += symbol.size
            flat_message = cs.flatten(message)
            print(
                f"Encoded {i+1}/{len(symbols)}[{(i+1)/float(len(symbols))*100:.0f}%], "
                f"message length: {len(flat_message) * (4/1024):.0f}kB, "
                f"bpd: {32 * len(flat_message) / float(dims):.2f}, "
                f"net bitrate: {(32 * len(flat_message) - init_len) / (float(dims - previous_dims)):.2f}, "
                f"net dims: {dims - previous_dims}, "
                f"net bits: {32 * len(flat_message) - init_len}, "
                f"iter time: {time.time() - t0:.2f}s, "
                f"total time: {time.time() - t_start:.2f}s, "
                f"symbol shape: {symbol.shape}, "
                f"message length: {len(flat_message) * (4/1024):.0f}kB, "
                f"bpd: {32 * len(flat_message) / float(dims):.2f}")
        return message

    def pop(message):
        symbols = []
        t_start = time.time()
        for i, codec in enumerate(codecs):
            t0 = time.time()
            message, symbol = codec.pop(message)
            symbols.append(symbol)
            print(
                f"Decoded {i+1}/{len(symbols)}[{(i+1)/float(len(codecs))*100:.0f}%], "
                f"iter time: {time.time() - t0:.2f}s, "
                f"total time: {time.time() - t_start:.2f}s")
        return message, symbols

    return cs.Codec(push, pop)
Exemplo n.º 6
0
def SSM(priors, likelihoods, posterior):
    # priors = [p(z_1), p(z_2 | z_1), ..., p(z_T | z_{T-1})]
    # likelihoods = [p(x_1 | z_1), ..., p(x_T | z_T)]
    # [lambda z_{t+1}: Q(z_t | x_{1:t}, z_{t+1}) for t in range(T)]
    post_init_state, post_update = posterior

    def push(message, xs):
        post_codecs = []
        post_state = post_init_state
        for t, x in enumerate(xs):  # Forward inference pass
            post_state, post_codec = post_update(t, post_state, x)
            post_codecs.append(post_codec)
        message, z_next = post_codecs[-1].pop(message)
        for t in range(len(priors) - 1, 0, -1):  # Backward encoding pass
            message = likelihoods[t](z_next).push(message, xs[t])
            message, z = post_codecs[t - 1](z_next).pop(message)
            message = priors[t](z).push(message, z_next)
            z_next = z
        message = likelihoods[0](z_next).push(message, xs[0])
        message = priors[0].push(message, z_next)
        return message

    def pop(message):
        xs = []
        message, z_next = priors[0].pop(message)
        message, x = likelihoods[0](z_next).pop(message)
        post_state, post_codec = post_update(0, post_init_state, x)
        xs.append(x)
        for t in range(1, len(priors)):  # Forward decoding pass
            z = z_next
            message, z_next = priors[t](z).pop(message)
            message = post_codec(z_next).push(message, z)
            message, x = likelihoods[t](z_next).pop(message)
            post_state, post_codec = post_update(t, post_state, x)
            xs.append(x)
        message = post_codec.push(message, z_next)
        return message, xs

    return cs.Codec(push, pop)
Exemplo n.º 7
0
def rvae_variable_size_codec(codec_from_shape,
                             latent_from_image_shape,
                             image_count,
                             dimensions=4,
                             dimension_bits=16,
                             previous_dims=0):
    size_codec = cs.repeat(cs.Uniform(dimension_bits), dimensions)

    def push(message, symbol):
        """push sizes and array in alternating order"""
        assert len(symbol.shape) == dimensions

        codec = codec_from_shape(symbol.shape)
        head_size = np.prod(latent_from_image_shape(symbol.shape)) + np.prod(
            symbol.shape)
        message = cs.reshape_head(message, (head_size, ))
        message = codec.push(message, symbol)
        message = cs.reshape_head(message, (1, ))
        message = size_codec.push(message, np.array(symbol.shape))
        return message

    def pop(message):
        message, size = size_codec.pop(message)
        # TODO make codec 0 dimensional:
        size = np.array(size)[:, 0]
        assert size.shape == (dimensions, )
        size = size.astype(np.int)
        head_size = np.prod(latent_from_image_shape(size)) + np.prod(size)
        codec = codec_from_shape(tuple(size))

        message = cs.reshape_head(message, (head_size, ))
        message, symbol = codec.pop(message)
        message = cs.reshape_head(message, (1, ))

        return message, symbol

    return rvae_serial_with_progress([cs.Codec(push, pop)] * image_count,
                                     previous_dims)
Exemplo n.º 8
0
    def posterior(data):
        # run deterministic upper-pass
        contexts = up_pass(data)

        def posterior_push(message, latents):
            # first run the model top-down to get the params and latent vals
            latents, _ = latents

            (post_mean, post_stdd), h_rec = rec_net_top(contexts[-1])
            post_params = [(post_mean, post_stdd)]

            for rec_net, latent, context in reversed(
                    list(zip(rec_nets, latents[1:], contexts[:-1]))):
                previous_latent, (prior_mean, prior_stdd) = latent
                previous_latent_val = prior_mean + \
                                      cs.std_gaussian_centres(prior_prec)[previous_latent] * prior_stdd

                (post_mean,
                 post_stdd), h_rec = rec_net(h_rec, previous_latent_val,
                                             context)
                post_params.append((post_mean, post_stdd))

            # now append bottom up
            for latent, post_param in zip(latents, reversed(post_params)):
                latent, (prior_mean, prior_stdd) = latent
                post_mean, post_stdd = post_param
                codec = cs.substack(
                    cs.DiagGaussian_GaussianBins(post_mean, post_stdd,
                                                 prior_mean, prior_stdd,
                                                 latent_prec, prior_prec),
                    z_view)
                message = codec.push(message, latent)
            return message

        def posterior_pop(message):
            # pop top-down
            (post_mean, post_stdd), h_rec = rec_net_top(contexts[-1])
            (prior_mean, prior_stdd), h_gen = gen_net_top()
            codec = cs.substack(
                cs.DiagGaussian_GaussianBins(post_mean, post_stdd, prior_mean,
                                             prior_stdd, latent_prec,
                                             prior_prec), z_view)
            message, latent = codec.pop(message)
            latents = [(latent, (prior_mean, prior_stdd))]
            for rec_net, gen_net, context in reversed(
                    list(zip(rec_nets, gen_nets, contexts[:-1]))):
                previous_latent_val = prior_mean + \
                                      cs.std_gaussian_centres(prior_prec)[latents[-1][0]] * prior_stdd

                (post_mean,
                 post_stdd), h_rec = rec_net(h_rec, previous_latent_val,
                                             context)
                (prior_mean,
                 prior_stdd), h_gen = gen_net(h_gen, previous_latent_val)
                codec = cs.substack(
                    cs.DiagGaussian_GaussianBins(post_mean, post_stdd,
                                                 prior_mean, prior_stdd,
                                                 latent_prec, prior_prec),
                    z_view)
                message, latent = codec.pop(message)
                latents.append((latent, (prior_mean, prior_stdd)))
            return message, (latents[::-1], h_gen)

        return cs.Codec(posterior_push, posterior_pop)
Exemplo n.º 9
0
def ResNetVAE(up_pass, rec_net_top, rec_nets, gen_net_top, gen_nets, obs_codec,
              prior_prec, latent_prec):
    """
    Codec for a ResNetVAE.
    Assume that the posterior is bidirectional -
    i.e. has a deterministic upper pass but top down sampling.
    Further assume that all latent conditionals are factorised Gaussians,
    both in the generative network p(z_n|z_{n-1})
    and in the inference network q(z_n|x, z_{n-1})

    Assume that everything is ordered bottom up
    """
    z_view = lambda head: head[0]
    x_view = lambda head: head[1]

    prior_codec = cs.substack(cs.Uniform(prior_prec), z_view)

    def prior_push(message, latents):
        # push bottom-up
        latents, _ = latents
        for latent in latents:
            latent, _ = latent
            message = prior_codec.push(message, latent)
        return message

    def prior_pop(message):
        # pop top-down
        (prior_mean, prior_stdd), h_gen = gen_net_top()
        message, latent = prior_codec.pop(message)
        latents = [(latent, (prior_mean, prior_stdd))]
        for gen_net in reversed(gen_nets):
            previous_latent_val = prior_mean + cs.std_gaussian_centres(
                prior_prec)[latent] * prior_stdd
            (prior_mean, prior_stdd), h_gen = gen_net(h_gen,
                                                      previous_latent_val)
            message, latent = prior_codec.pop(message)
            latents.append((latent, (prior_mean, prior_stdd)))
        return message, (latents[::-1], h_gen)

    def posterior(data):
        # run deterministic upper-pass
        contexts = up_pass(data)

        def posterior_push(message, latents):
            # first run the model top-down to get the params and latent vals
            latents, _ = latents

            (post_mean, post_stdd), h_rec = rec_net_top(contexts[-1])
            post_params = [(post_mean, post_stdd)]

            for rec_net, latent, context in reversed(
                    list(zip(rec_nets, latents[1:], contexts[:-1]))):
                previous_latent, (prior_mean, prior_stdd) = latent
                previous_latent_val = prior_mean + \
                                      cs.std_gaussian_centres(prior_prec)[previous_latent] * prior_stdd

                (post_mean,
                 post_stdd), h_rec = rec_net(h_rec, previous_latent_val,
                                             context)
                post_params.append((post_mean, post_stdd))

            # now append bottom up
            for latent, post_param in zip(latents, reversed(post_params)):
                latent, (prior_mean, prior_stdd) = latent
                post_mean, post_stdd = post_param
                codec = cs.substack(
                    cs.DiagGaussian_GaussianBins(post_mean, post_stdd,
                                                 prior_mean, prior_stdd,
                                                 latent_prec, prior_prec),
                    z_view)
                message = codec.push(message, latent)
            return message

        def posterior_pop(message):
            # pop top-down
            (post_mean, post_stdd), h_rec = rec_net_top(contexts[-1])
            (prior_mean, prior_stdd), h_gen = gen_net_top()
            codec = cs.substack(
                cs.DiagGaussian_GaussianBins(post_mean, post_stdd, prior_mean,
                                             prior_stdd, latent_prec,
                                             prior_prec), z_view)
            message, latent = codec.pop(message)
            latents = [(latent, (prior_mean, prior_stdd))]
            for rec_net, gen_net, context in reversed(
                    list(zip(rec_nets, gen_nets, contexts[:-1]))):
                previous_latent_val = prior_mean + \
                                      cs.std_gaussian_centres(prior_prec)[latents[-1][0]] * prior_stdd

                (post_mean,
                 post_stdd), h_rec = rec_net(h_rec, previous_latent_val,
                                             context)
                (prior_mean,
                 prior_stdd), h_gen = gen_net(h_gen, previous_latent_val)
                codec = cs.substack(
                    cs.DiagGaussian_GaussianBins(post_mean, post_stdd,
                                                 prior_mean, prior_stdd,
                                                 latent_prec, prior_prec),
                    z_view)
                message, latent = codec.pop(message)
                latents.append((latent, (prior_mean, prior_stdd)))
            return message, (latents[::-1], h_gen)

        return cs.Codec(posterior_push, posterior_pop)

    def likelihood(latents):
        # get the z1 vals to condition on
        latents, h = latents
        z1_idxs, (prior_mean, prior_stdd) = latents[0]
        z1_vals = prior_mean + cs.std_gaussian_centres(
            prior_prec)[z1_idxs] * prior_stdd
        return cs.substack(obs_codec(h, z1_vals), x_view)

    return BBANS(cs.Codec(prior_push, prior_pop), likelihood, posterior)