def build(self, unused_input_shapes): """Implements build() for the layer.""" self.layers = [] for i in range(self.num_hidden_layers): self.layers.append( layers.TransformerDecoderBlock( num_attention_heads=self.num_attention_heads, intermediate_size=self.intermediate_size, intermediate_activation=self.intermediate_activation, dropout_rate=self.hidden_dropout_prob, attention_dropout_rate=self.attention_probs_dropout_prob, kernel_initializer=tf.keras.initializers.TruncatedNormal( stddev=self.initializer_range), multi_channel_cross_attention=self. multi_channel_cross_attention, name=("layer_%d" % i))) super(TransformerDecoder, self).build(unused_input_shapes)
def build(self, input_shape): """Implements build() for the layer.""" self.decoder_layers = [] for i in range(self.num_layers): self.decoder_layers.append( layers.TransformerDecoderBlock( num_attention_heads=self.num_attention_heads, intermediate_size=self._intermediate_size, intermediate_activation=self._activation, dropout_rate=self._dropout_rate, attention_dropout_rate=self._attention_dropout_rate, use_bias=self._use_bias, norm_first=self._norm_first, norm_epsilon=self._norm_epsilon, intermediate_dropout=self._intermediate_dropout, attention_initializer=attention_initializer(input_shape[2]), name=("layer_%d" % i))) self.output_normalization = tf.keras.layers.LayerNormalization( epsilon=1e-6, dtype="float32") super(TransformerDecoder, self).build(input_shape)