Beispiel #1
0
 def reshape_pop(shape, message):
     head_shape = (np.prod(latent_from_image_shape(shape)) +
                   np.prod(shape), )
     message = cs.reshape_head(message, head_shape)
     codec = codec_from_image_shape(shape)
     message, symbol = codec.pop(message)
     return message, symbol
Beispiel #2
0
def rvae_variable_known_size_codec(codec_from_image_shape,
                                   latent_from_image_shape, image_shapes):
    image_codecs = [codec_from_image_shape(s) for s in image_shapes]
    head_shapes = [(np.prod(latent_from_image_shape(s)) + np.prod(s), )
                   for s in image_shapes]

    return serial_with_shapes(image_codecs, head_shapes)
Beispiel #3
0
    def push(message, symbol):
        """push sizes and array in alternating order"""
        assert len(symbol.shape) == dimensions

        codec = codec_from_shape(symbol.shape)
        head_size = np.prod(latent_from_image_shape(symbol.shape)) + np.prod(
            symbol.shape)
        message = cs.reshape_head(message, (head_size, ))
        message = codec.push(message, symbol)
        message = cs.reshape_head(message, (1, ))
        message = size_codec.push(message, np.array(symbol.shape))
        return message
Beispiel #4
0
    def pop(message):
        message, size = size_codec.pop(message)
        # TODO make codec 0 dimensional:
        size = np.array(size)[:, 0]
        assert size.shape == (dimensions, )
        size = size.astype(np.int)
        head_size = np.prod(latent_from_image_shape(size)) + np.prod(size)
        codec = codec_from_shape(tuple(size))

        message = cs.reshape_head(message, (head_size, ))
        message, symbol = codec.pop(message)
        message = cs.reshape_head(message, (1, ))

        return message, symbol
