def test_inner_decoder(self, cond_mlp, cond_ln, cond_att_q, cond_att_scale): embeddings = tf.random.uniform(shape=(2, 8, 8, 256)) channel_context = tf.random.uniform(shape=(2, 8, 8, 256)) upper_context = tf.random.uniform(shape=(2, 8, 8, 256)) config = self.get_config() config.cond_mlp = cond_mlp config.cond_ln = cond_ln config.cond_att_q = cond_att_q config.cond_att_scale = cond_att_scale model = core.InnerDecoder(config=config) output = model(inputs=(embeddings, upper_context, channel_context)) num_vars = get_num_variables(model) logging.info(num_vars) self.assertEqual(output.shape, (2, 8, 8, 256))
def build(self, input_shape): # encoder graph self.encoder = core.GrayScaleEncoder(self.enc_cfg) if self.is_parallel_loss: self.parallel_dense = layers.Dense(units=self.num_symbols, name='parallel_logits', use_bias=False) # decoder graph: outer decoder -> inner decoder -> logits. self.pixel_embed_layer = layers.Dense(units=self.hidden_size, use_bias=False) self.outer_decoder = core.OuterDecoder(self.dec_cfg) self.inner_decoder = core.InnerDecoder(self.dec_cfg) self.final_dense = layers.Dense(units=self.num_symbols, name='auto_logits') self.final_norm = layers.LayerNormalization()