def test_flatten_rate(): n = 1000 init_data = np.random.randint(1 << 16, size=8 * n, dtype='uint64') init_message = cs.base_message((1, )) for datum in init_data: init_message = cs.Uniform(16).push(init_message, datum) l_init = len(cs.flatten(init_message)) ps = np.random.rand(n, 1) data = np.random.rand(n, 1) < ps message = init_message for p, datum in zip(ps, data): message = cs.Bernoulli(p, 14).push(message, datum) l_scalar = len(cs.flatten(message)) message = init_message message = cs.reshape_head(message, (n, 1)) message = cs.Bernoulli(ps, 14).push(message, data) l_vector = len(cs.flatten(message)) assert (l_vector - l_init) / (l_scalar - l_init) - 1 < 0.001
def test_substack(): n_data = 100 prec = 4 head, tail = cs.base_message((4, 4)) head = np.split(head, 2) message = head, tail data = rng.randint(1 << prec, size=(n_data, 2, 4), dtype='uint64') view_fun = lambda h: h[0] append, pop = cs.substack(cs.repeat(cs.Uniform(prec), n_data), view_fun) message_ = append(message, data) np.testing.assert_array_equal(message_[0][1], message[0][1]) message_, data_ = pop(message_) np.testing.assert_equal(message, message_) np.testing.assert_equal(data, data_)
def test_flatten_unflatten(shape, depth=1000): np.random.seed(0) p = 8 bits = np.random.randint(1 << p, size=(depth, ) + shape, dtype=np.uint64) message = cs.base_message(shape) other_bits_push, _ = cs.repeat(cs.Uniform(p), depth) message = other_bits_push(message, bits) flattened = cs.flatten(message) reconstructed = cs.unflatten(flattened, shape) assert_message_equal(message, reconstructed)
def test_reshape_head(old_shape, new_shape, depth=1000): np.random.seed(0) p = 8 bits = np.random.randint(1 << p, size=(depth,) + old_shape, dtype=np.uint64) message = cs.base_message(old_shape) other_bits_push, _ = cs.repeat(cs.Uniform(p), depth) message, = other_bits_push(message, bits) resized = cs.reshape_head(message, new_shape) reconstructed = cs.reshape_head(resized, old_shape) assert_message_equal(message, reconstructed)
def test_flatten_unflatten(): n = 100 shape = (7, 3) p = 12 state = cs.base_message(shape) some_bits = rng.randint(1 << p, size=(n, ) + shape).astype(np.uint64) freqs = np.ones(shape, dtype="uint64") for b in some_bits: state = cs.rans.push(state, b, freqs, p) flat = cs.flatten(state) flat_ = cs.rans.flatten(state) print('Normal flat len: {}'.format(len(flat_) * 32)) print('Benford flat len: {}'.format(len(flat) * 32)) assert flat.dtype is np.dtype("uint32") state_ = cs.unflatten(flat, shape) flat_ = cs.flatten(state_) assert np.all(flat == flat_) assert np.all(state[0] == state_[0]) assert state[1] == state_[1]
h = hmm_codec.HyperParams(T=1 << 9) print('Hyperparameter settings:') print(h) rng = random.default_rng(1) params = hmm_codec.hmm_params_sample(h, rng) xs = hmm_codec.hmm_sample(h, params, rng) lengths = [] hs = [] message_lengths = [] l = 4 while l <= h.T: codec = hmm_codec.hmm_codec(replace(h, T=l), params) lengths.append(l) hs.append(-hmm_codec.hmm_logpmf(h, params, xs[:l]) / l) message = cs.base_message(1) message = codec.push(message, xs[:l]) message_lengths.append(len(cs.flatten(message)) * 32 / l) l = 2 * l fig, ax = plt.subplots(figsize=[2.7, 1.8]) ax.plot(lengths, np.divide(message_lengths, hs), color='black', lw=.5) ax.set_yscale('log') ax.yaxis.set_minor_locator(ticker.FixedLocator([1., 2., 3., 4.])) ax.yaxis.set_major_locator(ticker.NullLocator()) ax.yaxis.set_minor_formatter(ticker.ScalarFormatter()) ax.hlines(y=1, xmin=0, xmax=h.T + 1, color='gray', lw=.5) ax.set_xlim(0, h.T) ax.set_xlabel('$T$') ax.set_ylabel('$l(m)/h(x)$')
def check_codec(head_shape, codec, data): message = cs.base_message(head_shape) push, pop = codec message_, data_ = pop(push(message, data)) assert_message_equal(message, message_) np.testing.assert_equal(data, data_)
return ag_tuple((np.reshape(head[:latent_size], latent_shape), np.reshape(head[latent_size:], obs_shape))) vae_append, vae_pop = cs.repeat(cs.substack( bb_ans.VAE(gen_net, rec_net, obs_codec, prior_precision, q_precision), vae_view), num_batches) ## Load mnist images images = datasets.MNIST(sys.argv[1], train=False, download=True).data.numpy() images = np.uint64(rng.random_sample(np.shape(images)) < images / 255.) images = np.split(np.reshape(images, (num_images, -1)), num_batches) ## Encode # Initialize message with some 'extra' bits encode_t0 = time.time() init_message = cs.base_message(obs_size + latent_size) # Encode the mnist images message, = vae_append(init_message, images) flat_message = cs.flatten(message) encode_t = time.time() - encode_t0 print("All encoded in {:.2f}s.".format(encode_t)) message_len = 32 * len(flat_message) print("Used {} bits.".format(message_len)) print("This is {:.4f} bits per pixel.".format(message_len / num_pixels)) ## Decode decode_t0 = time.time()