class EncoderBlock(nn.Module):
    num_heads: int
    expand_ratio: float = 4
    attn_dropout_rate: float = 0.
    dropout_rate: float = 0.
    activation_fn: Callable = nn.activation.gelu
    dtype: jnp.dtype = jnp.float32
    precision: Precision = Precision.DEFAULT
    kernel_init: Callable = initializers.kaiming_uniform()
    bias_init: Callable = initializers.zeros

    @nn.compact
    def __call__(self, inputs, is_training: bool):
        x = nn.LayerNorm(dtype=self.dtype)(inputs)
        x = SelfAttentionBlock(num_heads=self.num_heads,
                               attn_drop_rate=self.attn_dropout_rate,
                               out_drop_rate=self.dropout_rate,
                               dtype=self.dtype,
                               precision=self.precision,
                               kernel_init=self.kernel_init)(
                                   x, is_training=is_training)
        x = x + inputs

        y = nn.LayerNorm(dtype=self.dtype)(x)
        y = FFBlock(expand_ratio=self.expand_ratio,
                    dropout_rate=self.dropout_rate,
                    activation_fn=self.activation_fn,
                    dtype=self.dtype,
                    precision=self.precision,
                    kernel_init=self.kernel_init,
                    bias_init=self.bias_init)(y, train=is_training)
        output = x + y
        return output
Esempio n. 2
0
class EncoderBlock(nn.Module):
    num_heads: int
    expand_ratio: float = 4
    leff_kernel_size: Optional[int] = 3
    activation_fn: Callable = nn.activation.gelu
    bn_momentum: float = 0.9
    bn_epsilon: float = 1e-5
    dtype: jnp.dtype = jnp.float32
    precision: Precision = Precision.DEFAULT
    kernel_init: Callable = initializers.kaiming_uniform()
    bias_init: Callable = initializers.normal(stddev=1e-6)

    @nn.compact
    def __call__(self, inputs, is_training: bool):
        x = SelfAttentionBlock(num_heads=self.num_heads,
                               dtype=self.dtype,
                               precision=self.precision,
                               kernel_init=self.kernel_init,
                               bias_init=self.bias_init)(
                                   inputs, is_training=is_training)
        x += inputs
        x = nn.LayerNorm(dtype=self.dtype)(x)

        y = LeFFBlock(expand_ratio=self.expand_ratio,
                      kernel_size=self.leff_kernel_size,
                      activation_fn=self.activation_fn,
                      bn_momentum=self.bn_momentum,
                      bn_epsilon=self.bn_epsilon,
                      dtype=self.dtype,
                      precision=self.precision,
                      kernel_init=self.kernel_init,
                      bias_init=self.bias_init)(x, is_training=is_training)
        y = x + y
        output = nn.LayerNorm(dtype=self.dtype)(y)
        return output
class MixerBlock(nn.Module):
    tokens_expand_ratio: float
    channels_expand_ratio: float
    activation_fn = nn.activation.gelu
    dtype: jnp.dtype = jnp.float32
    precision: Precision = Precision.DEFAULT
    kernel_init: Callable = initializers.kaiming_uniform()
    bias_init: Callable = initializers.zeros

    @nn.compact
    def __call__(self, inputs):
        x = nn.LayerNorm(dtype=self.dype)(inputs)
        x = rearrange(x, '... l d -> ... d l')
        x = FFBlock(expand_ratio=self.tokens_expand_ratio,
                    activation_fn=self.activation_fn,
                    dype=self.dtype,
                    precision=self.precision,
                    kernel_init=self.kernel_init,
                    bias_init=self.bias_init)(x)
        x = rearrange(x, '... d l -> ... l d')
        x = x + inputs

        y = nn.LayerNorm(dtype=self.dtype)(x)
        y = FFBlock(expand_ratio=self.channels_expand_ratio,
                    activation_fn=self.activation_fn,
                    dype=self.dtype,
                    precision=self.precision,
                    kernel_init=self.kernel_init,
                    bias_init=self.bias_init)(y)
        output = x + y
        return output
