Ejemplo n.º 1
0
def test_flatten_rate():
    n = 1000

    init_data = np.random.randint(1 << 16, size=8 * n, dtype='uint64')

    init_message = cs.base_message((1, ))

    for datum in init_data:
        init_message = cs.Uniform(16).push(init_message, datum)

    l_init = len(cs.flatten(init_message))

    ps = np.random.rand(n, 1)
    data = np.random.rand(n, 1) < ps

    message = init_message
    for p, datum in zip(ps, data):
        message = cs.Bernoulli(p, 14).push(message, datum)

    l_scalar = len(cs.flatten(message))

    message = init_message
    message = cs.reshape_head(message, (n, 1))
    message = cs.Bernoulli(ps, 14).push(message, data)

    l_vector = len(cs.flatten(message))

    assert (l_vector - l_init) / (l_scalar - l_init) - 1 < 0.001
Ejemplo n.º 2
0
def test_flatten_unflatten():
    n = 100
    shape = (7, 3)
    p = 12
    state = cs.base_message(shape)
    some_bits = rng.randint(1 << p, size=(n, ) + shape).astype(np.uint64)
    freqs = np.ones(shape, dtype="uint64")
    for b in some_bits:
        state = cs.rans.push(state, b, freqs, p)
    flat = cs.flatten(state)
    flat_ = cs.rans.flatten(state)
    print('Normal flat len: {}'.format(len(flat_) * 32))
    print('Benford flat len: {}'.format(len(flat) * 32))
    assert flat.dtype is np.dtype("uint32")
    state_ = cs.unflatten(flat, shape)
    flat_ = cs.flatten(state_)
    assert np.all(flat == flat_)
    assert np.all(state[0] == state_[0])
    assert state[1] == state_[1]
Ejemplo n.º 3
0
def test_flatten_unflatten(shape, depth=1000):
    np.random.seed(0)
    p = 8
    bits = np.random.randint(1 << p, size=(depth, ) + shape, dtype=np.uint64)

    message = cs.base_message(shape)

    other_bits_push, _ = cs.repeat(cs.Uniform(p), depth)

    message = other_bits_push(message, bits)

    flattened = cs.flatten(message)
    reconstructed = cs.unflatten(flattened, shape)

    assert_message_equal(message, reconstructed)
Ejemplo n.º 4
0
    def push(message, symbols):
        init_len = 32 * len(cs.flatten(message))
        t_start = time.time()
        dims = previous_dims

        for i, (codec,
                symbol) in enumerate(reversed(list(zip(codecs, symbols)))):
            t0 = time.time()
            message = codec.push(message, symbol)
            dims += symbol.size
            flat_message = cs.flatten(message)
            print(
                f"Encoded {i+1}/{len(symbols)}[{(i+1)/float(len(symbols))*100:.0f}%], "
                f"message length: {len(flat_message) * (4/1024):.0f}kB, "
                f"bpd: {32 * len(flat_message) / float(dims):.2f}, "
                f"net bitrate: {(32 * len(flat_message) - init_len) / (float(dims - previous_dims)):.2f}, "
                f"net dims: {dims - previous_dims}, "
                f"net bits: {32 * len(flat_message) - init_len}, "
                f"iter time: {time.time() - t0:.2f}s, "
                f"total time: {time.time() - t_start:.2f}s, "
                f"symbol shape: {symbol.shape}, "
                f"message length: {len(flat_message) * (4/1024):.0f}kB, "
                f"bpd: {32 * len(flat_message) / float(dims):.2f}")
        return message
Ejemplo n.º 5
0
    print('Hyperparameter settings:')
    print(h)
    rng = random.default_rng(1)
    params = hmm_codec.hmm_params_sample(h, rng)
    xs = hmm_codec.hmm_sample(h, params, rng)
    lengths = []
    hs = []
    message_lengths = []
    l = 4
    while l <= h.T:
        codec = hmm_codec.hmm_codec(replace(h, T=l), params)
        lengths.append(l)
        hs.append(-hmm_codec.hmm_logpmf(h, params, xs[:l]) / l)
        message = cs.base_message(1)
        message = codec.push(message, xs[:l])
        message_lengths.append(len(cs.flatten(message)) * 32 / l)
        l = 2 * l

    fig, ax = plt.subplots(figsize=[2.7, 1.8])
    ax.plot(lengths, np.divide(message_lengths, hs), color='black', lw=.5)
    ax.set_yscale('log')
    ax.yaxis.set_minor_locator(ticker.FixedLocator([1., 2., 3., 4.]))
    ax.yaxis.set_major_locator(ticker.NullLocator())
    ax.yaxis.set_minor_formatter(ticker.ScalarFormatter())

    ax.hlines(y=1, xmin=0, xmax=h.T + 1, color='gray', lw=.5)
    ax.set_xlim(0, h.T)
    ax.set_xlabel('$T$')
    ax.set_ylabel('$l(m)/h(x)$')
    ax.spines['top'].set_visible(False)
    ax.spines['bottom'].set_linewidth(0.5)
