class EncoderBlock(nn.Module): num_heads: int expand_ratio: float = 4 attn_dropout_rate: float = 0. dropout_rate: float = 0. activation_fn: Callable = nn.activation.gelu dtype: jnp.dtype = jnp.float32 precision: Precision = Precision.DEFAULT kernel_init: Callable = initializers.kaiming_uniform() bias_init: Callable = initializers.zeros @nn.compact def __call__(self, inputs, is_training: bool): x = nn.LayerNorm(dtype=self.dtype)(inputs) x = SelfAttentionBlock(num_heads=self.num_heads, attn_drop_rate=self.attn_dropout_rate, out_drop_rate=self.dropout_rate, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init)( x, is_training=is_training) x = x + inputs y = nn.LayerNorm(dtype=self.dtype)(x) y = FFBlock(expand_ratio=self.expand_ratio, dropout_rate=self.dropout_rate, activation_fn=self.activation_fn, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(y, train=is_training) output = x + y return output
class EncoderBlock(nn.Module): num_heads: int expand_ratio: float = 4 leff_kernel_size: Optional[int] = 3 activation_fn: Callable = nn.activation.gelu bn_momentum: float = 0.9 bn_epsilon: float = 1e-5 dtype: jnp.dtype = jnp.float32 precision: Precision = Precision.DEFAULT kernel_init: Callable = initializers.kaiming_uniform() bias_init: Callable = initializers.normal(stddev=1e-6) @nn.compact def __call__(self, inputs, is_training: bool): x = SelfAttentionBlock(num_heads=self.num_heads, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)( inputs, is_training=is_training) x += inputs x = nn.LayerNorm(dtype=self.dtype)(x) y = LeFFBlock(expand_ratio=self.expand_ratio, kernel_size=self.leff_kernel_size, activation_fn=self.activation_fn, bn_momentum=self.bn_momentum, bn_epsilon=self.bn_epsilon, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(x, is_training=is_training) y = x + y output = nn.LayerNorm(dtype=self.dtype)(y) return output
class MixerBlock(nn.Module): tokens_expand_ratio: float channels_expand_ratio: float activation_fn = nn.activation.gelu dtype: jnp.dtype = jnp.float32 precision: Precision = Precision.DEFAULT kernel_init: Callable = initializers.kaiming_uniform() bias_init: Callable = initializers.zeros @nn.compact def __call__(self, inputs): x = nn.LayerNorm(dtype=self.dype)(inputs) x = rearrange(x, '... l d -> ... d l') x = FFBlock(expand_ratio=self.tokens_expand_ratio, activation_fn=self.activation_fn, dype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) x = rearrange(x, '... d l -> ... l d') x = x + inputs y = nn.LayerNorm(dtype=self.dtype)(x) y = FFBlock(expand_ratio=self.channels_expand_ratio, activation_fn=self.activation_fn, dype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(y) output = x + y return output
class FFBlock(nn.Module): expand_ratio: float = None hidden_ch: int = None dropout_rate: float = 0. activation_fn: Callable = nn.activation.gelu dtype: jnp.dtype = jnp.float32 precision: Precision = Precision.DEFAULT kernel_init: Callable = initializers.kaiming_uniform() bias_init: Callable = initializers.normal(stddev=1e-6) @nn.compact def __call__(self, inputs, is_training: bool): in_ch = inputs.shape[-1] if self.expand_ratio is None: if self.hidden_ch is None: raise ValueError( 'Must provide one of expand_ratio or hidden_ch') hidden_ch = self.hidden_ch else: hidden_ch = max(1, int(self.expand_ratio * in_ch)) dense = partial(nn.Dense, use_bias=True, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init) x = dense(features=hidden_ch)(inputs) x = self.activation_fn(x) x = nn.Dropout(rate=self.dropout_rate, deterministic=not is_training)(x) x = dense(features=in_ch)(x) output = nn.Dropout(rate=self.dropout_rate, deterministic=not is_training)(x) return output
class CvTSelfAttentionBlock(nn.Module): num_heads: int head_ch: Optional[int] = None out_ch: Optional[int] = None talking_heads: bool = False attn_dropout_rate: float = 0. out_dropout_rate: float = 0. kernel_size: int = 3 strides: Sequence[int] = [1, 2, 2] use_bias: bool = False bn_momentum: float = 0.9 bn_epsilon: float = 1e-5 dtype: jnp.dtype = jnp.float32 precision: Precision = Precision.DEFAULT kernel_init: Callable = initializers.kaiming_uniform() bias_init: Callable = initializers.zeros @nn.compact def __call__(self, inputs, is_training: bool): return CvTAttentionBlock(num_heads=self.num_heads, head_ch=self.head_ch, out_ch=self.out_ch, talking_heads=self.talking_heads, attn_drop_rate=self.attn_drop_rate, out_drop_rate=self.out_drop_rate, kernel_size=self.kernel_size, strides=self.strides, bn_momentum=self.bn_momentum, bn_epsilon=self.bn_epsilon, use_bias=self.use_bias, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(inputs, inputs, is_training)
class ClassSelfAttentionBlock(nn.Module): num_heads: int head_ch: Optional[int] = None out_ch: Optional[int] = None talking_heads: bool = False attn_dropout_rate: float = 0. out_dropout_rate: float = 0. use_bias: bool = False dtype: jnp.dtype = jnp.float32 precision: Precision = Precision.DEFAULT kernel_init: Callable = initializers.kaiming_uniform() bias_init: Callable = initializers.zeros @nn.compact def __call__(self, inputs, is_training: bool): inputs_q = jnp.enpand_dims(inputs[:, 0, :], axis=1) return AttentionBlock(num_heads=self.num_heads, head_ch=self.head_ch, out_ch=self.out_ch, talking_heads=self.talking_heads, attn_drop_rate=self.attn_drop_rate, out_drop_rate=self.out_drop_rate, use_bias=self.use_bias, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)( inputs_q, inputs, is_training=is_training)
class Encoder(nn.Module): num_layers: int num_heads: int expand_ratio: float = 4 attn_dropout_rate: float = 0. dropout_rate: float = 0. activation_fn: Callable = nn.activation.gelu dtype: jnp.dtype = jnp.float32 precision: Precision = Precision.DEFAULT kernel_init: Callable = initializers.kaiming_uniform() bias_init: Callable = initializers.zeros @nn.compact def __call__(self, inputs, is_training: bool): x = AddAbsPosEmbed()(inputs) x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not is_training) for _ in range(self.num_layers): x = EncoderBlock(num_heads=self.num_heads, expand_ratio=self.expand_ratio, attn_dropout_rate=self.attn_dropout_rate, dropout_rate=self.dropout_rate, activation_fn=self.activation_fn, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init)( x, is_training=is_training) output = nn.LayerNorm(dtype=self.dtype)(x) return output
class Encoder(nn.Module): num_layers: int num_heads: int expand_ratio: float = 4 leff_kernel_size: int = 3 activation_fn: Callable = nn.activation.gelu bn_momentum: float = 0.9 bn_epsilon: float = 1e-5 dtype: jnp.dtype = jnp.float32 precision: Precision = Precision.DEFAULT kernel_init: Callable = initializers.kaiming_uniform() bias_init: Callable = initializers.normal(stddev=1e-6) @nn.compact def __call__(self, inputs, is_training: bool): encoder_block = partial(EncoderBlock, num_heads=self.num_heads, expand_ratio=self.expand_ratio, leff_kernel_size=self.leff_kernel_size, activation_fn=self.activation_fn, bn_momentum=self.bn_momentum, bn_epsilon=self.bn_epsilon, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init) x = encoder_block()(inputs, is_training=is_training) cls_tokens_lst = [jnp.expand_dims(x[:, 0], axis=1)] for _ in range(self.num_layers - 1): x = encoder_block()(x, is_training=is_training) cls_tokens_lst.append(jnp.expand_dims(x[:, 0], axis=1)) return jnp.concatenate(cls_tokens_lst, axis=1)
class LCAEncoderBlock(nn.Module): num_heads: int expand_ratio: float = 4 activation_fn: Callable = nn.activation.gelu dtype: jnp.dtype = jnp.float32 precision: Precision = Precision.DEFAULT kernel_init: Callable = initializers.kaiming_uniform() bias_init: Callable = initializers.normal(stddev=1e-6) @nn.compact def __call__(self, inputs, is_training: bool): x = LCSelfAttentionBlock(num_heads=self.num_heads, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)( inputs, is_training=is_training) x += inputs x = nn.LayerNorm(dtype=self.dtype)(x) y = FFBlock(expand_ratio=self.expand_ratio, activation_fn=self.activation_fn, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(x, is_training=is_training) y = x + y output = nn.LayerNorm(dtype=self.dtype)(y) return output
class ViT(nn.Module): num_classes: int num_layers: int num_heads: int embed_dim: int patch_shape: Tuple[int, int] expand_ratio: float = 4 dropout_rate: float = 0. attn_dropout_rate: float = 0. activation_fn: Callable = nn.activation.gelu dtype: jnp.dtype = jnp.float32 precision: Precision = Precision.DEFAULT kernel_init: Callable = initializers.kaiming_uniform() bias_init: Callable = initializers.zeros @nn.compact def __call__(self, inputs, is_training: bool): assert self.embed_dim % self.num_heads == 0 x = PatchEmbedBlock( patch_shape=self.patch_shape, embed_dim=self.embed_dim, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, )(inputs) b, l, _ = x.shape cls_shape = (1, 1, self.embed_dim) cls_token = self.param('cls', initializers.zeros, cls_shape) cls_token = jnp.tile(cls_token, [b, 1, 1]) x = jnp.concatenate([cls_token, x], axis=1) x = Encoder(num_layers=self.num_layers, num_heads=self.num_heads, expand_ratio=self.expand_ratio, attn_dropout_rate=self.attn_dropout_rate, dropout_rate=self.dropout_rate, activation_fn=self.activation_fn, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(x, is_training=is_training) cls_token = x[:, 0] output = nn.Dense( features=self.num_classes, use_bias=True, dtype=self.dtype, kernel_init=initializers.zeros, bias_init=self.bias_init, )(cls_token) return output
class Image2TokenBlock(nn.Module): patch_shape: Tuple[int, int] num_ch: int conv_kernel_size: int conv_stride: int pool_window_size: int pool_stride: int embed_dim: int use_bias: bool = False bn_momentum: float = 0.9 bn_epsilon: float = 1e-5 dtype: jnp.dtype = jnp.float32 precision: Precision = Precision.DEFAULT kernel_init: Callable = initializers.kaiming_uniform() bias_init: Callable = initializers.zeros @nn.compact def __call__(self, inputs, is_training: bool): x = nn.Conv(features=self.num_ch, use_bias=self.use_bias, kernel_size=(self.conv_kernel_size, self.conv_kernel_size), strides=(self.conv_stride, self.conv_stride), padding=[(self.patch_shape[0], ) * 2, (self.patch_shape[1], ) * 2])(inputs) x = nn.BatchNorm(use_running_average=not is_training, momentum=self.bn_momentum, epsilon=self.bn_epsilon, dtype=self.dtype)(x) x = nn.max_pool( inputs=x, window_shape=(self.pool_window_size, ) * 2, strides=(self.pool_stride, ) * 2, ) x = rearrange( x, 'b (h ph) (w pw) c -> b (h w) (ph pw c)', ph=self.patch_shape[0], pw=self.patch_shape[1], ) output = nn.Dense(features=self.embed_dim, use_bias=self.use_bias, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) return output
class MLPMixer(nn.Module): num_classes: int num_layers: int embed_dim: int patch_shape: Tuple[int, int] tokens_expand_ratio: float channels_expand_ratio: float activation_fn: Callable = nn.activation.gelu dtype: jnp.dtype = jnp.float32 precision: Precision = Precision.DEFAULT kernel_init: Callable = initializers.kaiming_uniform() bias_init: Callable = initializers.zeros @nn.compact def __call__(self, inputs, *unused_args, **unused_kwargs): x = PatchEmbedBlock(patch_shape=self.patch_shape, embed_dim=self.embed_dim, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init)(inputs) x = rearrange(x, 'b h w d -> b (h w) d') for _ in range(self.num_layers): x = MixerBlock(tokens_expand_ratio=self.tokens_expand_ratio, channels_expand_ratio=self.channels_expand_ratio, activation_fn=self.activation_fn, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) x = nn.LayerNorm(dtype=self.dtype)(x) x = jnp.mean(x, axis=1) output = nn.Dense(features=self.num_classes, dtype=self.dtype, precision=self.precision, kernel_init=nn.initializers.zeros, bias_init=self.bias_init)(x) return output
class ConvProjectionBlock(nn.Module): out_ch: int kernel_size: int = 3 strides: int = 1 use_bias: bool = True bn_momentum: float = 0.9 bn_epsilon: float = 1e-5 dtype: jnp.dtype = jnp.float32 precision: Precision = Precision.DEFAULT kernel_init: Callable = initializers.kaiming_uniform() bias_init: Callable = initializers.zeros @nn.compact def __call__(self, inputs, is_training: bool): in_ch = inputs.shape[-1] conv = partial(nn.Conv, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init) x = conv(features=in_ch, kernel_size=(self.kernel_size, self.kernel_size), strides=(self.strides, self.strides), padding='SAME', feature_group_count=in_ch, use_bias=False)(inputs) x = nn.BatchNorm(use_running_average=not is_training, momentum=self.bn_momentum, epsilon=self.bn_epsilon, dtype=self.dtype)(x) output = conv(features=self.out_ch, kernel_size=(1, 1), use_bias=self.use_bias, bias_init=self.bias_init)(x) return output
class CeiT(nn.Module): num_classes: int num_layers: int num_heads: int embed_dim: int patch_shape: Tuple[int, int] = (4, 4) num_ch: int = 32 conv_kernel_size: int = 7 conv_stride: int = 2 pool_window_size: int = 3 pool_stride: int = 2 expand_ratio: float = 4 leff_kernel_size: int = 3 bn_momentum: float = 0.9 bn_epsilon: float = 1e-5 activation_fn: Callable = nn.activation.gelu dtype: jnp.dtype = jnp.float32 precision: Precision = Precision.DEFAULT kernel_init: Callable = initializers.kaiming_uniform() bias_init: Callable = initializers.normal(stddev=1e-6) @nn.compact def __call__(self, inputs, is_training: bool): assert self.embed_dim % self.num_heads == 0 x = Image2TokenBlock(patch_shape=self.patch_shape, num_ch=self.num_ch, conv_kernel_size=self.conv_kernel_size, conv_stride=self.conv_stride, pool_window_size=self.pool_window_size, pool_stride=self.pool_stride, embed_dim=self.embed_dim, bn_momentum=self.bn_momentum, bn_epsilon=self.bn_epsilon, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(inputs, is_training=is_training) b = x.shape[0] cls_shape = (1, 1, self.embed_dim) cls_token = self.param('cls', initializers.zeros, cls_shape) cls_token = jnp.tile(cls_token, [b, 1, 1]) x = jnp.concatenate([cls_token, x], axis=1) cls_tokens = Encoder(num_layers=self.num_layers, num_heads=self.num_heads, expand_ratio=self.expand_ratio, leff_kernel_size=self.leff_kernel_size, bn_momentum=self.bn_momentum, bn_epsilon=self.bn_epsilon, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(x, is_training=is_training) cls_tokens = LCSelfAttentionBlock(num_heads=self.num_heads, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)( cls_tokens, is_training=is_training) cls = cls_tokens[:, -1] output = nn.Dense(features=self.num_classes, use_bias=True, dtype=self.dtype, kernel_init=self.kernel_init, bias_init=self.bias_init)(cls) return output
class LeFFBlock(nn.Module): expand_ratio: int = None hidden_ch: int = None kernel_size: int = 5 activation_fn: Callable = nn.activation.gelu bn_momentum: float = 0.9 bn_epsilon: float = 1e-5 dtype: jnp.dtype = jnp.float32 precision: Precision = Precision.DEFAULT kernel_init: Callable = initializers.kaiming_uniform() bias_init: Callable = initializers.normal(stddev=1e-6) @nn.compact def __call__(self, inputs, is_training: bool): cls, tokens = inputs[:, 0], inputs[:, 1:] _, l, in_ch = tokens.shape if self.expand_ratio is None: if self.hidden_ch is None: raise ValueError( 'Must provide one of expand_ratio or hidden_ch') hidden_ch = self.hidden_ch else: hidden_ch = max(1, self.expand_ratio * in_ch) dense = partial(nn.Dense, use_bias=True, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init) batch_norm = partial(nn.BatchNorm, use_running_average=not is_training, momentum=self.bn_momentum, epsilon=self.bn_epsilon, dtype=self.dtype) x = dense(features=hidden_ch)(tokens) x = batch_norm()(x) x = self.activation_fn(x) spatial_ch = int(jnp.sqrt(l)) x = rearrange(x, 'b (h w) c -> b h w c', h=spatial_ch, w=spatial_ch) x = nn.Conv( features=hidden_ch, kernel_size=(self.kernel_size, self.kernel_size), padding='SAME', dtype=self.dtype, precision=self.precision, feature_group_count=hidden_ch, kernel_init=self.kernel_init, bias_init=self.bias_init, )(x) x = batch_norm()(x) x = self.activation_fn(x) x = rearrange(x, 'b h w c -> b (h w) c') x = dense(features=in_ch)(x) x = batch_norm()(x) x = self.activation_fn(x) output = jnp.concatenate([jnp.expand_dims(cls, axis=1), x], axis=1) # check if this is correct return output
class AttentionBlock(nn.Module): num_heads: int head_ch: Optional[int] = None out_ch: Optional[int] = None talking_heads: bool = False attn_drop_rate: float = 0. out_drop_rate: float = 0. use_bias: bool = False dtype: jnp.dtype = jnp.float32 precision: Precision = Precision.DEFAULT kernel_init: Callable = initializers.kaiming_uniform() bias_init: Callable = initializers.zeros @nn.compact def __call__(self, inputs_q, inputs_kv, is_training: bool): assert inputs_q.ndim == inputs_kv.ndim == 3 in_ch = inputs_q.shape[-1] assert in_ch % self.num_heads == 0 head_ch = self.head_ch or int(in_ch / self.num_heads) out_ch = self.out_ch or in_ch dense = partial(nn.DenseGeneral, axis=-1, features=(self.num_heads, head_ch), use_bias=self.use_bias, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init) query = dense(name='queries')(inputs_q) key = dense(name='keys')(inputs_kv) value = dense(name='values')(inputs_kv) query = query / jnp.sqrt(head_ch) attn_weights = jnp.einsum('... q h d, ... k h d -> ... h q k', query, key, precision=self.precision) if self.talking_heads: pre_softmax_transform = self.param('pre_softmax', self.kernel_init, (self.num_heads, self.num_heads)) attn_weights = jnp.einsum('... h q k, h i -> ... i q k', attn_weights, pre_softmax_transform, precision=self.precision) attn_weights = nn.softmax(attn_weights) if self.talking_heads: post_softmax_transform = self.param( 'post_softmax', self.kernel_init, (self.num_heads, self.num_heads)) attn_weights = jnp.einsum('... i q k, i h -> ... h q k', attn_weights, post_softmax_transform, precision=self.precision) attn_weights = nn.Dropout(rate=self.attn_dropout_rate)( attn_weights, deterministic=not is_training) attn_scores = jnp.einsum('... h q k, ... k h d -> ... q h d', attn_weights, value, precision=self.precision) output = nn.DenseGeneral(features=out_ch, axis=(-2, -1), use_bias=self.use_bias, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(attn_scores) output = nn.Dropout(rate=self.out_drop_rate)( output, deterministic=not is_training) return output
class CvTAttentionBlock(nn.Module): num_heads: int head_ch: Optional[int] = None out_ch: Optional[int] = None talking_heads: bool = False attn_dropout_rate: float = 0. out_dropout_rate: float = 0. kernel_size: int = 3 strides: Sequence[int] = [1, 2, 2] use_bias: bool = False bn_momentum: float = 0.9 bn_epsilon: float = 1e-5 dtype: jnp.dtype = jnp.float32 precision: Precision = Precision.DEFAULT kernel_init: Callable = initializers.kaiming_uniform() bias_init: Callable = initializers.zeros @nn.compact def __call__(self, inputs_q, inputs_kv, is_training: bool): assert len(self.strides) == 3 q_strides, k_strides, v_strides = self.strides in_ch = inputs_q.shape[-1] assert in_ch % self.num_heads == 0 head_ch = self.head_ch or int(in_ch / self.num_heads) out_ch = self.out_ch or in_ch inputs_q = zero_pad_and_reshape(inputs_q) inputs_kv = zero_pad_and_reshape(inputs_kv) conv_proj = partial(ConvProjectionBlock, out_ch=self.num_heads * head_ch, kernel_size=self.kernel_size, use_bias=self.use_bias, bn_momentum=self.bn_momentum, bn_epsilon=self.bn_epsilon, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init) query = conv_proj(strides=q_strides)(inputs_q, is_training=is_training) key = conv_proj(strides=k_strides)(inputs_kv, is_training=is_training) value = conv_proj(strides=v_strides)(inputs_kv, is_training=is_training) query = rearrange(query, 'b H W (h d) -> b (H W) h d', h=self.num_heads) key = rearrange(key, 'b H W (h d) -> b (H W) h d', h=self.num_heads) value = rearrange(value, 'b H W (h d) -> b (H W) h d', h=self.num_heads) query = query / jnp.sqrt(head_ch) attn_weights = jnp.einsum('... q h d, ... k h d -> ... h q k', query, key, precision=self.precision) if self.talking_heads: pre_softmax_transform = self.param( 'pre_softmax', self.kernel_init, (self.num_heads, self.num_heads)) attn_weights = jnp.einsum('... h q k, h i -> ... i q k', attn_weights, pre_softmax_transform, precision=self.precision) attn_weights = nn.softmax(attn_weights) if self.talking_heads: post_softmax_transform = self.param( 'post_softmax', self.kernel_init, (self.num_heads, self.num_heads)) attn_weights = jnp.einsum('... i q k, i h -> ... h q k', attn_weights, post_softmax_transform, precision=self.precision) attn_weights = nn.Dropout(rate=self.attn_dropout_rate)( attn_weights, deterministic=not is_training) attn_scores = jnp.einsum('... h q k, ... k h d -> ... q h d', attn_weights, value, precision=self.precision) output = nn.DenseGeneral(features=out_ch, axis=(-2, -1), use_bias=self.use_bias, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(attn_scores) output = nn.Dropout(rate=self.out_drop_rate)( output, deterministic=not is_training) return output
class SelfAttentionBlock(nn.Module): num_heads: int head_ch: int out_ch: int is_lca: bool = False talking_heads: bool = False dropout_rate: float = 0. use_bias: bool = False dtype: jnp.dtype = jnp.float32 precision: Precision = Precision.DEFAULT kernel_init: Callable = initializers.kaiming_uniform() bias_init: Callable = initializers.zeros @nn.compact def __call__(self, inputs, is_training: bool): dense = partial(nn.DenseGeneral, axis=-1, features=(self.num_heads, self.head_ch), use_bias=self.use_bias, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init) if self.is_lca: q_inputs = jnp.expand_dims(inputs[:, -1, :], axis=1) else: q_inputs = inputs query = dense(name='queries')(q_inputs) key = dense(name='keys')(inputs) value = dense(name='values')(inputs) query = query / jnp.sqrt(self.head_ch) attn_weights = jnp.einsum('... q h d, ... k h d -> ... h q k', query, key, precision=self.precision) if self.talking_heads: pre_softmax_transform = self.param( 'pre_softmax', self.kernel_init, (self.num_heads, self.num_heads)) attn_weights = jnp.einsum('... h q k, h i -> ... i q k', attn_weights, pre_softmax_transform, precision=self.precision) attn_weights = nn.softmax(attn_weights) if self.talking_heads: post_softmax_transform = self.param( 'post_softmax', self.kernel_init, (self.num_heads, self.num_heads)) attn_weights = jnp.einsum('... i q k, i h -> ... h q k', attn_weights, post_softmax_transform, precision=self.precision) if is_training and self.dropout_rate > 0.: keep_prob = 1.0 - self.dropout_rate dropout_rng = self.make_rng('dropout') keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) multiplier = (keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=self.dtype)) attn_weights = attn_weights * multiplier attn_scores = jnp.einsum('... h q k, ... k h d -> ... q h d', attn_weights, value, precision=self.precision) if (self.num_heads * self.head_ch) == self.out_ch: output = rearrange(attn_scores, '... q h d -> ... q (h d)') else: output = nn.DenseGeneral(features=self.out_ch, axis=(-2, -1), use_bias=self.use_bias, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(attn_scores) return output