Esempio n. 4
0
class FFBlock(nn.Module):
    expand_ratio: float = None
    hidden_ch: int = None
    dropout_rate: float = 0.
    activation_fn: Callable = nn.activation.gelu
    dtype: jnp.dtype = jnp.float32
    precision: Precision = Precision.DEFAULT
    kernel_init: Callable = initializers.kaiming_uniform()
    bias_init: Callable = initializers.normal(stddev=1e-6)

    @nn.compact
    def __call__(self, inputs, is_training: bool):
        in_ch = inputs.shape[-1]
        if self.expand_ratio is None:
            if self.hidden_ch is None:
                raise ValueError(
                    'Must provide one of expand_ratio or hidden_ch')
            hidden_ch = self.hidden_ch
        else:
            hidden_ch = max(1, int(self.expand_ratio * in_ch))
        dense = partial(nn.Dense,
                        use_bias=True,
                        dtype=self.dtype,
                        precision=self.precision,
                        kernel_init=self.kernel_init,
                        bias_init=self.bias_init)

        x = dense(features=hidden_ch)(inputs)
        x = self.activation_fn(x)
        x = nn.Dropout(rate=self.dropout_rate,
                       deterministic=not is_training)(x)
        x = dense(features=in_ch)(x)
        output = nn.Dropout(rate=self.dropout_rate,
                            deterministic=not is_training)(x)
        return output
class CvTSelfAttentionBlock(nn.Module):
    num_heads: int
    head_ch: Optional[int] = None
    out_ch: Optional[int] = None
    talking_heads: bool = False
    attn_dropout_rate: float = 0.
    out_dropout_rate: float = 0.
    kernel_size: int = 3
    strides: Sequence[int] = [1, 2, 2]
    use_bias: bool = False
    bn_momentum: float = 0.9
    bn_epsilon: float = 1e-5
    dtype: jnp.dtype = jnp.float32
    precision: Precision = Precision.DEFAULT
    kernel_init: Callable = initializers.kaiming_uniform()
    bias_init: Callable = initializers.zeros

    @nn.compact
    def __call__(self, inputs, is_training: bool):
        return CvTAttentionBlock(num_heads=self.num_heads,
                                 head_ch=self.head_ch,
                                 out_ch=self.out_ch,
                                 talking_heads=self.talking_heads,
                                 attn_drop_rate=self.attn_drop_rate,
                                 out_drop_rate=self.out_drop_rate,
                                 kernel_size=self.kernel_size,
                                 strides=self.strides,
                                 bn_momentum=self.bn_momentum,
                                 bn_epsilon=self.bn_epsilon,
                                 use_bias=self.use_bias,
                                 dtype=self.dtype,
                                 precision=self.precision,
                                 kernel_init=self.kernel_init,
                                 bias_init=self.bias_init)(inputs, inputs,
                                                           is_training)
class ClassSelfAttentionBlock(nn.Module):
    num_heads: int
    head_ch: Optional[int] = None
    out_ch: Optional[int] = None
    talking_heads: bool = False
    attn_dropout_rate: float = 0.
    out_dropout_rate: float = 0.
    use_bias: bool = False
    dtype: jnp.dtype = jnp.float32
    precision: Precision = Precision.DEFAULT
    kernel_init: Callable = initializers.kaiming_uniform()
    bias_init: Callable = initializers.zeros

    @nn.compact
    def __call__(self, inputs, is_training: bool):
        inputs_q = jnp.enpand_dims(inputs[:, 0, :], axis=1)
        return AttentionBlock(num_heads=self.num_heads,
                              head_ch=self.head_ch,
                              out_ch=self.out_ch,
                              talking_heads=self.talking_heads,
                              attn_drop_rate=self.attn_drop_rate,
                              out_drop_rate=self.out_drop_rate,
                              use_bias=self.use_bias,
                              dtype=self.dtype,
                              precision=self.precision,
                              kernel_init=self.kernel_init,
                              bias_init=self.bias_init)(
                                  inputs_q, inputs, is_training=is_training)