Ejemplo n.º 6
0
def compress_samples(model, hparams, step=tf.constant(0), decode=False):

    model.set_compression()

    test_set = utils.load_training_files_tfrecords(record_pattern=os.path.join(
        hparams['tfrecords_dir'], hparams['train_files'] + '*'))

    datapoints = list(test_set.unbatch().batch(
        hparams['compress_batch_size']).take(hparams['n_compress_datapoint']))

    num_pixels = hparams['n_compress_datapoint'] * hparams[
        'compress_batch_size'] * hparams['segment_length']

    ## Load Codec
    waveglow_append, waveglow_pop = cs.repeat(
        Waveglow_codec(model=model, hparams=hparams),
        hparams['n_compress_datapoint'])

    ## Encode
    encode_t0 = time.time()
    init_message = cs.empty_message(shape=(hparams['compress_batch_size'],
                                           hparams['segment_length'] // 4))

    # Encode the audio samples
    message = waveglow_append(init_message, datapoints)

    flat_message = cs.flatten(message)
    encode_t = time.time() - encode_t0

    tf.print("All encoded in {:.2f}s.".format(encode_t))

    original_len = 16 * hparams['n_compress_datapoint'] * hparams[
        'segment_length']
    message_len = 32 * len(flat_message)
    tf.print("Used {} bits.".format(message_len))
    tf.print("This is {:.2f} bits per pixel.".format(message_len / num_pixels))
    tf.print("Compression ratio : {:.2f}".format(original_len / message_len))

    tf.summary.scalar(name='bits_per_dim',
                      data=message_len / num_pixels,
                      step=step)
    tf.summary.scalar(name='compression_ratio',
                      data=original_len / message_len,
                      step=step)

    if decode:
        ## Decode
        decode_t0 = time.time()
        message = cs.unflatten(flat_message,
                               shape=(hparams['compress_batch_size'],
                                      hparams['segment_length'] // 4))

        message, datapoints_ = waveglow_pop(message)
        decode_t = time.time() - decode_t0

        print('All decoded in {:.2f}s.'.format(decode_t))

        datacompare = [
            data['wav'].numpy()[..., np.newaxis] for data in datapoints
        ]
        np.testing.assert_equal(datacompare, datapoints_)
        np.testing.assert_equal(message, init_message)

    model.set_training()
Ejemplo n.º 7
0
def run_bbans(hps):
    from autograd.builtins import tuple as ag_tuple
    from rvae.resnet_codec import ResNetVAE

    hps.num_gpus = 1
    hps.batch_size = 1
    batch_size = hps.batch_size
    hps.eval_batch_size = batch_size
    n_flif = hps.n_flif

    _, datasets = images(hps)
    datasets = datasets if isinstance(datasets, list) else [datasets]
    test_images = [
        np.array([image]).astype('uint64') for dataset in datasets
        for image in dataset
    ]
    n_batches = len(test_images) // batch_size
    test_images = [
        np.concatenate(test_images[i * batch_size:(i + 1) * batch_size],
                       axis=0) for i in range(n_batches)
    ]
    flif_images = test_images[:n_flif]
    vae_images = test_images[n_flif:]
    num_dims = np.sum([batch.size for batch in test_images])
    flif_dims = np.sum([batch.size
                        for batch in flif_images]) if flif_images else 0

    prior_precision = 10
    obs_precision = 24
    q_precision = 18

    @lru_cache(maxsize=1)
    def codec_from_shape(shape):
        print("Creating codec for shape " + str(shape))

        hps.image_size = (shape[2], shape[3])

        z_shape = latent_shape(hps)
        z_size = np.prod(z_shape)

        graph = tf.Graph()
        with graph.as_default():
            with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
                x = tf.placeholder(tf.float32, shape, 'x')
                model = CVAE1(hps, "eval", x)
                stepwise_model = LayerwiseCVAE(model)

        saver = tf.train.Saver(model.avg_dict)
        config = tf.ConfigProto(allow_soft_placement=True,
                                intra_op_parallelism_threads=4,
                                inter_op_parallelism_threads=4)
        sess = tf.Session(config=config, graph=graph)
        saver.restore(sess, restore_path())

        run_all_contexts, run_top_prior, runs_down_prior, run_top_posterior, runs_down_posterior, \
        run_reconstruction = stepwise_model.get_model_parts_as_numpy_functions(sess)

        # Setup codecs
        def vae_view(head):
            return ag_tuple(
                (np.reshape(head[:z_size],
                            z_shape), np.reshape(head[z_size:], shape)))

        obs_codec = lambda h, z1: cs.Logistic_UnifBins(*run_reconstruction(
            h, z1),
                                                       obs_precision,
                                                       bin_prec=8,
                                                       bin_lb=-0.5,
                                                       bin_ub=0.5)

        return cs.substack(
            ResNetVAE(run_all_contexts, run_top_posterior, runs_down_posterior,
                      run_top_prior, runs_down_prior, obs_codec,
                      prior_precision, q_precision), vae_view)

    is_fixed = not hps.compression_always_variable and \
               (len(set([dataset[0].shape[-2:] for dataset in datasets])) == 1)
    fixed_size_codec = lambda: cs.repeat(codec_from_shape(vae_images[0].shape),
                                         len(vae_images))
    variable_codec_including_sizes = lambda: rvae_variable_size_codec(
        codec_from_shape,
        latent_from_image_shape=latent_from_image_shape(hps),
        image_count=len(vae_images),
        previous_dims=flif_dims)
    variable_known_sizes_codec = lambda: rvae_variable_known_size_codec(
        codec_from_image_shape=codec_from_shape,
        latent_from_image_shape=latent_from_image_shape(hps),
        shapes=[i.shape for i in vae_images],
        previous_dims=flif_dims)
    variable_size_codec = \
        variable_known_sizes_codec if hps.compression_exclude_sizes else variable_codec_including_sizes
    codec = fixed_size_codec if is_fixed else variable_size_codec
    vae_push, vae_pop = codec()

    np.seterr(divide='raise')

    if n_flif:
        print('Using FLIF to encode initial images...')
        flif_push, flif_pop = cs.repeat(cs.repeat(FLIF, batch_size), n_flif)
        message = cs.empty_message((1, ))
        message = flif_push(message, flif_images)
    else:
        print('Creating a random initial message...')
        message = cs.random_message(hps.initial_bits, (1, ))

    init_head_shape = (np.prod(image_shape(hps)) +
                       np.prod(latent_shape(hps)) if is_fixed else 1, )
    message = cs.reshape_head(message, init_head_shape)

    print("Encoding with VAE...")
    encode_t0 = time.time()
    message = vae_push(message, vae_images)
    encode_t = time.time() - encode_t0
    print("All encoded in {:.2f}s".format(encode_t))

    flat_message = cs.flatten(message)
    message_len = 32 * len(flat_message)
    print("Used {} bits.".format(message_len))
    print("This is {:.2f} bits per dim.".format(message_len / num_dims))
    if n_flif == 0:
        extra_bits = message_len - 32 * hps.initial_bits
        print('Extra bits: {}'.format(extra_bits))
        print('This is {:.2f} bits per dim.'.format(extra_bits / num_dims))

    print('Decoding with VAE...')
    decode_t0 = time.time()
    message = cs.unflatten(flat_message, init_head_shape)
    message, decoded_vae_images = vae_pop(message)
    message = cs.reshape_head(message, (1, ))

    decode_t = time.time() - decode_t0
    print('All decoded in {:.2f}s'.format(decode_t))

    assert len(vae_images) == len(decoded_vae_images), (
        len(vae_images), len(decoded_vae_images))
    for test_image, decoded_image in zip(vae_images, decoded_vae_images):
        np.testing.assert_equal(test_image, decoded_image)

    if n_flif:
        print('Decoding with FLIF...')
        message, decoded_flif_images = flif_pop(message)
        for test_image, decoded_image in zip(flif_images, decoded_flif_images):
            np.testing.assert_equal(test_image, decoded_image)
        assert cs.is_empty(message)
Ejemplo n.º 8
0
    vae_view), num_batches)

## Load mnist images
images = datasets.MNIST(sys.argv[1], train=False, download=True).data.numpy()
images = np.uint64(rng.random_sample(np.shape(images)) < images / 255.)
images = np.split(np.reshape(images, (num_images, -1)), num_batches)

## Encode
# Initialize message with some 'extra' bits
encode_t0 = time.time()
init_message = cs.base_message(obs_size + latent_size)

# Encode the mnist images
message, = vae_append(init_message, images)

flat_message = cs.flatten(message)
encode_t = time.time() - encode_t0

print("All encoded in {:.2f}s.".format(encode_t))

message_len = 32 * len(flat_message)
print("Used {} bits.".format(message_len))
print("This is {:.4f} bits per pixel.".format(message_len / num_pixels))

## Decode
decode_t0 = time.time()
message = cs.unflatten(flat_message, obs_size + latent_size)

message, images_ = vae_pop(message)
decode_t = time.time() - decode_t0
Ejemplo n.º 9
0
            def z_codec(z_next):
                masses = mixing_coeffs * a_mass[:, z_next]
                return Categorical(h, quantized_cdf(h,
                                                    masses / np.sum(masses)))
        else:
            alpha = None
            z_codec = Categorical(
                h, quantized_cdf(h, mixing_coeffs / np.sum(mixing_coeffs)))
        return alpha, z_codec

    return SSM(priors, likelihoods, (np.diff(a0), post_update))


if __name__ == '__main__':
    h = HyperParams()
    print('Hyperparameter settings:')
    print(h)
    rng = random.default_rng(1)
    params = hmm_params_sample(h, rng)
    xs = hmm_sample(h, params, rng)
    print('h(x) = {:.2f} bits/symbol.'.format(-hmm_logpmf(h, params, xs)))
    codec = HMM(h, params)
    message = cs.base_message(1)
    message = codec.push(message, xs)
    print('Compression rate: {:.2f} bits/symbol.'.format(
        len(cs.flatten(message)) * 32))
    message, xs_decoded = codec.pop(message)
    assert np.all(xs == xs_decoded)
    print('Decoded OK!')