def test_cache_layer(self): cache = layers.Cache(canvas_shape=(2, 4)) # update 1 exp_first = tf.range(8, dtype=tf.float32) exp_first = tf.reshape(exp_first, (1, 2, 2, 2)) index = tf.stack([0, 0]) out = cache(inputs=(exp_first, index)) out_slice = out.numpy()[:1, :2, :2, :2] self.assertTrue(np.allclose(out_slice, exp_first.numpy())) # update 2 exp_second = tf.range(8, 16, dtype=tf.float32) exp_second = tf.reshape(exp_second, (1, 2, 2, 2)) index = tf.stack([0, 2]) out = cache(inputs=(exp_second, index)) out_np = out.numpy() first, second = out_np[:1, :2, :2, :2], out_np[:1, :2, 2:, :2] self.assertTrue(np.allclose(second, exp_second.numpy())) self.assertTrue(np.allclose(first, exp_first.numpy())) # update 3 (special case) exp_third = tf.reshape([50.0, 100.0], (1, 1, 1, 2)) index = tf.stack([0, 0]) out = cache(inputs=(exp_third, index)) out_np = out.numpy() self.assertTrue(np.allclose(out_np[0, 0, 0, :2], [50.0, 100.0]))
def autoregressive_sample(self, z_gray, mode='sample'): """Generates pixel-by-pixel. 1. The encoder is run once per-channel. 2. The outer decoder is run once per-row. 3. the inner decoder is run once per-pixel. The context from the encoder and outer decoder conditions the inner decoder. The inner decoder then generates a row, one pixel at a time. After generating all pixels in a row, the outer decoder is run to recompute context. This condtions the inner decoder, which then generates the next row, pixel-by-pixel. Args: z_gray: grayscale image. mode: sample or argmax. Returns: image: coarse image of shape (B, H, W) image_proba: probalities, shape (B, H, W, 512) """ num_filters = self.config.hidden_size batch_size, height, width = z_gray.shape[:3] # channel_cache[i, j] stores the pixel embedding for row i and col j. canvas_shape = (batch_size, height, width, num_filters) channel_cache = coltran_layers.Cache(canvas_shape=(height, width)) init_channel = tf.zeros(shape=canvas_shape) init_ind = tf.stack([0, 0]) channel_cache(inputs=(init_channel, init_ind)) # upper_context[row_ind] stores context from all previously generated rows. upper_context = tf.zeros(shape=canvas_shape) # row_cache[0, j] stores the pixel embedding for the column j of the row # under generation. After every row is generated, this is rewritten. row_cache = coltran_layers.Cache(canvas_shape=(1, width)) init_row = tf.zeros(shape=(batch_size, 1, width, num_filters)) row_cache(inputs=(init_row, init_ind)) pixel_samples, pixel_probas = [], [] for row in range(height): row_cond_channel = tf.expand_dims(z_gray[:, row], axis=1) row_cond_upper = tf.expand_dims(upper_context[:, row], axis=1) row_cache.reset() gen_row, proba_row = [], [] for col in range(width): inner_input = (row_cache.cache, row_cond_upper, row_cond_channel) # computes output activations at col. activations = self.inner_decoder(inner_input, row_ind=row, training=False) pixel_sample, pixel_embed, pixel_proba = self.act_logit_sample_embed( activations, col, mode=mode) proba_row.append(pixel_proba) gen_row.append(pixel_sample) # row_cache[:, col] = pixel_embed row_cache(inputs=(pixel_embed, tf.stack([0, col]))) # channel_cache[row, col] = pixel_embed channel_cache(inputs=(pixel_embed, tf.stack([row, col]))) gen_row = tf.stack(gen_row, axis=-1) pixel_samples.append(gen_row) pixel_probas.append(tf.stack(proba_row, axis=1)) # after a row is generated, recomputes the context for the next row. # upper_context[row] = self_attention(channel_cache[:row_index]) upper_context = self.outer_decoder(inputs=(channel_cache.cache, z_gray), training=False) image = tf.stack(pixel_samples, axis=1) image = self.post_process_image(image) image_proba = tf.stack(pixel_probas, axis=1) return image, image_proba