class Encoder(nn.Module):
    num_layers: int
    num_heads: int
    expand_ratio: float = 4
    attn_dropout_rate: float = 0.
    dropout_rate: float = 0.
    activation_fn: Callable = nn.activation.gelu
    dtype: jnp.dtype = jnp.float32
    precision: Precision = Precision.DEFAULT
    kernel_init: Callable = initializers.kaiming_uniform()
    bias_init: Callable = initializers.zeros

    @nn.compact
    def __call__(self, inputs, is_training: bool):
        x = AddAbsPosEmbed()(inputs)
        x = nn.Dropout(rate=self.dropout_rate)(x,
                                               deterministic=not is_training)

        for _ in range(self.num_layers):
            x = EncoderBlock(num_heads=self.num_heads,
                             expand_ratio=self.expand_ratio,
                             attn_dropout_rate=self.attn_dropout_rate,
                             dropout_rate=self.dropout_rate,
                             activation_fn=self.activation_fn,
                             dtype=self.dtype,
                             precision=self.precision,
                             kernel_init=self.kernel_init)(
                                 x, is_training=is_training)

        output = nn.LayerNorm(dtype=self.dtype)(x)
        return output
Esempio n. 8
0
class Encoder(nn.Module):
    num_layers: int
    num_heads: int
    expand_ratio: float = 4
    leff_kernel_size: int = 3
    activation_fn: Callable = nn.activation.gelu
    bn_momentum: float = 0.9
    bn_epsilon: float = 1e-5
    dtype: jnp.dtype = jnp.float32
    precision: Precision = Precision.DEFAULT
    kernel_init: Callable = initializers.kaiming_uniform()
    bias_init: Callable = initializers.normal(stddev=1e-6)

    @nn.compact
    def __call__(self, inputs, is_training: bool):
        encoder_block = partial(EncoderBlock,
                                num_heads=self.num_heads,
                                expand_ratio=self.expand_ratio,
                                leff_kernel_size=self.leff_kernel_size,
                                activation_fn=self.activation_fn,
                                bn_momentum=self.bn_momentum,
                                bn_epsilon=self.bn_epsilon,
                                dtype=self.dtype,
                                precision=self.precision,
                                kernel_init=self.kernel_init,
                                bias_init=self.bias_init)

        x = encoder_block()(inputs, is_training=is_training)
        cls_tokens_lst = [jnp.expand_dims(x[:, 0], axis=1)]
        for _ in range(self.num_layers - 1):
            x = encoder_block()(x, is_training=is_training)
            cls_tokens_lst.append(jnp.expand_dims(x[:, 0], axis=1))

        return jnp.concatenate(cls_tokens_lst, axis=1)
Esempio n. 9
0
class LCAEncoderBlock(nn.Module):
    num_heads: int
    expand_ratio: float = 4
    activation_fn: Callable = nn.activation.gelu
    dtype: jnp.dtype = jnp.float32
    precision: Precision = Precision.DEFAULT
    kernel_init: Callable = initializers.kaiming_uniform()
    bias_init: Callable = initializers.normal(stddev=1e-6)

    @nn.compact
    def __call__(self, inputs, is_training: bool):
        x = LCSelfAttentionBlock(num_heads=self.num_heads,
                                 dtype=self.dtype,
                                 precision=self.precision,
                                 kernel_init=self.kernel_init,
                                 bias_init=self.bias_init)(
                                     inputs, is_training=is_training)
        x += inputs
        x = nn.LayerNorm(dtype=self.dtype)(x)

        y = FFBlock(expand_ratio=self.expand_ratio,
                    activation_fn=self.activation_fn,
                    dtype=self.dtype,
                    precision=self.precision,
                    kernel_init=self.kernel_init,
                    bias_init=self.bias_init)(x, is_training=is_training)
        y = x + y
        output = nn.LayerNorm(dtype=self.dtype)(y)
        return output
