Ejemplo n.º 1
0
    def build(self, input_shape):
        # encoder graph
        self.encoder = core.GrayScaleEncoder(self.enc_cfg)
        if self.is_parallel_loss:
            self.parallel_dense = layers.Dense(units=self.num_symbols,
                                               name='parallel_logits',
                                               use_bias=False)

        # decoder graph: outer decoder -> inner decoder -> logits.
        self.pixel_embed_layer = layers.Dense(units=self.hidden_size,
                                              use_bias=False)
        self.outer_decoder = core.OuterDecoder(self.dec_cfg)
        self.inner_decoder = core.InnerDecoder(self.dec_cfg)
        self.final_dense = layers.Dense(units=self.num_symbols,
                                        name='auto_logits')
        self.final_norm = layers.LayerNormalization()
Ejemplo n.º 2
0
  def test_outer_decoder(self, cond_mlp, cond_ln, cond_att_q, cond_att_scale):
    embeddings = tf.random.uniform(shape=(2, 8, 8, 256))
    channel_context = tf.random.uniform(shape=(2, 8, 8, 256))
    config = self.get_config()
    config.cond_mlp = cond_mlp
    config.cond_ln = cond_ln
    config.cond_att_q = cond_att_q
    config.cond_att_scale = cond_att_scale

    model = core.OuterDecoder(config=config)
    num_vars = get_num_variables(model)
    logging.info(num_vars)
    upper_context = model(inputs=(embeddings, channel_context))
    upper_context_np = upper_context.numpy()

    # the first row slice should have zero context since both the present
    # and future are effectively masked.
    self.assertTrue(np.allclose(upper_context_np[:, 0], 0.0))
    self.assertEqual(upper_context_np.shape, (2, 8, 8, 256))