示例#1
0
    def test_cache_layer(self):
        cache = layers.Cache(canvas_shape=(2, 4))

        # update 1
        exp_first = tf.range(8, dtype=tf.float32)
        exp_first = tf.reshape(exp_first, (1, 2, 2, 2))
        index = tf.stack([0, 0])
        out = cache(inputs=(exp_first, index))
        out_slice = out.numpy()[:1, :2, :2, :2]
        self.assertTrue(np.allclose(out_slice, exp_first.numpy()))

        # update 2
        exp_second = tf.range(8, 16, dtype=tf.float32)
        exp_second = tf.reshape(exp_second, (1, 2, 2, 2))
        index = tf.stack([0, 2])
        out = cache(inputs=(exp_second, index))
        out_np = out.numpy()
        first, second = out_np[:1, :2, :2, :2], out_np[:1, :2, 2:, :2]
        self.assertTrue(np.allclose(second, exp_second.numpy()))
        self.assertTrue(np.allclose(first, exp_first.numpy()))

        # update 3 (special case)
        exp_third = tf.reshape([50.0, 100.0], (1, 1, 1, 2))
        index = tf.stack([0, 0])
        out = cache(inputs=(exp_third, index))
        out_np = out.numpy()
        self.assertTrue(np.allclose(out_np[0, 0, 0, :2], [50.0, 100.0]))
示例#2
0
    def autoregressive_sample(self, z_gray, mode='sample'):
        """Generates pixel-by-pixel.

    1. The encoder is run once per-channel.
    2. The outer decoder is run once per-row.
    3. the inner decoder is run once per-pixel.

    The context from the encoder and outer decoder conditions the
    inner decoder. The inner decoder then generates a row, one pixel at a time.

    After generating all pixels in a row, the outer decoder is run to recompute
    context. This condtions the inner decoder, which then generates the next
    row, pixel-by-pixel.

    Args:
      z_gray: grayscale image.
      mode: sample or argmax.

    Returns:
      image: coarse image of shape (B, H, W)
      image_proba: probalities, shape (B, H, W, 512)
    """
        num_filters = self.config.hidden_size
        batch_size, height, width = z_gray.shape[:3]

        # channel_cache[i, j] stores the pixel embedding for row i and col j.
        canvas_shape = (batch_size, height, width, num_filters)
        channel_cache = coltran_layers.Cache(canvas_shape=(height, width))
        init_channel = tf.zeros(shape=canvas_shape)
        init_ind = tf.stack([0, 0])
        channel_cache(inputs=(init_channel, init_ind))

        # upper_context[row_ind] stores context from all previously generated rows.
        upper_context = tf.zeros(shape=canvas_shape)

        # row_cache[0, j] stores the pixel embedding for the column j of the row
        # under generation. After every row is generated, this is rewritten.
        row_cache = coltran_layers.Cache(canvas_shape=(1, width))
        init_row = tf.zeros(shape=(batch_size, 1, width, num_filters))
        row_cache(inputs=(init_row, init_ind))

        pixel_samples, pixel_probas = [], []

        for row in range(height):
            row_cond_channel = tf.expand_dims(z_gray[:, row], axis=1)
            row_cond_upper = tf.expand_dims(upper_context[:, row], axis=1)
            row_cache.reset()

            gen_row, proba_row = [], []
            for col in range(width):

                inner_input = (row_cache.cache, row_cond_upper,
                               row_cond_channel)
                # computes output activations at col.
                activations = self.inner_decoder(inner_input,
                                                 row_ind=row,
                                                 training=False)

                pixel_sample, pixel_embed, pixel_proba = self.act_logit_sample_embed(
                    activations, col, mode=mode)
                proba_row.append(pixel_proba)
                gen_row.append(pixel_sample)

                # row_cache[:, col] = pixel_embed
                row_cache(inputs=(pixel_embed, tf.stack([0, col])))

                # channel_cache[row, col] = pixel_embed
                channel_cache(inputs=(pixel_embed, tf.stack([row, col])))

            gen_row = tf.stack(gen_row, axis=-1)
            pixel_samples.append(gen_row)
            pixel_probas.append(tf.stack(proba_row, axis=1))

            # after a row is generated, recomputes the context for the next row.
            # upper_context[row] = self_attention(channel_cache[:row_index])
            upper_context = self.outer_decoder(inputs=(channel_cache.cache,
                                                       z_gray),
                                               training=False)

        image = tf.stack(pixel_samples, axis=1)
        image = self.post_process_image(image)

        image_proba = tf.stack(pixel_probas, axis=1)
        return image, image_proba