class ViT(nn.Module):
    num_classes: int
    num_layers: int
    num_heads: int
    embed_dim: int
    patch_shape: Tuple[int, int]
    expand_ratio: float = 4
    dropout_rate: float = 0.
    attn_dropout_rate: float = 0.
    activation_fn: Callable = nn.activation.gelu
    dtype: jnp.dtype = jnp.float32
    precision: Precision = Precision.DEFAULT
    kernel_init: Callable = initializers.kaiming_uniform()
    bias_init: Callable = initializers.zeros

    @nn.compact
    def __call__(self, inputs, is_training: bool):
        assert self.embed_dim % self.num_heads == 0

        x = PatchEmbedBlock(
            patch_shape=self.patch_shape,
            embed_dim=self.embed_dim,
            dtype=self.dtype,
            precision=self.precision,
            kernel_init=self.kernel_init,
        )(inputs)

        b, l, _ = x.shape
        cls_shape = (1, 1, self.embed_dim)
        cls_token = self.param('cls', initializers.zeros, cls_shape)
        cls_token = jnp.tile(cls_token, [b, 1, 1])
        x = jnp.concatenate([cls_token, x], axis=1)

        x = Encoder(num_layers=self.num_layers,
                    num_heads=self.num_heads,
                    expand_ratio=self.expand_ratio,
                    attn_dropout_rate=self.attn_dropout_rate,
                    dropout_rate=self.dropout_rate,
                    activation_fn=self.activation_fn,
                    dtype=self.dtype,
                    precision=self.precision,
                    kernel_init=self.kernel_init,
                    bias_init=self.bias_init)(x, is_training=is_training)

        cls_token = x[:, 0]
        output = nn.Dense(
            features=self.num_classes,
            use_bias=True,
            dtype=self.dtype,
            kernel_init=initializers.zeros,
            bias_init=self.bias_init,
        )(cls_token)
        return output
Esempio n. 11
0
class Image2TokenBlock(nn.Module):
    patch_shape: Tuple[int, int]
    num_ch: int
    conv_kernel_size: int
    conv_stride: int
    pool_window_size: int
    pool_stride: int
    embed_dim: int
    use_bias: bool = False
    bn_momentum: float = 0.9
    bn_epsilon: float = 1e-5
    dtype: jnp.dtype = jnp.float32
    precision: Precision = Precision.DEFAULT
    kernel_init: Callable = initializers.kaiming_uniform()
    bias_init: Callable = initializers.zeros

    @nn.compact
    def __call__(self, inputs, is_training: bool):

        x = nn.Conv(features=self.num_ch,
                    use_bias=self.use_bias,
                    kernel_size=(self.conv_kernel_size, self.conv_kernel_size),
                    strides=(self.conv_stride, self.conv_stride),
                    padding=[(self.patch_shape[0], ) * 2,
                             (self.patch_shape[1], ) * 2])(inputs)
        x = nn.BatchNorm(use_running_average=not is_training,
                         momentum=self.bn_momentum,
                         epsilon=self.bn_epsilon,
                         dtype=self.dtype)(x)
        x = nn.max_pool(
            inputs=x,
            window_shape=(self.pool_window_size, ) * 2,
            strides=(self.pool_stride, ) * 2,
        )
        x = rearrange(
            x,
            'b (h ph) (w pw) c -> b (h w) (ph pw c)',
            ph=self.patch_shape[0],
            pw=self.patch_shape[1],
        )

        output = nn.Dense(features=self.embed_dim,
                          use_bias=self.use_bias,
                          dtype=self.dtype,
                          precision=self.precision,
                          kernel_init=self.kernel_init,
                          bias_init=self.bias_init)(x)
        return output
class MLPMixer(nn.Module):
    num_classes: int
    num_layers: int
    embed_dim: int
    patch_shape: Tuple[int, int]
    tokens_expand_ratio: float
    channels_expand_ratio: float
    activation_fn: Callable = nn.activation.gelu
    dtype: jnp.dtype = jnp.float32
    precision: Precision = Precision.DEFAULT
    kernel_init: Callable = initializers.kaiming_uniform()
    bias_init: Callable = initializers.zeros

    @nn.compact
    def __call__(self, inputs, *unused_args, **unused_kwargs):
        x = PatchEmbedBlock(patch_shape=self.patch_shape,
                            embed_dim=self.embed_dim,
                            dtype=self.dtype,
                            precision=self.precision,
                            kernel_init=self.kernel_init)(inputs)
        x = rearrange(x, 'b h w d -> b (h w) d')

        for _ in range(self.num_layers):
            x = MixerBlock(tokens_expand_ratio=self.tokens_expand_ratio,
                           channels_expand_ratio=self.channels_expand_ratio,
                           activation_fn=self.activation_fn,
                           dtype=self.dtype,
                           precision=self.precision,
                           kernel_init=self.kernel_init,
                           bias_init=self.bias_init)(x)

        x = nn.LayerNorm(dtype=self.dtype)(x)
        x = jnp.mean(x, axis=1)
        output = nn.Dense(features=self.num_classes,
                          dtype=self.dtype,
                          precision=self.precision,
                          kernel_init=nn.initializers.zeros,
                          bias_init=self.bias_init)(x)
        return output