Beispiel #5
0
def run_bbans(hps):
    from autograd.builtins import tuple as ag_tuple
    from rvae.resnet_codec import ResNetVAE

    hps.num_gpus = 1
    hps.batch_size = 1
    batch_size = hps.batch_size
    hps.eval_batch_size = batch_size
    n_flif = hps.n_flif

    _, datasets = images(hps)
    datasets = datasets if isinstance(datasets, list) else [datasets]
    test_images = [
        np.array([image]).astype('uint64') for dataset in datasets
        for image in dataset
    ]
    n_batches = len(test_images) // batch_size
    test_images = [
        np.concatenate(test_images[i * batch_size:(i + 1) * batch_size],
                       axis=0) for i in range(n_batches)
    ]
    flif_images = test_images[:n_flif]
    vae_images = test_images[n_flif:]
    num_dims = np.sum([batch.size for batch in test_images])
    flif_dims = np.sum([batch.size
                        for batch in flif_images]) if flif_images else 0

    prior_precision = 10
    obs_precision = 24
    q_precision = 18

    @lru_cache(maxsize=1)
    def codec_from_shape(shape):
        print("Creating codec for shape " + str(shape))

        hps.image_size = (shape[2], shape[3])

        z_shape = latent_shape(hps)
        z_size = np.prod(z_shape)

        graph = tf.Graph()
        with graph.as_default():
            with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
                x = tf.placeholder(tf.float32, shape, 'x')
                model = CVAE1(hps, "eval", x)
                stepwise_model = LayerwiseCVAE(model)

        saver = tf.train.Saver(model.avg_dict)
        config = tf.ConfigProto(allow_soft_placement=True,
                                intra_op_parallelism_threads=4,
                                inter_op_parallelism_threads=4)
        sess = tf.Session(config=config, graph=graph)
        saver.restore(sess, restore_path())

        run_all_contexts, run_top_prior, runs_down_prior, run_top_posterior, runs_down_posterior, \
        run_reconstruction = stepwise_model.get_model_parts_as_numpy_functions(sess)

        # Setup codecs
        def vae_view(head):
            return ag_tuple(
                (np.reshape(head[:z_size],
                            z_shape), np.reshape(head[z_size:], shape)))

        obs_codec = lambda h, z1: cs.Logistic_UnifBins(*run_reconstruction(
            h, z1),
                                                       obs_precision,
                                                       bin_prec=8,
                                                       bin_lb=-0.5,
                                                       bin_ub=0.5)

        return cs.substack(
            ResNetVAE(run_all_contexts, run_top_posterior, runs_down_posterior,
                      run_top_prior, runs_down_prior, obs_codec,
                      prior_precision, q_precision), vae_view)

    is_fixed = not hps.compression_always_variable and \
               (len(set([dataset[0].shape[-2:] for dataset in datasets])) == 1)
    fixed_size_codec = lambda: cs.repeat(codec_from_shape(vae_images[0].shape),
                                         len(vae_images))
    variable_codec_including_sizes = lambda: rvae_variable_size_codec(
        codec_from_shape,
        latent_from_image_shape=latent_from_image_shape(hps),
        image_count=len(vae_images),
        previous_dims=flif_dims)
    variable_known_sizes_codec = lambda: rvae_variable_known_size_codec(
        codec_from_image_shape=codec_from_shape,
        latent_from_image_shape=latent_from_image_shape(hps),
        shapes=[i.shape for i in vae_images],
        previous_dims=flif_dims)
    variable_size_codec = \
        variable_known_sizes_codec if hps.compression_exclude_sizes else variable_codec_including_sizes
    codec = fixed_size_codec if is_fixed else variable_size_codec
    vae_push, vae_pop = codec()

    np.seterr(divide='raise')

    if n_flif:
        print('Using FLIF to encode initial images...')
        flif_push, flif_pop = cs.repeat(cs.repeat(FLIF, batch_size), n_flif)
        message = cs.empty_message((1, ))
        message = flif_push(message, flif_images)
    else:
        print('Creating a random initial message...')
        message = cs.random_message(hps.initial_bits, (1, ))

    init_head_shape = (np.prod(image_shape(hps)) +
                       np.prod(latent_shape(hps)) if is_fixed else 1, )
    message = cs.reshape_head(message, init_head_shape)

    print("Encoding with VAE...")
    encode_t0 = time.time()
    message = vae_push(message, vae_images)
    encode_t = time.time() - encode_t0
    print("All encoded in {:.2f}s".format(encode_t))

    flat_message = cs.flatten(message)
    message_len = 32 * len(flat_message)
    print("Used {} bits.".format(message_len))
    print("This is {:.2f} bits per dim.".format(message_len / num_dims))
    if n_flif == 0:
        extra_bits = message_len - 32 * hps.initial_bits
        print('Extra bits: {}'.format(extra_bits))
        print('This is {:.2f} bits per dim.'.format(extra_bits / num_dims))

    print('Decoding with VAE...')
    decode_t0 = time.time()
    message = cs.unflatten(flat_message, init_head_shape)
    message, decoded_vae_images = vae_pop(message)
    message = cs.reshape_head(message, (1, ))

    decode_t = time.time() - decode_t0
    print('All decoded in {:.2f}s'.format(decode_t))

    assert len(vae_images) == len(decoded_vae_images), (
        len(vae_images), len(decoded_vae_images))
    for test_image, decoded_image in zip(vae_images, decoded_vae_images):
        np.testing.assert_equal(test_image, decoded_image)

    if n_flif:
        print('Decoding with FLIF...')
        message, decoded_flif_images = flif_pop(message)
        for test_image, decoded_image in zip(flif_images, decoded_flif_images):
            np.testing.assert_equal(test_image, decoded_image)
        assert cs.is_empty(message)
