Ejemplo n.º 1
0
def test_flatten_rate():
    n = 1000

    init_data = np.random.randint(1 << 16, size=8 * n, dtype='uint64')

    init_message = cs.base_message((1, ))

    for datum in init_data:
        init_message = cs.Uniform(16).push(init_message, datum)

    l_init = len(cs.flatten(init_message))

    ps = np.random.rand(n, 1)
    data = np.random.rand(n, 1) < ps

    message = init_message
    for p, datum in zip(ps, data):
        message = cs.Bernoulli(p, 14).push(message, datum)

    l_scalar = len(cs.flatten(message))

    message = init_message
    message = cs.reshape_head(message, (n, 1))
    message = cs.Bernoulli(ps, 14).push(message, data)

    l_vector = len(cs.flatten(message))

    assert (l_vector - l_init) / (l_scalar - l_init) - 1 < 0.001
Ejemplo n.º 2
0
def test_substack():
    n_data = 100
    prec = 4
    head, tail = cs.base_message((4, 4))
    head = np.split(head, 2)
    message = head, tail
    data = rng.randint(1 << prec, size=(n_data, 2, 4), dtype='uint64')
    view_fun = lambda h: h[0]
    append, pop = cs.substack(cs.repeat(cs.Uniform(prec), n_data), view_fun)
    message_ = append(message, data)
    np.testing.assert_array_equal(message_[0][1], message[0][1])
    message_, data_ = pop(message_)
    np.testing.assert_equal(message, message_)
    np.testing.assert_equal(data, data_)
Ejemplo n.º 3
0
def test_flatten_unflatten(shape, depth=1000):
    np.random.seed(0)
    p = 8
    bits = np.random.randint(1 << p, size=(depth, ) + shape, dtype=np.uint64)

    message = cs.base_message(shape)

    other_bits_push, _ = cs.repeat(cs.Uniform(p), depth)

    message = other_bits_push(message, bits)

    flattened = cs.flatten(message)
    reconstructed = cs.unflatten(flattened, shape)

    assert_message_equal(message, reconstructed)
Ejemplo n.º 4
0
def test_reshape_head(old_shape, new_shape, depth=1000):
    np.random.seed(0)
    p = 8
    bits = np.random.randint(1 << p, size=(depth,) + old_shape, dtype=np.uint64)

    message = cs.base_message(old_shape)

    other_bits_push, _ = cs.repeat(cs.Uniform(p), depth)

    message, = other_bits_push(message, bits)

    resized = cs.reshape_head(message, new_shape)
    reconstructed = cs.reshape_head(resized, old_shape)

    assert_message_equal(message, reconstructed)
Ejemplo n.º 5
0
def test_flatten_unflatten():
    n = 100
    shape = (7, 3)
    p = 12
    state = cs.base_message(shape)
    some_bits = rng.randint(1 << p, size=(n, ) + shape).astype(np.uint64)
    freqs = np.ones(shape, dtype="uint64")
    for b in some_bits:
        state = cs.rans.push(state, b, freqs, p)
    flat = cs.flatten(state)
    flat_ = cs.rans.flatten(state)
    print('Normal flat len: {}'.format(len(flat_) * 32))
    print('Benford flat len: {}'.format(len(flat) * 32))
    assert flat.dtype is np.dtype("uint32")
    state_ = cs.unflatten(flat, shape)
    flat_ = cs.flatten(state_)
    assert np.all(flat == flat_)
    assert np.all(state[0] == state_[0])
    assert state[1] == state_[1]
Ejemplo n.º 6
0
    h = hmm_codec.HyperParams(T=1 << 9)
    print('Hyperparameter settings:')
    print(h)
    rng = random.default_rng(1)
    params = hmm_codec.hmm_params_sample(h, rng)
    xs = hmm_codec.hmm_sample(h, params, rng)
    lengths = []
    hs = []
    message_lengths = []
    l = 4
    while l <= h.T:
        codec = hmm_codec.hmm_codec(replace(h, T=l), params)
        lengths.append(l)
        hs.append(-hmm_codec.hmm_logpmf(h, params, xs[:l]) / l)
        message = cs.base_message(1)
        message = codec.push(message, xs[:l])
        message_lengths.append(len(cs.flatten(message)) * 32 / l)
        l = 2 * l

    fig, ax = plt.subplots(figsize=[2.7, 1.8])
    ax.plot(lengths, np.divide(message_lengths, hs), color='black', lw=.5)
    ax.set_yscale('log')
    ax.yaxis.set_minor_locator(ticker.FixedLocator([1., 2., 3., 4.]))
    ax.yaxis.set_major_locator(ticker.NullLocator())
    ax.yaxis.set_minor_formatter(ticker.ScalarFormatter())

    ax.hlines(y=1, xmin=0, xmax=h.T + 1, color='gray', lw=.5)
    ax.set_xlim(0, h.T)
    ax.set_xlabel('$T$')
    ax.set_ylabel('$l(m)/h(x)$')
Ejemplo n.º 7
0
def check_codec(head_shape, codec, data):
    message = cs.base_message(head_shape)
    push, pop = codec
    message_, data_ = pop(push(message, data))
    assert_message_equal(message, message_)
    np.testing.assert_equal(data, data_)
Ejemplo n.º 8
0
    return ag_tuple((np.reshape(head[:latent_size], latent_shape),
                     np.reshape(head[latent_size:], obs_shape)))

vae_append, vae_pop = cs.repeat(cs.substack(
    bb_ans.VAE(gen_net, rec_net, obs_codec, prior_precision, q_precision),
    vae_view), num_batches)

## Load mnist images
images = datasets.MNIST(sys.argv[1], train=False, download=True).data.numpy()
images = np.uint64(rng.random_sample(np.shape(images)) < images / 255.)
images = np.split(np.reshape(images, (num_images, -1)), num_batches)

## Encode
# Initialize message with some 'extra' bits
encode_t0 = time.time()
init_message = cs.base_message(obs_size + latent_size)

# Encode the mnist images
message, = vae_append(init_message, images)

flat_message = cs.flatten(message)
encode_t = time.time() - encode_t0

print("All encoded in {:.2f}s.".format(encode_t))

message_len = 32 * len(flat_message)
print("Used {} bits.".format(message_len))
print("This is {:.4f} bits per pixel.".format(message_len / num_pixels))

## Decode
decode_t0 = time.time()