class ConvProjectionBlock(nn.Module):
    out_ch: int
    kernel_size: int = 3
    strides: int = 1
    use_bias: bool = True
    bn_momentum: float = 0.9
    bn_epsilon: float = 1e-5
    dtype: jnp.dtype = jnp.float32
    precision: Precision = Precision.DEFAULT
    kernel_init: Callable = initializers.kaiming_uniform()
    bias_init: Callable = initializers.zeros

    @nn.compact
    def __call__(self, inputs, is_training: bool):
        in_ch = inputs.shape[-1]

        conv = partial(nn.Conv,
                       dtype=self.dtype,
                       precision=self.precision,
                       kernel_init=self.kernel_init)

        x = conv(features=in_ch,
                 kernel_size=(self.kernel_size, self.kernel_size),
                 strides=(self.strides, self.strides),
                 padding='SAME',
                 feature_group_count=in_ch,
                 use_bias=False)(inputs)
        x = nn.BatchNorm(use_running_average=not is_training,
                         momentum=self.bn_momentum,
                         epsilon=self.bn_epsilon,
                         dtype=self.dtype)(x)
        output = conv(features=self.out_ch,
                      kernel_size=(1, 1),
                      use_bias=self.use_bias,
                      bias_init=self.bias_init)(x)
        return output
Esempio n. 14
0
class CeiT(nn.Module):
    num_classes: int
    num_layers: int
    num_heads: int
    embed_dim: int
    patch_shape: Tuple[int, int] = (4, 4)
    num_ch: int = 32
    conv_kernel_size: int = 7
    conv_stride: int = 2
    pool_window_size: int = 3
    pool_stride: int = 2
    expand_ratio: float = 4
    leff_kernel_size: int = 3
    bn_momentum: float = 0.9
    bn_epsilon: float = 1e-5
    activation_fn: Callable = nn.activation.gelu
    dtype: jnp.dtype = jnp.float32
    precision: Precision = Precision.DEFAULT
    kernel_init: Callable = initializers.kaiming_uniform()
    bias_init: Callable = initializers.normal(stddev=1e-6)

    @nn.compact
    def __call__(self, inputs, is_training: bool):
        assert self.embed_dim % self.num_heads == 0

        x = Image2TokenBlock(patch_shape=self.patch_shape,
                             num_ch=self.num_ch,
                             conv_kernel_size=self.conv_kernel_size,
                             conv_stride=self.conv_stride,
                             pool_window_size=self.pool_window_size,
                             pool_stride=self.pool_stride,
                             embed_dim=self.embed_dim,
                             bn_momentum=self.bn_momentum,
                             bn_epsilon=self.bn_epsilon,
                             dtype=self.dtype,
                             precision=self.precision,
                             kernel_init=self.kernel_init,
                             bias_init=self.bias_init)(inputs,
                                                       is_training=is_training)

        b = x.shape[0]
        cls_shape = (1, 1, self.embed_dim)
        cls_token = self.param('cls', initializers.zeros, cls_shape)
        cls_token = jnp.tile(cls_token, [b, 1, 1])

        x = jnp.concatenate([cls_token, x], axis=1)

        cls_tokens = Encoder(num_layers=self.num_layers,
                             num_heads=self.num_heads,
                             expand_ratio=self.expand_ratio,
                             leff_kernel_size=self.leff_kernel_size,
                             bn_momentum=self.bn_momentum,
                             bn_epsilon=self.bn_epsilon,
                             dtype=self.dtype,
                             precision=self.precision,
                             kernel_init=self.kernel_init,
                             bias_init=self.bias_init)(x,
                                                       is_training=is_training)

        cls_tokens = LCSelfAttentionBlock(num_heads=self.num_heads,
                                          dtype=self.dtype,
                                          precision=self.precision,
                                          kernel_init=self.kernel_init,
                                          bias_init=self.bias_init)(
                                              cls_tokens,
                                              is_training=is_training)
        cls = cls_tokens[:, -1]
        output = nn.Dense(features=self.num_classes,
                          use_bias=True,
                          dtype=self.dtype,
                          kernel_init=self.kernel_init,
                          bias_init=self.bias_init)(cls)
        return output
