def build(self, input_shape): self.bit_embedding = layers.Dense(units=self.hidden_size, use_bias=False) self.gray_embedding = layers.Dense(units=self.hidden_size, use_bias=False) self.input_dense = layers.Dense(units=self.hidden_size) self.encoder = coltran_layers.FactorizedAttention(self.config) self.final_dense = layers.Dense(units=256)
def test_factorized_attention(self): config = ConfigDict() config.hidden_size = 256 config.ff_size = 256 config.num_encoder_layers = 2 config.num_heads = 2 fact = layers.FactorizedAttention(config) inputs = tf.random.uniform(shape=(8, 8, 8, 256)) output = fact(inputs) self.assertEqual(output.shape, (8, 8, 8, 256))
def build(self, input_shapes): self.embedding = layers.Dense(units=self.config.hidden_size) self.encoder = coltran_layers.FactorizedAttention(self.config)