def test_decoding(self, spatial_shape, attn_dims): bs = 2 num_heads = 3 num_features = 4 rng = random.PRNGKey(0) key1, key2 = random.split(rng) inputs = random.normal( key1, (bs,) + spatial_shape + (num_heads * num_features,)) module = nn.MultiHeadDotProductAttention( num_heads=num_heads, qkv_features=num_heads * num_features, attention_axis=attn_dims, causal_mask=True, precision=lax.Precision.HIGHEST, decode=False) decode_module = module.clone(decode=True) initial_vars = decode_module.init(key2, inputs, inputs) y_ref = jax.jit(lambda x: module.apply(initial_vars, x, x))(inputs) # feed the inputs sequentially to simulate decoding def body_fn(vars_in, x): y, vars_out = decode_module.apply(vars_in, x, x, decode=True, mutable=['cache']) return vars_out, y # scan_in_dim supports scanning multiple dims _, y = jax_utils.scan_in_dim(body_fn, initial_vars, inputs, axis=attn_dims, keepdims=True) np.testing.assert_allclose(y_ref, y, atol=1e-5)
def __call__(self, inputs, *, deterministic): """Applies Encoder1DBlock module. Args: inputs: Inputs to the layer. deterministic: Dropout will not be applied when set to true. Returns: output after transformer encoder block. """ # Attention block. assert inputs.ndim == 3, f'Expected (batch, seq, hidden) got {inputs.shape}' x = nn.LayerNorm(dtype=self.dtype)(inputs) x = nn.MultiHeadDotProductAttention( dtype=self.dtype, kernel_init=nn.initializers.xavier_uniform(), broadcast_dropout=False, deterministic=deterministic, dropout_rate=self.attention_dropout_rate, num_heads=self.num_heads)(x, x) x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) x = x + inputs # MLP block. y = nn.LayerNorm(dtype=self.dtype)(x) y = MlpBlock(mlp_dim=self.mlp_dim, dtype=self.dtype, dropout_rate=self.dropout_rate)( y, deterministic=deterministic) return x + y
def __call__(self, targets, encoded, decoder_mask=None, encoder_decoder_mask=None): """Applies EncoderDecoder1DBlock module. Args: targets: input data for decoder encoded: input data from encoder decoder_mask: decoder self-attention mask. encoder_decoder_mask: encoder-decoder attention mask. Returns: output after transformer encoder-decoder block. """ cfg = self.config # Decoder block. assert targets.ndim == 3 x = nn.LayerNorm(dtype=cfg.dtype)(targets) x = nn.SelfAttention( num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic, decode=cfg.decode)(x, decoder_mask) x = nn.Dropout(rate=cfg.dropout_rate)( x, deterministic=cfg.deterministic) x = x + targets # Encoder-Decoder block. y = nn.LayerNorm(dtype=cfg.dtype)(x) y = nn.MultiHeadDotProductAttention( num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic)( y, encoded, encoder_decoder_mask) y = nn.Dropout(rate=cfg.dropout_rate)( y, deterministic=cfg.deterministic) y = y + x # MLP block. z = nn.LayerNorm(dtype=cfg.dtype)(y) z = MlpBlock(config=cfg)(z) return y + z
def __call__(self, targets, encoded, decoder_mask=None, encoder_decoder_mask=None): """Applies Transformer block. Args: targets: input data for decoder `[batch_size, ..., length, dim]` encoded: input data from encoder `[batch_size, ..., length2, dim2]` decoder_mask: decoder self-attention mask encoder_decoder_mask: encoder-decoder attention mask Returns: Decoded data `[batch_size, ..., length, mlp_dim]` """ cfg = self.config # Decoder block. x = nn.LayerNorm(dtype=cfg.dtype)(targets) x = nn.SelfAttention(num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic, decode=cfg.decode)(x, decoder_mask) x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=cfg.deterministic) x = x + targets # Encoder-Decoder block. y = nn.LayerNorm(dtype=cfg.dtype)(x) y = nn.MultiHeadDotProductAttention( num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic)(y, encoded, encoder_decoder_mask) y = nn.Dropout(rate=cfg.dropout_rate)(y, deterministic=cfg.deterministic) y = y + x # MLP block. z = nn.LayerNorm(dtype=cfg.dtype)(y) z = MLPBlock(config=cfg)(z) return y + z
def test_multihead_encoder_decoder_attention(self): rng = random.PRNGKey(0) q = jnp.ones((4, 2, 3, 5)) kv = jnp.ones((4, 2, 3, 5)) sa_module = nn.MultiHeadDotProductAttention( num_heads=8, qkv_features=16, kernel_init=initializers.ones, bias_init=initializers.zeros, ) y, _ = sa_module.init_with_output(rng, q, kv) self.assertEqual(y.shape, q.shape)
def test_multihead_self_attention_w_dropout(self): rng = random.PRNGKey(0) x = jnp.ones((4, 2, 3, 5)) sa_module = nn.MultiHeadDotProductAttention( num_heads=8, qkv_features=16, kernel_init=initializers.ones, bias_init=initializers.zeros, dropout_rate=0.1, ) rng1, rng2 = random.split(rng) rngs = {'params': rng1, 'dropout': rng2} y, _ = sa_module.init_with_output(rngs, x, x) self.assertEqual(y.shape, x.shape)
def __call__(self, x): # TODO(lbeyer): condition on GAP(x) n, _, d = x.shape probe = self.param('probe', nn.initializers.xavier_uniform(), (1, 1, d), x.dtype) probe = jnp.tile(probe, [n, 1, 1]) x = nn.MultiHeadDotProductAttention( num_heads=self.num_heads, kernel_init=nn.initializers.xavier_uniform())(probe, x) # TODO(lbeyer): dropout on head? y = nn.LayerNorm()(x) x = x + MlpBlock(mlp_dim=self.mlp_dim)(y) return x[:, 0]
def __call__(self, images: jnp.ndarray, train: Optional[bool] = None): train = nn.module.merge_param("train", self.train, train) transformer = self.transformer or {} # Convert images to patches. x = self.patches(images, self.hidden_size, self.patch_size, self.patch_grid) # Add "class" token if necessary. n, _, c = x.shape if self.classifier == "token": cls = self.param("cls", nn.initializers.zeros, (1, 1, self.hidden_size)) cls = jnp.tile(cls, [n, 1, 1]) x = jnp.concatenate([cls, x], axis=1) # Encode tokens. x, extra_info = BatchEnsembleEncoder( train=train, name="BatchEnsembleTransformer", **transformer)( x) # Reduce tokens to a single vector representation. if self.classifier == "token": # Take the first token's output as representation as in BERT. x = x[:, 0] elif self.classifier == "gap": # Average all tokens. x = jnp.mean(x, axis=tuple(range(1, x.ndim - 1))) # (1,) or (1, 2) elif self.classifier == "map": probe = self.param("probe", nn.initializers.xavier_uniform(), (1, 1, c)) probe = jnp.tile(probe, [n, 1, 1]) attention = nn.MultiHeadDotProductAttention( deterministic=not train, num_heads=transformer.get("attention", {}).get("num_heads", 1), kernel_init=nn.initializers.xavier_uniform()) x = attention(inputs_q=probe, inputs_kv=x) y = nn.LayerNorm()(x) y = patch_transformer_lib.MlpBlock( mlp_dim=transformer["mlp_dim"], dropout_rate=0, deterministic=not train)(y) x = (x + y)[:, 0] else: raise ValueError(f"Unknown classifier: {self.classifier}") if self.representation_size is None: x = identity.IdentityLayer(name="pre_logits")(x) else: x = nn.Dense(self.representation_size, name="pre_logits")(x) x = nn.tanh(x) x = nn.Dense(self.num_classes, kernel_init=self.head_kernel_init, name="head")(x) return x, extra_info
def test_autoregresive_receptive_field_1d(self): """Tests the autoregresive self-attention receptive field.""" rng = random.PRNGKey(0) rng1, rng2 = random.split(rng, num=2) length = 10 dim = 1 num_heads = 1 input_shape = (1, length, dim) inputs = random.normal(rng2, input_shape) module = nn.MultiHeadDotProductAttention( num_heads=num_heads, kernel_init=jax.nn.initializers.ones, deterministic=False) initial_vars = module.init(rng1, inputs, inputs) causal_mask = nn.attention.make_causal_mask(jnp.ones(input_shape[:-1])) def model_loss(inputs, pos): out = module.apply(initial_vars, inputs, inputs, causal_mask) assert out.shape == input_shape assert len(out.shape) == 3 return out[0, pos, :].sum() grad_fn = jax.jit(jax.grad(model_loss)) def get_receptive_field_1d(pos): g = grad_fn(inputs, pos)[0, :, :] return jnp.any((jnp.abs(g) > 1e-5).astype(jnp.uint32), axis=-1) for i in range(length): deps = get_receptive_field_1d(i) assert (deps[:i] == 1).all(), ( 'Receptive Field Error: Some of the ' 'previous postions are not reachable ' 'in autoregressive self-attention.') if i != length - 1: k = i + 1 assert (deps[k:] == 0).all(), ( 'Receptive Field Error: Some of the ' 'future postions are reachable in ' 'autoregressive self-attention.')
def __call__(self, inputs: jnp.ndarray, *, deterministic: Optional[bool] = None): """Applies Encoder1Dlock module.""" assert inputs.ndim == 3, f"Expected (batch, seq, hidden) got {inputs.shape}" x = nn.LayerNorm(dtype=self.dtype, name="LayerNorm_0")(inputs) x = nn.MultiHeadDotProductAttention( dtype=self.dtype, kernel_init=nn.initializers.xavier_uniform(), broadcast_dropout=False, deterministic=deterministic, name="MultiHeadDotProductAttention_1", num_heads=self.num_heads, dropout_rate=self.attention_dropout_rate)(x, x) x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) x = x + inputs # MLP block. y = nn.LayerNorm(dtype=self.dtype, name="LayerNorm_2")(x) y = self.mlp_class(name="MlpBlock_3")(y, deterministic=deterministic) return x + y
def __call__(self, targets, encoded, decoder_mask = None, encoder_decoder_mask = None, decoder_relative_position = None, encoder_decoder_relative_position = None): """Applies Transformer block. Args: targets: input data for decoder `[batch_size, ..., length, dim]` encoded: input data from encoder `[batch_size, ..., length2, dim2]` decoder_mask: decoder self-attention mask encoder_decoder_mask: encoder-decoder attention mask decoder_relative_position: decoder relative positions tensor `[batch_sizes..., length2, length2]' encoder_decoder_relative_position: encoder-decoder relative tensor `[batch_sizes..., length2, length]' Returns: Decoded data `[batch_size, ..., length2, mlp_dim]` """ cfg = self.config # Decoder block. x = nn.LayerNorm(dtype=cfg.dtype)(targets) if cfg.use_relative_attention: x = relative_attention.RelativeSelfAttention( num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic, bidirectional=self.bidirectional_attention, num_relative_position_buckets=self.num_relative_position_buckets, max_distance=self.max_distance)( x, decoder_mask, decoder_relative_position) else: x = nn.SelfAttention( num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic)(x, decoder_mask) x = nn.Dropout(rate=cfg.dropout_rate)( x, deterministic=cfg.deterministic) x = x + targets # Encoder-Decoder block. y = nn.LayerNorm(dtype=cfg.dtype)(x) if self.relative_cross_attention: y = relative_attention.RelativeMultiHeadDotProductAttention( num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic, bidirectional=self.bidirectional_cross_attention, num_relative_position_buckets=( self.num_relative_position_buckets_cross_attention), max_distance=self.max_distance_cross_attention)( y, encoded, encoder_decoder_mask, encoder_decoder_relative_position) else: y = nn.MultiHeadDotProductAttention( num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic)(y, encoded, encoder_decoder_mask) y = nn.Dropout(rate=cfg.dropout_rate)( y, deterministic=cfg.deterministic) y = y + x # MLP block. z = nn.LayerNorm(dtype=cfg.dtype)(y) z = MLPBlock(config=cfg)(z) return y + z
def __call__(self, images: jnp.ndarray, train: Optional[bool] = None, mean_field_factor: float = -1., **gp_kwargs): train = nn.module.merge_param("train", self.train, train) transformer = self.transformer or {} # Convert images to patches. x = self.patches(images, self.hidden_size, self.patch_size, self.patch_grid) # Add "class" token if necessary. n, _, c = x.shape if self.classifier == "token": cls = self.param("cls", nn.initializers.zeros, (1, 1, self.hidden_size)) cls = jnp.tile(cls, [n, 1, 1]) x = jnp.concatenate([cls, x], axis=1) # Encode tokens. x, extra_info = vit_batchensemble.BatchEnsembleEncoder( train=train, name="Transformer", **transformer)(x) # Reduce tokens to a single vector representation. if self.classifier == "token": # Take the first token's output as representation as in BERT. x = x[:, 0] elif self.classifier == "gap": # Average all tokens. x = jnp.mean(x, axis=tuple(range(1, x.ndim - 1))) # (1,) or (1, 2) elif self.classifier == "map": probe = self.param("probe", nn.initializers.xavier_uniform(), (1, 1, c)) # x may have been subject to tiling, n can be different from x.shape[0]. probe = jnp.tile(probe, [x.shape[0], 1, 1]) attention = nn.MultiHeadDotProductAttention( deterministic=not train, num_heads=transformer.get("attention", {}).get("num_heads", 1), kernel_init=nn.initializers.xavier_uniform()) x = attention(inputs_q=probe, inputs_kv=x) y = nn.LayerNorm()(x) y = vit.MlpBlock(mlp_dim=transformer["mlp_dim"], dropout_rate=0)(y, deterministic=not train) x = (x + y)[:, 0] else: raise ValueError(f"Unknown classifier: {self.classifier}") if self.representation_size is None: x = vit.IdentityLayer(name="pre_logits")(x) extra_info["pre_logits"] = x else: x = nn.Dense(self.representation_size, name="pre_logits")(x) extra_info["pre_logits"] = x x = nn.tanh(x) if self.use_gp_layer: x_gp = self.gp_layer(x, **gp_kwargs) # Gaussian process layer output: a tuple of logits, covmat, and optionally # random features. extra_info["covmat"] = x_gp[1] if len(x_gp) > 2: extra_info["random_features"] = x_gp[2] if train: x = x_gp[0] else: # During inference, compute posterior mean by adjusting the original # logits with predictive uncertainty. x = ed.nn.utils.mean_field_logits( logits=x_gp[0], covmat=x_gp[1], mean_field_factor=mean_field_factor) else: x = nn.Dense(self.num_classes, kernel_init=self.head_kernel_init, name="batchensemble_head")(x) return x, extra_info
def __call__(self, images: jnp.ndarray, train: Optional[bool] = None): train = nn.module.merge_param("train", self.train, train) transformer = self.transformer or {} # Convert images to patches. x = self.embed(images, self.hidden_size, self.patches.size) # Add "class" token if necessary. n, _, c = x.shape if self.classifier == "token": cls = self.param("cls", nn.initializers.zeros, (1, 1, self.hidden_size)) cls = jnp.tile(cls, [n, 1, 1]) x = jnp.concatenate([cls, x], axis=1) # Encode tokens. x, extra_info = BatchEnsembleEncoder( train=train, name="Transformer", **transformer)( x) # Reduce tokens to a single vector representation. if self.classifier == "token": # Take the first token's output as representation as in BERT. x = x[:, 0] elif self.classifier == "gap": # Average all tokens. x = jnp.mean(x, axis=tuple(range(1, x.ndim - 1))) # (1,) or (1, 2) elif self.classifier == "map": probe = self.param("probe", nn.initializers.xavier_uniform(), (1, 1, c)) # x may have been subject to tiling, n can be different from x.shape[0]. probe = jnp.tile(probe, [x.shape[0], 1, 1]) attention = nn.MultiHeadDotProductAttention( deterministic=not train, num_heads=transformer.get("attention", {}).get("num_heads", 1), kernel_init=nn.initializers.xavier_uniform()) x = attention(inputs_q=probe, inputs_kv=x) y = nn.LayerNorm()(x) y = vit.MlpBlock( mlp_dim=transformer["mlp_dim"], dropout_rate=0)( y, deterministic=not train) x = (x + y)[:, 0] else: raise ValueError(f"Unknown classifier: {self.classifier}") if self.representation_size is None: x = IdentityLayer(name="pre_logits")(x) extra_info["pre_logits"] = x else: x = ed.nn.DenseBatchEnsemble( self.representation_size, self.transformer.get("ens_size"), activation=None, alpha_init=ed.nn.utils.make_sign_initializer( self.transformer.get("random_sign_init")), gamma_init=ed.nn.utils.make_sign_initializer( self.transformer.get("random_sign_init")), name="pre_logits")(x) extra_info["pre_logits"] = x x = nn.tanh(x) x = ed.nn.DenseBatchEnsemble( self.num_classes, self.transformer.get("ens_size"), activation=None, alpha_init=ed.nn.utils.make_sign_initializer( self.transformer.get("random_sign_init")), gamma_init=ed.nn.utils.make_sign_initializer( self.transformer.get("random_sign_init")), kernel_init=self.head_kernel_init, name="batchensemble_head")(x) return x, extra_info
def __call__(self, inputs, decoder_mask=None, encoder_decoder_mask=None, inputs_kv=None): """Applies EncoderDecoder1DBlock module. Args: inputs: input data for decoder decoder_mask: decoder self-attention mask. encoder_decoder_mask: encoder-decoder attention mask. Returns: output after transformer encoder-decoder block. """ cfg = self.config # Decoder block. assert inputs.ndim == 3 # assert decoder_mask.ndim == 4 if cfg.use_layernorm: x = nn.LayerNorm(dtype=cfg.dtype)(inputs) else: x = inputs if self.is_self_att: x = nn.SelfAttention(num_heads=cfg.num_heads, out_features=self.out_features, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic, attention_fn=self.attention_fn, decode=cfg.decode)(x, decoder_mask) else: if cfg.use_layernorm: x_kv = nn.LayerNorm(dtype=cfg.dtype)(inputs_kv) else: x_kv = inputs_kv x = nn.MultiHeadDotProductAttention( num_heads=cfg.num_heads, out_features=self.out_features, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic, attention_fn=self.attention_fn, decode=cfg.decode)(x, x_kv, mask=decoder_mask) x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=cfg.deterministic) x = x + inputs # MLP block. if cfg.use_layernorm: z = nn.LayerNorm(dtype=cfg.dtype)(x) else: z = x z = MlpBlock(config=cfg)(z) return x + z
def __call__(self, targets, encoded, decoder_mask=None, encoder_decoder_mask=None, train=True): """Applies EncoderDecoder1DBlock module. Args: targets: input data for decoder encoded: input data from encoder decoder_mask: decoder self-attention mask. encoder_decoder_mask: encoder-decoder attention mask. train: if it is training. Returns: output after transformer encoder-decoder block. """ # Decoder block. assert targets.ndim == 3 if self.normalizer in [ 'batch_norm', 'layer_norm', 'pre_layer_norm', 'none' ]: maybe_pre_normalize = model_utils.get_normalizer( self.normalizer, train) maybe_post_normalize = model_utils.get_normalizer('none', train) elif self.normalizer == 'post_layer_norm': maybe_pre_normalize = model_utils.get_normalizer('none', train) maybe_post_normalize = model_utils.get_normalizer( self.normalizer, train) else: raise ValueError('Unsupported normalizer: {}'.format( self.normalizer)) x = maybe_pre_normalize()(targets) x = nn.SelfAttention(num_heads=self.num_heads, dtype=self.dtype, qkv_features=self.qkv_dim, kernel_init=self.dec_self_attn_kernel_init_fn, bias_init=nn.initializers.normal(stddev=1e-6), use_bias=False, broadcast_dropout=False, dropout_rate=self.attention_dropout_rate, decode=self.decode, name='DecoderSelfAttention')( x, decoder_mask, deterministic=not train) x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) x = x + targets x = maybe_post_normalize()(x) # Encoder-Decoder block. y = maybe_pre_normalize()(x) y = nn.MultiHeadDotProductAttention( num_heads=self.num_heads, dtype=self.dtype, qkv_features=self.qkv_dim, kernel_init=self.dec_cross_attn_kernel_init_fn, bias_init=nn.initializers.normal(stddev=1e-6), use_bias=False, broadcast_dropout=False, dropout_rate=self.attention_dropout_rate)(y, encoded, encoder_decoder_mask, deterministic=not train) y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=not train) y = y + x y = maybe_post_normalize()(y) # MLP block. z = maybe_pre_normalize()(y) z = MlpBlock(mlp_dim=self.mlp_dim, dtype=self.dtype, dropout_rate=self.dropout_rate, name='MLPBlock')(z, train=train) res = y + z return maybe_post_normalize()(res)
def __call__( self, targets, encoded, inputs_segmentation=None, # REFACTOR targets_segmentation=None, # REFACTOR padding_mask=None, # REFACTOR key_padding_mask=None): # REFACTOR """Applies EncoderDecoder1DBlock module. Args: targets: input data for decoder encoded: input data from encoder inputs_segmentation: input segmentation info for packed examples. targets_segmentation: target segmentation info for packed examples. padding_mask: bool, mask padding tokens key_padding_mask: bool, mask padding tokens Returns: output after transformer encoder-decoder block. """ cfg = self.config # Decoder block. assert targets.ndim == 3 x = nn.LayerNorm(dtype=cfg.dtype)(targets) x = nn.SelfAttention(num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, attention_axis=(1, ), causal_mask=True, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic, decode=cfg.decode)( x, padding_mask=padding_mask, segmentation=targets_segmentation) x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=cfg.deterministic) x = x + targets # Encoder-Decoder block. y = nn.LayerNorm(dtype=cfg.dtype)(x) y = nn.MultiHeadDotProductAttention( num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, attention_axis=(1, ), causal_mask=False, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic)( y, encoded, padding_mask=padding_mask, key_padding_mask=key_padding_mask, segmentation=targets_segmentation, key_segmentation=inputs_segmentation) y = nn.Dropout(rate=cfg.dropout_rate)(y, deterministic=cfg.deterministic) y = y + x # MLP block. z = nn.LayerNorm(dtype=cfg.dtype)(y) z = MlpBlock(config=cfg)(z) return y + z