Esempio n. 15
0
class LeFFBlock(nn.Module):
    expand_ratio: int = None
    hidden_ch: int = None
    kernel_size: int = 5
    activation_fn: Callable = nn.activation.gelu
    bn_momentum: float = 0.9
    bn_epsilon: float = 1e-5
    dtype: jnp.dtype = jnp.float32
    precision: Precision = Precision.DEFAULT
    kernel_init: Callable = initializers.kaiming_uniform()
    bias_init: Callable = initializers.normal(stddev=1e-6)

    @nn.compact
    def __call__(self, inputs, is_training: bool):

        cls, tokens = inputs[:, 0], inputs[:, 1:]
        _, l, in_ch = tokens.shape
        if self.expand_ratio is None:
            if self.hidden_ch is None:
                raise ValueError(
                    'Must provide one of expand_ratio or hidden_ch')
            hidden_ch = self.hidden_ch
        else:
            hidden_ch = max(1, self.expand_ratio * in_ch)

        dense = partial(nn.Dense,
                        use_bias=True,
                        dtype=self.dtype,
                        precision=self.precision,
                        kernel_init=self.kernel_init,
                        bias_init=self.bias_init)

        batch_norm = partial(nn.BatchNorm,
                             use_running_average=not is_training,
                             momentum=self.bn_momentum,
                             epsilon=self.bn_epsilon,
                             dtype=self.dtype)

        x = dense(features=hidden_ch)(tokens)
        x = batch_norm()(x)
        x = self.activation_fn(x)

        spatial_ch = int(jnp.sqrt(l))
        x = rearrange(x, 'b (h w) c -> b h w c', h=spatial_ch, w=spatial_ch)

        x = nn.Conv(
            features=hidden_ch,
            kernel_size=(self.kernel_size, self.kernel_size),
            padding='SAME',
            dtype=self.dtype,
            precision=self.precision,
            feature_group_count=hidden_ch,
            kernel_init=self.kernel_init,
            bias_init=self.bias_init,
        )(x)
        x = batch_norm()(x)
        x = self.activation_fn(x)

        x = rearrange(x, 'b h w c -> b (h w) c')

        x = dense(features=in_ch)(x)
        x = batch_norm()(x)
        x = self.activation_fn(x)

        output = jnp.concatenate([jnp.expand_dims(cls, axis=1), x],
                                 axis=1)  # check if this is correct
        return output
