def build(self, input_shape): # encoder graph self.encoder = core.GrayScaleEncoder(self.enc_cfg) if self.is_parallel_loss: self.parallel_dense = layers.Dense(units=self.num_symbols, name='parallel_logits', use_bias=False) # decoder graph: outer decoder -> inner decoder -> logits. self.pixel_embed_layer = layers.Dense(units=self.hidden_size, use_bias=False) self.outer_decoder = core.OuterDecoder(self.dec_cfg) self.inner_decoder = core.InnerDecoder(self.dec_cfg) self.final_dense = layers.Dense(units=self.num_symbols, name='auto_logits') self.final_norm = layers.LayerNormalization()
def test_outer_decoder(self, cond_mlp, cond_ln, cond_att_q, cond_att_scale): embeddings = tf.random.uniform(shape=(2, 8, 8, 256)) channel_context = tf.random.uniform(shape=(2, 8, 8, 256)) config = self.get_config() config.cond_mlp = cond_mlp config.cond_ln = cond_ln config.cond_att_q = cond_att_q config.cond_att_scale = cond_att_scale model = core.OuterDecoder(config=config) num_vars = get_num_variables(model) logging.info(num_vars) upper_context = model(inputs=(embeddings, channel_context)) upper_context_np = upper_context.numpy() # the first row slice should have zero context since both the present # and future are effectively masked. self.assertTrue(np.allclose(upper_context_np[:, 0], 0.0)) self.assertEqual(upper_context_np.shape, (2, 8, 8, 256))