Beispiel #6
0
def run_bbans(hps):
    from autograd.builtins import tuple as ag_tuple
    from rvae.resnet_codec import ResNetVAE

    hps.num_gpus = 1
    batch_size = 1
    hps.batch_size = batch_size
    hps.eval_batch_size = batch_size

    _, datasets = images(hps)
    datasets = datasets if isinstance(datasets, list) else [datasets]
    test_images = [
        np.array([image]).astype('uint64') for dataset in datasets
        for image in dataset
    ]

    num_dims = np.sum([batch.size for batch in test_images])

    prior_precision = 10
    obs_precision = 24
    q_precision = 20

    @lru_cache()
    def codec_from_shape(shape):
        print("Creating codec for shape " + str(shape))

        hps.image_size = (shape[2], shape[3])

        z_shape = latent_shape(hps)
        z_size = np.prod(z_shape)

        graph = tf.Graph()
        with graph.as_default():
            with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
                x = tf.placeholder(tf.float32, shape, 'x')
                model = CVAE1(hps, "eval", x)
                stepwise_model = LayerwiseCVAE(model)

        saver = tf.train.Saver(model.avg_dict)
        config = tf.ConfigProto(allow_soft_placement=True,
                                intra_op_parallelism_threads=4,
                                inter_op_parallelism_threads=4,
                                device_count={'GPU': 0})
        sess = tf.Session(config=config, graph=graph)
        saver.restore(sess, restore_path())

        run_all_contexts, run_top_prior, runs_down_prior, run_top_posterior, runs_down_posterior, \
        run_reconstruction = stepwise_model.get_model_parts_as_numpy_functions(sess)

        # Setup codecs
        def vae_view(head):
            return ag_tuple(
                (np.reshape(head[:z_size],
                            z_shape), np.reshape(head[z_size:], shape)))

        obs_codec = lambda h, z1: codecs.Logistic_UnifBins(
            *run_reconstruction(h, z1), obs_precision, bin_prec=8)

        return codecs.substack(
            ResNetVAE(run_all_contexts, run_top_posterior, runs_down_posterior,
                      run_top_prior, runs_down_prior, obs_codec,
                      prior_precision, q_precision), vae_view)

    is_fixed = not hps.compression_always_variable and \
               (len(set([dataset[0].shape[-2:] for dataset in datasets])) == 1)
    fixed_size_codec = lambda: codecs.repeat(
        codec_from_shape(test_images[0].shape), len(test_images))
    variable_codec_including_sizes = lambda: rvae_variable_size_codec(
        codec_from_shape,
        latent_from_image_shape=latent_from_image_shape(hps),
        image_count=len(test_images))
    variable_known_sizes_codec = lambda: rvae_variable_known_size_codec(
        codec_from_image_shape=codec_from_shape,
        latent_from_image_shape=latent_from_image_shape(hps),
        image_shapes=[i.shape for i in test_images])
    variable_size_codec = \
        variable_known_sizes_codec if hps.compression_exclude_sizes else variable_codec_including_sizes
    codec = fixed_size_codec if is_fixed else variable_size_codec
    vae_append, vae_pop = codec()

    rng = np.random.RandomState(0)

    np.seterr(divide='raise')

    # Codec for adding extra bits to the start of the chain (necessary for bits back).
    other_bits_count = 100000000

    encode_t0 = time.time()
    init_head_shape = (np.prod(image_shape(hps)) +
                       np.prod(latent_shape(hps)) if is_fixed else 1, )
    init_message = codecs.random_stack(other_bits_count, init_head_shape, rng)
    init_len = 32 * other_bits_count

    print("Encoding...")
    message = vae_append(init_message, test_images)
    flat_message = codecs.flatten(message)
    encode_t = time.time() - encode_t0

    print("All encoded in {:.2f}s".format(encode_t))

    message_len = 32 * len(flat_message)
    print("Used {} bits.".format(message_len))
    print("This is {:.2f} bits per dim.".format(message_len / num_dims))
    print('Extra bits per dim: {:.2f}'.format(
        (message_len - init_len) / num_dims))
    print('Extra bits: {:.2f}'.format(message_len - init_len))

    ## Decode
    decode_t0 = time.time()
    message = codecs.unflatten(flat_message, init_head_shape)
    message, images_ = vae_pop(message)
    decode_t = time.time() - decode_t0

    print('All decoded in {:.2f}s'.format(decode_t))

    assert len(test_images) == len(images_), (len(test_images), len(images_))
    for test_image, image in zip(test_images, images_):
        np.testing.assert_equal(test_image, image)
    init_head, init_tail = init_message

    message = codecs.reshape_head(message, init_head_shape)
    head, tail = message
    assert init_head.shape == head.shape, (init_head.shape, head.shape)
    assert np.all(init_head == head)

    # use this, or get into recursion issues
    while init_tail:
        el, init_tail = init_tail
        el_, tail = tail
        assert el == el_