class AttentionBlock(nn.Module):
    num_heads: int
    head_ch: Optional[int] = None
    out_ch: Optional[int] = None
    talking_heads: bool = False
    attn_drop_rate: float = 0.
    out_drop_rate: float = 0.
    use_bias: bool = False
    dtype: jnp.dtype = jnp.float32
    precision: Precision = Precision.DEFAULT
    kernel_init: Callable = initializers.kaiming_uniform()
    bias_init: Callable = initializers.zeros

    @nn.compact
    def __call__(self, inputs_q, inputs_kv, is_training: bool):
        assert inputs_q.ndim == inputs_kv.ndim == 3

        in_ch = inputs_q.shape[-1]
        assert in_ch % self.num_heads == 0
        head_ch = self.head_ch or int(in_ch / self.num_heads)
        out_ch = self.out_ch or in_ch

        dense = partial(nn.DenseGeneral,
                        axis=-1,
                        features=(self.num_heads, head_ch),
                        use_bias=self.use_bias,
                        dtype=self.dtype,
                        precision=self.precision,
                        kernel_init=self.kernel_init,
                        bias_init=self.bias_init)

        query = dense(name='queries')(inputs_q)
        key = dense(name='keys')(inputs_kv)
        value = dense(name='values')(inputs_kv)

        query = query / jnp.sqrt(head_ch)

        attn_weights = jnp.einsum('... q h d, ... k h d -> ... h q k',
                                  query,
                                  key,
                                  precision=self.precision)
        if self.talking_heads:
            pre_softmax_transform = self.param('pre_softmax', self.kernel_init,
                                               (self.num_heads, self.num_heads))
            attn_weights = jnp.einsum('... h q k, h i -> ... i q k',
                                      attn_weights,
                                      pre_softmax_transform,
                                      precision=self.precision)

        attn_weights = nn.softmax(attn_weights)

        if self.talking_heads:
            post_softmax_transform = self.param(
                'post_softmax', self.kernel_init,
                (self.num_heads, self.num_heads))
            attn_weights = jnp.einsum('... i q k, i h -> ... h q k',
                                      attn_weights,
                                      post_softmax_transform,
                                      precision=self.precision)

        attn_weights = nn.Dropout(rate=self.attn_dropout_rate)(
            attn_weights, deterministic=not is_training)

        attn_scores = jnp.einsum('... h q k, ... k h d -> ... q h d',
                                 attn_weights,
                                 value,
                                 precision=self.precision)

        output = nn.DenseGeneral(features=out_ch,
                                 axis=(-2, -1),
                                 use_bias=self.use_bias,
                                 dtype=self.dtype,
                                 precision=self.precision,
                                 kernel_init=self.kernel_init,
                                 bias_init=self.bias_init)(attn_scores)

        output = nn.Dropout(rate=self.out_drop_rate)(
            output, deterministic=not is_training)

        return output
class CvTAttentionBlock(nn.Module):
    num_heads: int
    head_ch: Optional[int] = None
    out_ch: Optional[int] = None
    talking_heads: bool = False
    attn_dropout_rate: float = 0.
    out_dropout_rate: float = 0.
    kernel_size: int = 3
    strides: Sequence[int] = [1, 2, 2]
    use_bias: bool = False
    bn_momentum: float = 0.9
    bn_epsilon: float = 1e-5
    dtype: jnp.dtype = jnp.float32
    precision: Precision = Precision.DEFAULT
    kernel_init: Callable = initializers.kaiming_uniform()
    bias_init: Callable = initializers.zeros

    @nn.compact
    def __call__(self, inputs_q, inputs_kv, is_training: bool):
        assert len(self.strides) == 3
        q_strides, k_strides, v_strides = self.strides

        in_ch = inputs_q.shape[-1]
        assert in_ch % self.num_heads == 0
        head_ch = self.head_ch or int(in_ch / self.num_heads)
        out_ch = self.out_ch or in_ch

        inputs_q = zero_pad_and_reshape(inputs_q)
        inputs_kv = zero_pad_and_reshape(inputs_kv)

        conv_proj = partial(ConvProjectionBlock,
                            out_ch=self.num_heads * head_ch,
                            kernel_size=self.kernel_size,
                            use_bias=self.use_bias,
                            bn_momentum=self.bn_momentum,
                            bn_epsilon=self.bn_epsilon,
                            dtype=self.dtype,
                            precision=self.precision,
                            kernel_init=self.kernel_init,
                            bias_init=self.bias_init)

        query = conv_proj(strides=q_strides)(inputs_q, is_training=is_training)
        key = conv_proj(strides=k_strides)(inputs_kv, is_training=is_training)
        value = conv_proj(strides=v_strides)(inputs_kv,
                                             is_training=is_training)

        query = rearrange(query,
                          'b H W (h d) -> b (H W) h d',
                          h=self.num_heads)
        key = rearrange(key, 'b H W (h d) -> b (H W) h d', h=self.num_heads)
        value = rearrange(value,
                          'b H W (h d) -> b (H W) h d',
                          h=self.num_heads)

        query = query / jnp.sqrt(head_ch)

        attn_weights = jnp.einsum('... q h d, ... k h d -> ... h q k',
                                  query,
                                  key,
                                  precision=self.precision)

        if self.talking_heads:
            pre_softmax_transform = self.param(
                'pre_softmax', self.kernel_init,
                (self.num_heads, self.num_heads))
            attn_weights = jnp.einsum('... h q k, h i -> ... i q k',
                                      attn_weights,
                                      pre_softmax_transform,
                                      precision=self.precision)

        attn_weights = nn.softmax(attn_weights)

        if self.talking_heads:
            post_softmax_transform = self.param(
                'post_softmax', self.kernel_init,
                (self.num_heads, self.num_heads))
            attn_weights = jnp.einsum('... i q k, i h -> ... h q k',
                                      attn_weights,
                                      post_softmax_transform,
                                      precision=self.precision)

        attn_weights = nn.Dropout(rate=self.attn_dropout_rate)(
            attn_weights, deterministic=not is_training)

        attn_scores = jnp.einsum('... h q k, ... k h d -> ... q h d',
                                 attn_weights,
                                 value,
                                 precision=self.precision)

        output = nn.DenseGeneral(features=out_ch,
                                 axis=(-2, -1),
                                 use_bias=self.use_bias,
                                 dtype=self.dtype,
                                 precision=self.precision,
                                 kernel_init=self.kernel_init,
                                 bias_init=self.bias_init)(attn_scores)

        output = nn.Dropout(rate=self.out_drop_rate)(
            output, deterministic=not is_training)

        return output
