def test_serial_resized(shape2, shape1=(5, ), precision=4): data1 = rng.randint(precision, size=(7,) + shape1, dtype="uint64") data2 = rng.randint(precision, size=(20,) + shape2, dtype="uint64") data = list(data1) + list(data2) codec = cs.Uniform(precision) push, pop = codec def push_resize(message, symbol): assert message[0].shape == shape2 message = cs.reshape_head(message, shape1) message = push(message, symbol) return message def pop_resize(message): assert message[0].shape == shape1 message, symbol = pop(message) message = cs.reshape_head(message, shape2) return message, symbol resize_codec = cs.Codec(push_resize, pop_resize) check_codec(shape2, cs.serial([codec for _ in data1[:-1]] + [resize_codec] + [codec for _ in data2]), data)
def rvae_variable_known_size_codec(codec_from_image_shape, latent_from_image_shape, shapes, previous_dims): """ Applies given codecs in series on a sequence of symbols requiring various ANS stack head shapes. The head shape required for each symbol is given through shapes. """ def reshape_push(shape, message, symbol): head_shape = (np.prod(latent_from_image_shape(shape)) + np.prod(shape), ) message = cs.reshape_head(message, head_shape) codec = codec_from_image_shape(shape) message = codec.push(message, symbol) return message def reshape_pop(shape, message): head_shape = (np.prod(latent_from_image_shape(shape)) + np.prod(shape), ) message = cs.reshape_head(message, head_shape) codec = codec_from_image_shape(shape) message, symbol = codec.pop(message) return message, symbol return rvae_serial_with_progress([ cs.Codec(partial(reshape_push, shape), partial(reshape_pop, shape)) for shape in shapes ], previous_dims)
def ArrayCodec(codec): def push(message, x): return codec.push(message, x.arr) def pop(message): message, x = codec.pop(message) return message, ArraySymbol(x) return cs.Codec(push, pop)
def FLIF(): def push(message, image): """expects image to be chw""" image = im_transform(image).astype(np.uint8) success, im_buffer = cv2.imencode(".ppm", image) process = subprocess.run(encode_command.split(), input=im_buffer.tobytes(), capture_output=True) if process.returncode != 0: raise Exception(f"flif encode failed: {process.stderr}") compressed_bytes = process.stdout # take off the 'FLIF' magic header compressed_bytes = compressed_bytes[4:] # can also remove RGB interlaced byte and bytes per chan (next two bytes) # then there are 3 varints for width, height and number of frames # https://flif.info/spec.html for details compressed_bits = list(compressed_bytes) # list of uint8s n_compressed_bits = len(compressed_bits) cbits_codec = cs.repeat(codec, n_compressed_bits) message = cbits_codec.push(message, compressed_bits) message = len_codec.push(message, np.uint64(n_compressed_bits)) return message def pop(message): message, n_compressed_bits = len_codec.pop(message) cbits_codec = cs.repeat(codec, n_compressed_bits[0]) message, compressed_bits = cbits_codec.pop(message) compressed_bits = np.squeeze(compressed_bits).astype(np.uint8) bytes_buffer = b'FLIF' + bytes(compressed_bits) process = subprocess.run(decode_command.split(), input=bytes_buffer, capture_output=True) if process.returncode != 0: raise Exception(f"flif decode failed: {process.stderr}") im_buffer = np.frombuffer(process.stdout, dtype=np.uint8) image = cv2.imdecode(im_buffer, flags=1) # this gives in hwc return message, inverse_im_transform(image) return cs.Codec(push, pop)
def rvae_serial_with_progress(codecs, previous_dims): def push(message, symbols): init_len = 32 * len(cs.flatten(message)) t_start = time.time() dims = previous_dims for i, (codec, symbol) in enumerate(reversed(list(zip(codecs, symbols)))): t0 = time.time() message = codec.push(message, symbol) dims += symbol.size flat_message = cs.flatten(message) print( f"Encoded {i+1}/{len(symbols)}[{(i+1)/float(len(symbols))*100:.0f}%], " f"message length: {len(flat_message) * (4/1024):.0f}kB, " f"bpd: {32 * len(flat_message) / float(dims):.2f}, " f"net bitrate: {(32 * len(flat_message) - init_len) / (float(dims - previous_dims)):.2f}, " f"net dims: {dims - previous_dims}, " f"net bits: {32 * len(flat_message) - init_len}, " f"iter time: {time.time() - t0:.2f}s, " f"total time: {time.time() - t_start:.2f}s, " f"symbol shape: {symbol.shape}, " f"message length: {len(flat_message) * (4/1024):.0f}kB, " f"bpd: {32 * len(flat_message) / float(dims):.2f}") return message def pop(message): symbols = [] t_start = time.time() for i, codec in enumerate(codecs): t0 = time.time() message, symbol = codec.pop(message) symbols.append(symbol) print( f"Decoded {i+1}/{len(symbols)}[{(i+1)/float(len(codecs))*100:.0f}%], " f"iter time: {time.time() - t0:.2f}s, " f"total time: {time.time() - t_start:.2f}s") return message, symbols return cs.Codec(push, pop)
def SSM(priors, likelihoods, posterior): # priors = [p(z_1), p(z_2 | z_1), ..., p(z_T | z_{T-1})] # likelihoods = [p(x_1 | z_1), ..., p(x_T | z_T)] # [lambda z_{t+1}: Q(z_t | x_{1:t}, z_{t+1}) for t in range(T)] post_init_state, post_update = posterior def push(message, xs): post_codecs = [] post_state = post_init_state for t, x in enumerate(xs): # Forward inference pass post_state, post_codec = post_update(t, post_state, x) post_codecs.append(post_codec) message, z_next = post_codecs[-1].pop(message) for t in range(len(priors) - 1, 0, -1): # Backward encoding pass message = likelihoods[t](z_next).push(message, xs[t]) message, z = post_codecs[t - 1](z_next).pop(message) message = priors[t](z).push(message, z_next) z_next = z message = likelihoods[0](z_next).push(message, xs[0]) message = priors[0].push(message, z_next) return message def pop(message): xs = [] message, z_next = priors[0].pop(message) message, x = likelihoods[0](z_next).pop(message) post_state, post_codec = post_update(0, post_init_state, x) xs.append(x) for t in range(1, len(priors)): # Forward decoding pass z = z_next message, z_next = priors[t](z).pop(message) message = post_codec(z_next).push(message, z) message, x = likelihoods[t](z_next).pop(message) post_state, post_codec = post_update(t, post_state, x) xs.append(x) message = post_codec.push(message, z_next) return message, xs return cs.Codec(push, pop)
def rvae_variable_size_codec(codec_from_shape, latent_from_image_shape, image_count, dimensions=4, dimension_bits=16, previous_dims=0): size_codec = cs.repeat(cs.Uniform(dimension_bits), dimensions) def push(message, symbol): """push sizes and array in alternating order""" assert len(symbol.shape) == dimensions codec = codec_from_shape(symbol.shape) head_size = np.prod(latent_from_image_shape(symbol.shape)) + np.prod( symbol.shape) message = cs.reshape_head(message, (head_size, )) message = codec.push(message, symbol) message = cs.reshape_head(message, (1, )) message = size_codec.push(message, np.array(symbol.shape)) return message def pop(message): message, size = size_codec.pop(message) # TODO make codec 0 dimensional: size = np.array(size)[:, 0] assert size.shape == (dimensions, ) size = size.astype(np.int) head_size = np.prod(latent_from_image_shape(size)) + np.prod(size) codec = codec_from_shape(tuple(size)) message = cs.reshape_head(message, (head_size, )) message, symbol = codec.pop(message) message = cs.reshape_head(message, (1, )) return message, symbol return rvae_serial_with_progress([cs.Codec(push, pop)] * image_count, previous_dims)
def posterior(data): # run deterministic upper-pass contexts = up_pass(data) def posterior_push(message, latents): # first run the model top-down to get the params and latent vals latents, _ = latents (post_mean, post_stdd), h_rec = rec_net_top(contexts[-1]) post_params = [(post_mean, post_stdd)] for rec_net, latent, context in reversed( list(zip(rec_nets, latents[1:], contexts[:-1]))): previous_latent, (prior_mean, prior_stdd) = latent previous_latent_val = prior_mean + \ cs.std_gaussian_centres(prior_prec)[previous_latent] * prior_stdd (post_mean, post_stdd), h_rec = rec_net(h_rec, previous_latent_val, context) post_params.append((post_mean, post_stdd)) # now append bottom up for latent, post_param in zip(latents, reversed(post_params)): latent, (prior_mean, prior_stdd) = latent post_mean, post_stdd = post_param codec = cs.substack( cs.DiagGaussian_GaussianBins(post_mean, post_stdd, prior_mean, prior_stdd, latent_prec, prior_prec), z_view) message = codec.push(message, latent) return message def posterior_pop(message): # pop top-down (post_mean, post_stdd), h_rec = rec_net_top(contexts[-1]) (prior_mean, prior_stdd), h_gen = gen_net_top() codec = cs.substack( cs.DiagGaussian_GaussianBins(post_mean, post_stdd, prior_mean, prior_stdd, latent_prec, prior_prec), z_view) message, latent = codec.pop(message) latents = [(latent, (prior_mean, prior_stdd))] for rec_net, gen_net, context in reversed( list(zip(rec_nets, gen_nets, contexts[:-1]))): previous_latent_val = prior_mean + \ cs.std_gaussian_centres(prior_prec)[latents[-1][0]] * prior_stdd (post_mean, post_stdd), h_rec = rec_net(h_rec, previous_latent_val, context) (prior_mean, prior_stdd), h_gen = gen_net(h_gen, previous_latent_val) codec = cs.substack( cs.DiagGaussian_GaussianBins(post_mean, post_stdd, prior_mean, prior_stdd, latent_prec, prior_prec), z_view) message, latent = codec.pop(message) latents.append((latent, (prior_mean, prior_stdd))) return message, (latents[::-1], h_gen) return cs.Codec(posterior_push, posterior_pop)
def ResNetVAE(up_pass, rec_net_top, rec_nets, gen_net_top, gen_nets, obs_codec, prior_prec, latent_prec): """ Codec for a ResNetVAE. Assume that the posterior is bidirectional - i.e. has a deterministic upper pass but top down sampling. Further assume that all latent conditionals are factorised Gaussians, both in the generative network p(z_n|z_{n-1}) and in the inference network q(z_n|x, z_{n-1}) Assume that everything is ordered bottom up """ z_view = lambda head: head[0] x_view = lambda head: head[1] prior_codec = cs.substack(cs.Uniform(prior_prec), z_view) def prior_push(message, latents): # push bottom-up latents, _ = latents for latent in latents: latent, _ = latent message = prior_codec.push(message, latent) return message def prior_pop(message): # pop top-down (prior_mean, prior_stdd), h_gen = gen_net_top() message, latent = prior_codec.pop(message) latents = [(latent, (prior_mean, prior_stdd))] for gen_net in reversed(gen_nets): previous_latent_val = prior_mean + cs.std_gaussian_centres( prior_prec)[latent] * prior_stdd (prior_mean, prior_stdd), h_gen = gen_net(h_gen, previous_latent_val) message, latent = prior_codec.pop(message) latents.append((latent, (prior_mean, prior_stdd))) return message, (latents[::-1], h_gen) def posterior(data): # run deterministic upper-pass contexts = up_pass(data) def posterior_push(message, latents): # first run the model top-down to get the params and latent vals latents, _ = latents (post_mean, post_stdd), h_rec = rec_net_top(contexts[-1]) post_params = [(post_mean, post_stdd)] for rec_net, latent, context in reversed( list(zip(rec_nets, latents[1:], contexts[:-1]))): previous_latent, (prior_mean, prior_stdd) = latent previous_latent_val = prior_mean + \ cs.std_gaussian_centres(prior_prec)[previous_latent] * prior_stdd (post_mean, post_stdd), h_rec = rec_net(h_rec, previous_latent_val, context) post_params.append((post_mean, post_stdd)) # now append bottom up for latent, post_param in zip(latents, reversed(post_params)): latent, (prior_mean, prior_stdd) = latent post_mean, post_stdd = post_param codec = cs.substack( cs.DiagGaussian_GaussianBins(post_mean, post_stdd, prior_mean, prior_stdd, latent_prec, prior_prec), z_view) message = codec.push(message, latent) return message def posterior_pop(message): # pop top-down (post_mean, post_stdd), h_rec = rec_net_top(contexts[-1]) (prior_mean, prior_stdd), h_gen = gen_net_top() codec = cs.substack( cs.DiagGaussian_GaussianBins(post_mean, post_stdd, prior_mean, prior_stdd, latent_prec, prior_prec), z_view) message, latent = codec.pop(message) latents = [(latent, (prior_mean, prior_stdd))] for rec_net, gen_net, context in reversed( list(zip(rec_nets, gen_nets, contexts[:-1]))): previous_latent_val = prior_mean + \ cs.std_gaussian_centres(prior_prec)[latents[-1][0]] * prior_stdd (post_mean, post_stdd), h_rec = rec_net(h_rec, previous_latent_val, context) (prior_mean, prior_stdd), h_gen = gen_net(h_gen, previous_latent_val) codec = cs.substack( cs.DiagGaussian_GaussianBins(post_mean, post_stdd, prior_mean, prior_stdd, latent_prec, prior_prec), z_view) message, latent = codec.pop(message) latents.append((latent, (prior_mean, prior_stdd))) return message, (latents[::-1], h_gen) return cs.Codec(posterior_push, posterior_pop) def likelihood(latents): # get the z1 vals to condition on latents, h = latents z1_idxs, (prior_mean, prior_stdd) = latents[0] z1_vals = prior_mean + cs.std_gaussian_centres( prior_prec)[z1_idxs] * prior_stdd return cs.substack(obs_codec(h, z1_vals), x_view) return BBANS(cs.Codec(prior_push, prior_pop), likelihood, posterior)