def _init_mixing_sublayer(self, layer, model_arch, mixing_key): """Initializes config-dependent mixing sublayer.""" if model_arch == ModelArchitecture.BERT: mixing_sublayer = nn.SelfAttention( num_heads=self.config.num_heads, qkv_features=self.config.d_model, broadcast_dropout=False, kernel_init=default_kernel_init, bias_init=default_bias_init, dropout_rate=self.config.mixing_dropout_rate, use_bias=True, name=f"self_attention_{layer}") elif model_arch == ModelArchitecture.F_NET: mixing_sublayer = layers.FourierTransform( fourier_transform=self.fourier_transform, name=f"fourier_transform_{layer}") elif model_arch == ModelArchitecture.FF_ONLY: mixing_sublayer = layers.IdentityTransform( name=f"identity_transform_{layer}") elif model_arch == ModelArchitecture.LINEAR: mixing_sublayer = layers.LinearTransform( precision=lax.Precision.DEFAULT, name=f"linear_transform_{layer}") elif model_arch == ModelArchitecture.RANDOM: mixing_sublayer = layers.RandomTransform( max_seq_length=self.config.max_seq_length, d_model=self.config.d_model, key=mixing_key, precision=lax.Precision.DEFAULT, name=f"random_transform_{layer}") else: raise ValueError("Unexpected model architecture: %s" % model_arch.name) return mixing_sublayer
def test_decoding(self, spatial_shape, attn_dims): bs = 2 num_heads = 3 num_features = 4 rng = random.PRNGKey(0) key1, key2 = random.split(rng) inputs = random.normal( key1, (bs,) + spatial_shape + (num_heads * num_features,)) module = nn.SelfAttention( num_heads=num_heads, qkv_features=num_heads * num_features, precision=lax.Precision.HIGHEST, decode=False) decode_module = module.clone(decode=True) initial_vars = decode_module.init(key2, inputs) causal_mask = nn.attention.make_causal_mask(jnp.ones((bs,) + spatial_shape)) y_ref = jax.jit(lambda x, y: module.apply(initial_vars, x, y))( inputs, causal_mask) # feed the inputs sequentially to simulate decoding def body_fn(vars_in, x): y, vars_out = decode_module.apply(vars_in, x, mutable=['cache']) return vars_out, y # scan_in_dim supports scanning multiple dims _, y = jax_utils.scan_in_dim(body_fn, initial_vars, inputs, axis=attn_dims, keepdims=True) np.testing.assert_allclose(y_ref, y, atol=1e-5)
def __call__(self, inputs, encoder_mask=None): """Applies Encoder1DBlock module. Args: inputs: input data. encoder_mask: encoder self-attention mask. Returns: output after transformer encoder block. """ cfg = self.config # Attention block. assert inputs.ndim == 3 x = nn.LayerNorm(dtype=cfg.dtype)(inputs) x = nn.SelfAttention(num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic)(x, encoder_mask) x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=cfg.deterministic) x = x + inputs # MLP block. y = nn.LayerNorm(dtype=cfg.dtype)(x) y = MlpBlock(config=cfg)(y) return x + y
def __call__(self, inputs, encoder_mask=None): """Applies Transformer block. Args: inputs: input data `[batch_size, ..., length, dim]` encoder_mask: encoder self-attention mask Returns: Encoded input data `[batch_size, ..., length, mlp_dim]` """ cfg = self.config # Attention block. x = nn.LayerNorm(dtype=cfg.dtype)(inputs) x = nn.SelfAttention(num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic)(x, encoder_mask) x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=cfg.deterministic) x = x + inputs # MLP block. y = nn.LayerNorm(dtype=cfg.dtype)(x) y = MLPBlock(config=cfg)(y) return x + y
def __call__(self, targets, encoded, decoder_mask=None, encoder_decoder_mask=None): """Applies EncoderDecoder1DBlock module. Args: targets: input data for decoder encoded: input data from encoder decoder_mask: decoder self-attention mask. encoder_decoder_mask: encoder-decoder attention mask. Returns: output after transformer encoder-decoder block. """ cfg = self.config # Decoder block. assert targets.ndim == 3 x = nn.LayerNorm(dtype=cfg.dtype)(targets) x = nn.SelfAttention( num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic, decode=cfg.decode)(x, decoder_mask) x = nn.Dropout(rate=cfg.dropout_rate)( x, deterministic=cfg.deterministic) x = x + targets # Encoder-Decoder block. y = nn.LayerNorm(dtype=cfg.dtype)(x) y = nn.MultiHeadDotProductAttention( num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic)( y, encoded, encoder_decoder_mask) y = nn.Dropout(rate=cfg.dropout_rate)( y, deterministic=cfg.deterministic) y = y + x # MLP block. z = nn.LayerNorm(dtype=cfg.dtype)(y) z = MlpBlock(config=cfg)(z) return y + z
def __call__(self, inputs, encoder_mask=None, encoder_relative_position=None): """Applies Transformer block. Args: inputs: input data `[batch_size, ..., length, dim]` encoder_mask: encoder self-attention mask encoder_relative_position: encoder relative positions tensor `[batch_sizes..., length, length]' Returns: Encoded input data `[batch_size, ..., length, mlp_dim]` """ cfg = self.config # Attention block. x = nn.LayerNorm(dtype=cfg.dtype)(inputs) if cfg.use_relative_attention: x = relative_attention.RelativeSelfAttention( num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic, bidirectional=self.bidirectional_attention, num_relative_position_buckets=self. num_relative_position_buckets, max_distance=self.max_distance)(x, encoder_mask, encoder_relative_position) else: x = nn.SelfAttention(num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic)(x, encoder_mask) x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=cfg.deterministic) x = x + inputs # MLP block. y = nn.LayerNorm(dtype=cfg.dtype)(x) y = MLPBlock(config=cfg)(y) return x + y
def __call__(self, targets, encoded, decoder_mask=None, encoder_decoder_mask=None): """Applies Transformer block. Args: targets: input data for decoder `[batch_size, ..., length, dim]` encoded: input data from encoder `[batch_size, ..., length2, dim2]` decoder_mask: decoder self-attention mask encoder_decoder_mask: encoder-decoder attention mask Returns: Decoded data `[batch_size, ..., length, mlp_dim]` """ cfg = self.config # Decoder block. x = nn.LayerNorm(dtype=cfg.dtype)(targets) x = nn.SelfAttention(num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic, decode=cfg.decode)(x, decoder_mask) x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=cfg.deterministic) x = x + targets # Encoder-Decoder block. y = nn.LayerNorm(dtype=cfg.dtype)(x) y = nn.MultiHeadDotProductAttention( num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic)(y, encoded, encoder_decoder_mask) y = nn.Dropout(rate=cfg.dropout_rate)(y, deterministic=cfg.deterministic) y = y + x # MLP block. z = nn.LayerNorm(dtype=cfg.dtype)(y) z = MLPBlock(config=cfg)(z) return y + z
def __call__(self, pixel_embeddings, patch_embeddings): cfg = self.config v = cfg.image_size // cfg.patch_size #Inner T-Block x = nn.LayerNorm(dtype=cfg.dtype)(pixel_embeddings) x = nn.SelfAttention(num_heads=cfg.inner_heads, qkv_features=cfg.inner_heads * cfg.inner_dim_head, out_features=cfg.inner_dim, use_bias=False, kernel_init=cfg.kernel_init, deterministic=True)(x) x = x + pixel_embeddings y = nn.LayerNorm(dtype=cfg.dtype)(x) y = MlpBlock(config=cfg, inner=True)(y) inner_output = x + y x = rearrange(pixel_embeddings, '... n d -> ... (n d)') x = nn.Dense(cfg.outer_dim, dtype=cfg.dtype, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init)(x) x = rearrange(x, '(b h w) d -> b (h w) d', h=v, w=v) x = jnp.pad(x, ((0, 0), (0, 1), (0, 0))) x = x + patch_embeddings #Outer T-Block x = nn.LayerNorm(dtype=cfg.dtype)(x) x = nn.SelfAttention(num_heads=cfg.outer_heads, qkv_features=cfg.outer_heads * cfg.outer_dim_head, out_features=cfg.outer_dim, use_bias=False, kernel_init=cfg.kernel_init, deterministic=True)(x) x = x + patch_embeddings y = nn.LayerNorm(dtype=cfg.dtype)(x) y = MlpBlock(config=cfg, inner=False)(y) outer_output = x + y return inner_output, outer_output
def setup(self): self.attention_layer = nn.SelfAttention( num_heads=self.num_heads, dtype=self.dtype, qkv_features=self.model_dim, dropout_rate=self.dropout_rate, kernel_init=self.kernel_init, bias_init=self.bias_init, ) self.dropout = nn.Dropout(self.dropout_rate) self.layer_norm = nn.LayerNorm(epsilon=self.layer_norm_epsilon)
def test_multihead_self_attention(self): rng = random.PRNGKey(0) x = jnp.ones((4, 6, 5)) sa_module = nn.SelfAttention( num_heads=8, qkv_features=16, kernel_init=initializers.ones, bias_init=initializers.zeros, ) y, _ = sa_module.init_with_output(rng, x) self.assertEqual(y.shape, x.shape)
def __call__(self, inputs, encoder_mask=None, train=True): """Applies Encoder1DBlock module. Args: inputs: input data. encoder_mask: encoder self-attention mask. train: if it is training. Returns: output after transformer encoder block. """ # Attention block. assert inputs.ndim == 3 if self.normalizer in [ 'batch_norm', 'layer_norm', 'pre_layer_norm', 'none' ]: maybe_pre_normalize = model_utils.get_normalizer( self.normalizer, train) maybe_post_normalize = model_utils.get_normalizer('none', train) elif self.normalizer == 'post_layer_norm': maybe_pre_normalize = model_utils.get_normalizer('none', train) maybe_post_normalize = model_utils.get_normalizer( self.normalizer, train) else: raise ValueError('Unsupported normalizer: {}'.format( self.normalizer)) x = maybe_pre_normalize()(inputs) x = nn.SelfAttention(num_heads=self.num_heads, dtype=self.dtype, qkv_features=self.qkv_dim, kernel_init=self.enc_self_attn_kernel_init_fn, bias_init=nn.initializers.normal(stddev=1e-6), use_bias=False, broadcast_dropout=False, dropout_rate=self.attention_dropout_rate, name='EncoderSelfAttention')( x, mask=encoder_mask, deterministic=not train) x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) x = x + inputs x = maybe_post_normalize()(x) # MLP block. y = maybe_pre_normalize()(x) y = MlpBlock(mlp_dim=self.mlp_dim, dtype=self.dtype, dropout_rate=self.dropout_rate, name='MLPBlock')(y, train=train) res = x + y return maybe_post_normalize()(res)
def __call__(self, x, training: bool = True): x = x.astype(jnp.int32) x = nn.Embed( num_embeddings=self.num_embeddings, features=self.embedding_dim, name="embed", )(x) x = jnp.reshape(x, (x.shape[0], -1)) x = nn.SelfAttention( num_heads=self.num_heads, qkv_features=self.qkv_features, out_features=self.out_features, use_bias=False, deterministic=not training, )(x) return x
def __call__(self, query, deterministic): out_params = query.shape[-1] # Attention from query to value attention_output = nn.SelfAttention( num_heads=self.attention_heads, qkv_features=self.qkv_params, out_features=out_params, dropout_rate=self.dropout_rate)( query, deterministic=deterministic) normalized_attention_output = nn.LayerNorm()(query + attention_output) mlp_output = Mlp( hidden_params=self.mlp_params, out_params=out_params, dropout_rate=self.dropout_rate)( normalized_attention_output, deterministic=deterministic) return nn.LayerNorm()(normalized_attention_output + mlp_output)
def __call__( self, inputs, inputs_segmentation=None, # REFACTOR padding_mask=None): # REFACTOR """Applies Encoder1DBlock module. Args: inputs: input data. inputs_segmentation: input segmentation info for packed examples. padding_mask: bool, mask padding tokens. Returns: output after transformer encoder block. """ cfg = self.config # Attention block. assert inputs.ndim == 3 x = nn.LayerNorm(dtype=cfg.dtype)(inputs) x = nn.SelfAttention(num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, attention_axis=(1, ), causal_mask=False, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic)( x, segmentation=inputs_segmentation, # REFACTOR padding_mask=padding_mask) # REFACTOR x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=cfg.deterministic) x = x + inputs # MLP block. y = nn.LayerNorm(dtype=cfg.dtype)(x) y = MlpBlock(config=cfg)(y) return x + y
def __call__(self, inputs, temb, deterministic, decoder_mask=None, encoder_decoder_mask=None): """Applies EncoderDecoder1DBlock module. Args: inputs: Input data for decoder. temb: Time embedding representation. deterministic: Should be deterministic in dropout? decoder_mask: Decoder self-attention mask. encoder_decoder_mask: Encoder-decoder attention mask. Returns: output after transformer encoder-decoder block. """ cfg = self.config # Decoder block. assert inputs.ndim == 3 x = nn.LayerNorm(dtype=cfg.dtype)(inputs) x = nn.SelfAttention(num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=deterministic, decode=False)(x, decoder_mask) x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=deterministic) x = x + inputs # MLP block. z = nn.LayerNorm(dtype=cfg.dtype)(x) z = MlpBlock(config=cfg)(z, temb, deterministic) return x + z
def __call__(self, x, train=True): out = {} y = nn.LayerNorm(name='LayerNorm_0')(x) y = out['sa'] = nn.SelfAttention( num_heads=self.num_heads, kernel_init=nn.initializers.xavier_uniform(), deterministic=train, name='MultiHeadDotProductAttention_1', )(y) y = nn.Dropout(rate=self.dropout)(y, train) x = out['+sa'] = x + y y = nn.LayerNorm(name='LayerNorm_2')(x) y = out['mlp'] = MlpBlock( mlp_dim=self.mlp_dim, dropout=self.dropout, name='MlpBlock_3', )(y, train) y = nn.Dropout(rate=self.dropout)(y, train) x = out['+mlp'] = x + y return x, out
def __call__(self, targets, encoded, decoder_mask = None, encoder_decoder_mask = None, decoder_relative_position = None, encoder_decoder_relative_position = None): """Applies Transformer block. Args: targets: input data for decoder `[batch_size, ..., length, dim]` encoded: input data from encoder `[batch_size, ..., length2, dim2]` decoder_mask: decoder self-attention mask encoder_decoder_mask: encoder-decoder attention mask decoder_relative_position: decoder relative positions tensor `[batch_sizes..., length2, length2]' encoder_decoder_relative_position: encoder-decoder relative tensor `[batch_sizes..., length2, length]' Returns: Decoded data `[batch_size, ..., length2, mlp_dim]` """ cfg = self.config # Decoder block. x = nn.LayerNorm(dtype=cfg.dtype)(targets) if cfg.use_relative_attention: x = relative_attention.RelativeSelfAttention( num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic, bidirectional=self.bidirectional_attention, num_relative_position_buckets=self.num_relative_position_buckets, max_distance=self.max_distance)( x, decoder_mask, decoder_relative_position) else: x = nn.SelfAttention( num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic)(x, decoder_mask) x = nn.Dropout(rate=cfg.dropout_rate)( x, deterministic=cfg.deterministic) x = x + targets # Encoder-Decoder block. y = nn.LayerNorm(dtype=cfg.dtype)(x) if self.relative_cross_attention: y = relative_attention.RelativeMultiHeadDotProductAttention( num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic, bidirectional=self.bidirectional_cross_attention, num_relative_position_buckets=( self.num_relative_position_buckets_cross_attention), max_distance=self.max_distance_cross_attention)( y, encoded, encoder_decoder_mask, encoder_decoder_relative_position) else: y = nn.MultiHeadDotProductAttention( num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic)(y, encoded, encoder_decoder_mask) y = nn.Dropout(rate=cfg.dropout_rate)( y, deterministic=cfg.deterministic) y = y + x # MLP block. z = nn.LayerNorm(dtype=cfg.dtype)(y) z = MLPBlock(config=cfg)(z) return y + z
def __call__( self, targets, encoded, inputs_segmentation=None, # REFACTOR targets_segmentation=None, # REFACTOR padding_mask=None, # REFACTOR key_padding_mask=None): # REFACTOR """Applies EncoderDecoder1DBlock module. Args: targets: input data for decoder encoded: input data from encoder inputs_segmentation: input segmentation info for packed examples. targets_segmentation: target segmentation info for packed examples. padding_mask: bool, mask padding tokens key_padding_mask: bool, mask padding tokens Returns: output after transformer encoder-decoder block. """ cfg = self.config # Decoder block. assert targets.ndim == 3 x = nn.LayerNorm(dtype=cfg.dtype)(targets) x = nn.SelfAttention(num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, attention_axis=(1, ), causal_mask=True, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic, decode=cfg.decode)( x, padding_mask=padding_mask, segmentation=targets_segmentation) x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=cfg.deterministic) x = x + targets # Encoder-Decoder block. y = nn.LayerNorm(dtype=cfg.dtype)(x) y = nn.MultiHeadDotProductAttention( num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, attention_axis=(1, ), causal_mask=False, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic)( y, encoded, padding_mask=padding_mask, key_padding_mask=key_padding_mask, segmentation=targets_segmentation, key_segmentation=inputs_segmentation) y = nn.Dropout(rate=cfg.dropout_rate)(y, deterministic=cfg.deterministic) y = y + x # MLP block. z = nn.LayerNorm(dtype=cfg.dtype)(y) z = MlpBlock(config=cfg)(z) return y + z
def __call__(self, inputs, decoder_mask=None, encoder_decoder_mask=None, inputs_kv=None): """Applies EncoderDecoder1DBlock module. Args: inputs: input data for decoder decoder_mask: decoder self-attention mask. encoder_decoder_mask: encoder-decoder attention mask. Returns: output after transformer encoder-decoder block. """ cfg = self.config # Decoder block. assert inputs.ndim == 3 # assert decoder_mask.ndim == 4 if cfg.use_layernorm: x = nn.LayerNorm(dtype=cfg.dtype)(inputs) else: x = inputs if self.is_self_att: x = nn.SelfAttention(num_heads=cfg.num_heads, out_features=self.out_features, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic, attention_fn=self.attention_fn, decode=cfg.decode)(x, decoder_mask) else: if cfg.use_layernorm: x_kv = nn.LayerNorm(dtype=cfg.dtype)(inputs_kv) else: x_kv = inputs_kv x = nn.MultiHeadDotProductAttention( num_heads=cfg.num_heads, out_features=self.out_features, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic, attention_fn=self.attention_fn, decode=cfg.decode)(x, x_kv, mask=decoder_mask) x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=cfg.deterministic) x = x + inputs # MLP block. if cfg.use_layernorm: z = nn.LayerNorm(dtype=cfg.dtype)(x) else: z = x z = MlpBlock(config=cfg)(z) return x + z
def __call__(self, inputs, train, decoder_mask=None, encoder_decoder_mask=None, inputs_positions=None, inputs_segmentation=None): """Applies Transformer1DBlock module. Args: inputs: input data train: bool: if model is training. decoder_mask: decoder self-attention mask. encoder_decoder_mask: encoder-decoder attention mask. inputs_positions: input subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. Returns: output after transformer block. """ # Attention block. assert inputs.ndim == 3 if self.normalizer in [ 'batch_norm', 'layer_norm', 'pre_layer_norm', 'none' ]: maybe_pre_normalize = model_utils.get_normalizer(self.normalizer, train, dtype=self.dtype) maybe_post_normalize = model_utils.get_normalizer('none', train, dtype=self.dtype) elif self.normalizer == 'post_layer_norm': maybe_pre_normalize = model_utils.get_normalizer('none', train, dtype=self.dtype) maybe_post_normalize = model_utils.get_normalizer(self.normalizer, train, dtype=self.dtype) else: raise ValueError('Unsupported normalizer: {}'.format( self.normalizer)) x = maybe_pre_normalize()(inputs) if self.attention_fn is None: attention_fn = nn.dot_product_attention else: attention_fn = self.attention_fn x = nn.SelfAttention(num_heads=self.num_heads, qkv_features=self.qkv_dim, decode=self.decode, dtype=self.dtype, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), use_bias=False, broadcast_dropout=False, attention_fn=attention_fn, dropout_rate=self.attention_dropout_rate, deterministic=not train)(x, decoder_mask) x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x) x = x + inputs x = maybe_post_normalize()(x) # MLP block. y = maybe_pre_normalize()(x) y = MlpBlock(mlp_dim=self.mlp_dim, dropout_rate=self.dropout_rate)(y, train=train) res = x + y return maybe_post_normalize()(res)
def exec_op(self, op, input_values, deterministic, training, **_): """Executes an op according to the normal concrete semantics.""" input_kwargs: Dict[str, Any] = op.input_kwargs op_kwargs: Dict[str, Any] = op.op_kwargs op_type = op.type if "name" not in op_kwargs: raise ValueError("Op kwargs must contain a name.") op_name = op_kwargs["name"] if op_type == OpType.NONE: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs assert len(op_kwargs) == 1 output_values = [lax.stop_gradient(input_value)] elif op_type == OpType.IDENTITY: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs assert len(op_kwargs) == 1 output_values = [input_value] # nn.linear elif op_type == OpType.DENSE: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs output_values = [nn.Dense(**op_kwargs)(input_value)] elif op_type == OpType.DENSE_GENERAL: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs assert 2 <= len(op_kwargs) <= 7 output_values = [nn.DenseGeneral(**op_kwargs)(input_value)] elif op_type == OpType.CONV: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs ks = op_kwargs["kernel_size"] if isinstance(ks, int): op_kwargs["kernel_size"] = (ks, ) * (input_value.ndim - 2) output_values = [nn.Conv(**op_kwargs)(input_value)] # others elif op_type == OpType.MUL: assert len(input_values) == 2 assert not input_kwargs assert len(op_kwargs) == 1 # name output_values = [input_values[0] * input_values[1]] elif op_type in [OpType.ADD, OpType.STOCH_DEPTH]: assert len(op_kwargs) == 1 # name input_value = input_values[0] if "layer_drop_rate" in input_kwargs: assert len(input_kwargs) == 1 survival_rate = 1 - input_kwargs["layer_drop_rate"] if survival_rate == 1.0 or deterministic: pass else: # Reuse dropout's rng stream. rng = self.make_rng("dropout") mask_shape = [input_value.shape[0] ] + [1] * (input_value.ndim - 1) mask = random.bernoulli(rng, p=survival_rate, shape=mask_shape) mask = jnp.tile(mask, [1] + list(input_value.shape[1:])) input_value = lax.select(mask, input_value / survival_rate, jnp.zeros_like(input_value)) else: assert not input_kwargs assert op_type == OpType.ADD if op_type == OpType.ADD: assert len(input_values) == 2 output_values = [input_value + input_values[1]] else: assert len(input_values) == 1 output_values = [input_value] elif op_type == OpType.SCALAR_MUL: assert len(input_values) == 1 input_value = input_values[0] assert len(input_kwargs) <= 1 assert len(op_kwargs) == 1 # name if "const" in input_kwargs: c = input_kwargs["const"] else: c = 1 / jnp.sqrt(input_values[0].shape[-1]) output_values = [input_values[0] * c] elif op_type == OpType.SCALAR_ADD: assert len(input_values) == 1 input_value = input_values[0] assert len(input_kwargs) <= 1 assert len(op_kwargs) == 1 # name assert "const" in input_kwargs c = input_kwargs["const"] output_values = [input_values[0] + c] elif op_type == OpType.DOT_GENERAL: assert len(input_values) == 2 assert 0 < len(input_kwargs) <= 3 assert len(op_kwargs) == 1 # name output_values = [ lax.dot_general(input_values[0], input_values[1], **input_kwargs) ] elif op_type == OpType.EINSUM: assert len(input_values) == 2 assert len(input_kwargs) == 1 assert "sum" in input_kwargs output_values = [ jnp.einsum(input_kwargs["sum"], input_values[0], input_values[1]) ] # nn.attention elif op_type == OpType.SELF_ATTENTION: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs output_values = [ nn.SelfAttention(**op_kwargs, deterministic=deterministic)(input_value) ] # nn.activation elif op_type in [ OpType.RELU, OpType.GELU, OpType.SWISH, OpType.SIGMOID ]: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs fn = { OpType.RELU: nn.relu, OpType.GELU: nn.gelu, OpType.SWISH: nn.swish, OpType.SIGMOID: nn.sigmoid }[op_type] output_values = [fn(input_value)] elif op_type == OpType.SOFTMAX: assert len(input_values) == 1 input_value = input_values[0] assert len(input_kwargs) <= 1 output_values = [nn.softmax(input_value, **input_kwargs)] # nn.normalization elif op_type == OpType.BATCH_NORM: assert len(input_values) == 1 input_value = input_values[0] assert len(input_kwargs) <= 1 add_kwargs = {} if "use_running_average" not in input_kwargs: add_kwargs = {"use_running_average": not training} else: add_kwargs = {} output_values = [ nn.BatchNorm(**op_kwargs)(input_value, **input_kwargs, **add_kwargs) ] elif op_type == OpType.LAYER_NORM: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs output_values = [nn.LayerNorm(**op_kwargs)(input_value)] elif op_type == OpType.GROUP_NORM: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs output_values = [nn.GroupNorm(**op_kwargs)(input_value)] # reshape operators elif op_type == OpType.RESHAPE: assert len(input_values) == 1 input_value = input_values[0] assert 0 < len(input_kwargs) < 3 new_shape = input_kwargs.pop("new_shape") if new_shape[0] == "B": new_shape = (input_value.shape[0], ) + new_shape[1:] output_values = [ jnp.reshape(input_value, new_shape, **input_kwargs) ] elif op_type == OpType.FLATTEN: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs new_shape = (input_value.shape[0], -1) output_values = [jnp.reshape(input_value, new_shape)] elif op_type == OpType.TRANSPOSE: assert len(input_values) == 1 input_value = input_values[0] assert len(input_kwargs) == 1 assert len(op_kwargs) == 1 # name output_values = [jnp.transpose(input_value, **input_kwargs)] # nn.stochastic elif op_type == OpType.DROPOUT: assert len(input_values) == 1 input_value = input_values[0] assert len(input_kwargs) <= 1 output_values = [ nn.Dropout(**op_kwargs)(input_value, deterministic=deterministic, **input_kwargs) ] # nn.pooling elif op_type == OpType.AVG_POOL or op_type == OpType.MAX_POOL: op_fn = nn.avg_pool if op_type == OpType.AVG_POOL else nn.max_pool assert len(input_values) == 1 input_value = input_values[0] assert input_kwargs ws = input_kwargs["window_shape"] if isinstance(ws, int): ws = [ws] * (input_value.ndim - 2) new_ws = [] for window_dim_shape, dim_shape in zip(ws, input_value.shape[1:]): if window_dim_shape == 0: new_ws.append(dim_shape) else: new_ws.append(window_dim_shape) input_kwargs["window_shape"] = tuple(new_ws) if "strides" in input_kwargs: s = input_kwargs["strides"] if isinstance(s, int): input_kwargs["strides"] = (s, ) * (input_value.ndim - 2) output_values = [op_fn(input_value, **input_kwargs)] elif op_type == OpType.MEAN: assert len(input_values) == 1 input_value = input_values[0] assert input_kwargs output_values = [jnp.mean(input_value, **input_kwargs)] # new param elif op_type == OpType.PARAM: assert not input_values assert 0 < len(input_kwargs) <= 2 init_fn = input_kwargs.pop("init_fn") init_fn_with_kwargs = functools.partial(init_fn, **input_kwargs) output_values = [self.param(op_name, init_fn_with_kwargs)] else: raise ValueError(f"op_type {op_type} not supported...") return output_values