class SelfAttentionBlock(nn.Module):
    num_heads: int
    head_ch: int
    out_ch: int
    is_lca: bool = False
    talking_heads: bool = False
    dropout_rate: float = 0.
    use_bias: bool = False
    dtype: jnp.dtype = jnp.float32
    precision: Precision = Precision.DEFAULT
    kernel_init: Callable = initializers.kaiming_uniform()
    bias_init: Callable = initializers.zeros

    @nn.compact
    def __call__(self, inputs, is_training: bool):
        dense = partial(nn.DenseGeneral,
                        axis=-1,
                        features=(self.num_heads, self.head_ch),
                        use_bias=self.use_bias,
                        dtype=self.dtype,
                        precision=self.precision,
                        kernel_init=self.kernel_init,
                        bias_init=self.bias_init)

        if self.is_lca:
            q_inputs = jnp.expand_dims(inputs[:, -1, :], axis=1)
        else:
            q_inputs = inputs

        query = dense(name='queries')(q_inputs)
        key = dense(name='keys')(inputs)
        value = dense(name='values')(inputs)

        query = query / jnp.sqrt(self.head_ch)

        attn_weights = jnp.einsum('... q h d, ... k h d -> ... h q k',
                                  query,
                                  key,
                                  precision=self.precision)
        if self.talking_heads:
            pre_softmax_transform = self.param(
                'pre_softmax', self.kernel_init,
                (self.num_heads, self.num_heads))
            attn_weights = jnp.einsum('... h q k, h i -> ... i q k',
                                      attn_weights,
                                      pre_softmax_transform,
                                      precision=self.precision)

        attn_weights = nn.softmax(attn_weights)

        if self.talking_heads:
            post_softmax_transform = self.param(
                'post_softmax', self.kernel_init,
                (self.num_heads, self.num_heads))
            attn_weights = jnp.einsum('... i q k, i h -> ... h q k',
                                      attn_weights,
                                      post_softmax_transform,
                                      precision=self.precision)

        if is_training and self.dropout_rate > 0.:
            keep_prob = 1.0 - self.dropout_rate
            dropout_rng = self.make_rng('dropout')
            keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
            multiplier = (keep.astype(attn_weights.dtype) /
                          jnp.asarray(keep_prob, dtype=self.dtype))
            attn_weights = attn_weights * multiplier

        attn_scores = jnp.einsum('... h q k, ... k h d -> ... q h d',
                                 attn_weights,
                                 value,
                                 precision=self.precision)

        if (self.num_heads * self.head_ch) == self.out_ch:
            output = rearrange(attn_scores, '... q h d -> ... q (h d)')
        else:
            output = nn.DenseGeneral(features=self.out_ch,
                                     axis=(-2, -1),
                                     use_bias=self.use_bias,
                                     dtype=self.dtype,
                                     precision=self.precision,
                                     kernel_init=self.kernel_init,
                                     bias_init=self.bias_init)(attn_scores)
        return output