Example #1
0
 def test_grayscale_encoder(self):
   config = self.get_config()
   inputs = tf.random.uniform(shape=(2, 32, 32, 3), minval=0, maxval=256,
                              dtype=tf.int32)
   gray = tf.image.rgb_to_grayscale(inputs)
   encoder = core.GrayScaleEncoder(config)
   output = encoder(gray)
   self.assertEqual(output.shape, (2, 32, 32, 256))
Example #2
0
    def build(self, input_shape):
        # encoder graph
        self.encoder = core.GrayScaleEncoder(self.enc_cfg)
        if self.is_parallel_loss:
            self.parallel_dense = layers.Dense(units=self.num_symbols,
                                               name='parallel_logits',
                                               use_bias=False)

        # decoder graph: outer decoder -> inner decoder -> logits.
        self.pixel_embed_layer = layers.Dense(units=self.hidden_size,
                                              use_bias=False)
        self.outer_decoder = core.OuterDecoder(self.dec_cfg)
        self.inner_decoder = core.InnerDecoder(self.dec_cfg)
        self.final_dense = layers.Dense(units=self.num_symbols,
                                        name='auto_logits')
        self.final_norm = layers.LayerNormalization()