def __call__(self, patch_inputs, pixel_inputs, is_training: bool):
        inner_x = nn.LayerNorm(dtype=self.dtype)(pixel_inputs)
        inner_x = SelfAttentionBlock(num_heads=self.inner_num_heads,
                                     attn_dropout_rate=self.attn_dropout_rate,
                                     out_dropout_rate=self.dropout_rate,
                                     dtype=self.dtype)(inner_x,
                                                       is_training=is_training)
        inner_x = inner_x + pixel_inputs
        inner_y = nn.LayerNorm(dtype=self.dtype)(inner_x)
        inner_y = FFBlock(expand_ratio=self.inner_expand_ratio,
                          dropout_rate=self.dropout_rate,
                          dtype=self.dtype)(inner_y, is_training=is_training)
        inner_output = inner_x + inner_y

        outer_x = Inner2OuterBlock(dtype=self.dtype)(patch_inputs,
                                                     inner_output)

        outer_x = nn.LayerNorm(dtype=self.dtype)(outer_x)
        outer_x = SelfAttentionBlock(num_heads=self.outer_num_heads,
                                     attn_dropout_rate=self.attn_dropout_rate,
                                     out_dropout_rate=self.dropout_rate,
                                     dtype=self.dtype)(outer_x,
                                                       is_training=is_training)
        outer_x = outer_x + patch_inputs
        outer_y = nn.LayerNorm(dtype=self.dtype)(outer_x)
        outer_y = FFBlock(expand_ratio=self.outer_expand_ratio,
                          dropout_rate=self.dropout_rate,
                          dtype=self.dtype)(outer_y, is_training=is_training)
        outer_output = outer_x + outer_y

        return outer_output, inner_output
    def __call__(self, inputs, is_training: bool):
        x = nn.LayerNorm(dtype=self.dtype)(inputs)
        x = rearrange(x, '... l d -> ... d l')
        x = FFBlock(expand_ratio=self.tokens_expand_ratio,
                    activation_fn=self.activation_fn,
                    dtype=self.dtype)(x, is_training=is_training)
        x = rearrange(x, '... d l -> ... l d')
        x = x + inputs

        y = nn.LayerNorm(dtype=self.dtype)(x)
        y = FFBlock(expand_ratio=self.channels_expand_ratio,
                    activation_fn=self.activation_fn,
                    dtype=self.dtype)(y, is_training=is_training)
        output = x + y
        return output
    def __call__(self, inputs, cls_token, is_training: bool):
        x = jnp.concatenate([cls_token, inputs], axis=1)
        x = nn.LayerNorm(dtype=self.dtype)(x)
        x = ClassSelfAttentionBlock(num_heads=self.num_heads,
                                    attn_dropout_rate=self.attn_dropout_rate,
                                    out_dropout_rate=self.dropout_rate,
                                    dtype=self.dtype)(x,
                                                      is_training=is_training)
        x = LayerScaleBlock(eps=self.layerscale_eps,
                            dtype=self.dtype)(x, is_training=is_training)
        x = StochasticDepthBlock(drop_rate=self.stoch_depth_rate)(
            x, is_training=is_training)
        cls_token = cls_token + x

        y = nn.LayerNorm(dtype=self.dtype)(cls_token)
        y = FFBlock(expand_ratio=self.expand_ratio,
                    dropout_rate=self.dropout_rate,
                    activation_fn=self.activation_fn,
                    dtype=self.dtype)(y, is_training=is_training)
        y = LayerScaleBlock(eps=self.layerscale_eps,
                            dtype=self.dtype)(y, is_training=is_training)
        y = StochasticDepthBlock(drop_rate=self.stoch_depth_rate)(
            y, is_training=is_training)

        output = cls_token + y
        return output
    def __call__(self, inputs, is_training: bool):
        x = nn.LayerNorm(dtype=self.dtype)(inputs)
        x = SelfAttentionBlock(num_heads=self.num_heads,
                               talking_heads=True,
                               attn_dropout_rate=self.attn_dropout_rate,
                               out_dropout_rate=self.dropout_rate,
                               dtype=self.dtype)(x, is_training=is_training)
        x = LayerScaleBlock(eps=self.layerscale_eps,
                            dtype=self.dtype)(x, is_training=is_training)
        x = StochasticDepthBlock(drop_rate=self.stoch_depth_rate)(
            x, is_training=is_training)
        x = x + inputs

        y = nn.LayerNorm(dtype=self.dtype)(x)
        y = FFBlock(expand_ratio=self.expand_ratio,
                    dropout_rate=self.dropout_rate,
                    activation_fn=self.activation_fn,
                    dtype=self.dtype)(y, is_training=is_training)
        y = LayerScaleBlock(eps=self.layerscale_eps,
                            dtype=self.dtype)(y, is_training=is_training)
        y = StochasticDepthBlock(drop_rate=self.stoch_depth_rate)(
            y, is_training=is_training)

        output = x + y
        return output
Exemple #5
0
    def __call__(self, inputs, is_training: bool):
        assert inputs.ndim == 4
        assert self.embed_dim % self.num_heads == 0
        head_ch = int(self.embed_dim / self.num_heads)

        x = CvTSelfAttentionBlock(num_heads=self.num_heads,
                                  head_ch=head_ch,
                                  out_ch=self.embed_dim,
                                  kernel_size=self.kernel_size,
                                  use_bias=self.use_bias,
                                  bn_momentum=self.bn_momentum,
                                  bn_epsilon=self.bn_epsilon,
                                  dtype=self.dtype,
                                  precision=self.precision,
                                  kernel_init=self.kernel_init,
                                  bias_init=self.bias_init)(
                                      inputs, is_training=is_training)
        x = x + inputs

        y = nn.LayerNorm(dtype=self.dtype)(x)
        y = FFBlock(expand_ratio=self.expand_ratio,
                    dtype=self.dtype,
                    precision=self.precision,
                    kernel_init=self.kernel_init,
                    bias_init=self.bias_init)(y, is_training=is_training)

        output = x + y
        return output
    def __call__(self, inputs, is_training: bool):
        x = nn.LayerNorm(dtype=self.dtype)(inputs)
        x = SelfAttentionBlock(num_heads=self.num_heads,
                               head_ch=self.head_ch,
                               out_ch=self.num_heads * self.head_ch,
                               dropout_rate=self.attn_dropout_rate,
                               dtype=self.dtype,
                               precision=self.precision,
                               kernel_init=self.kernel_init)(
                                   x, is_training=is_training)
        x = nn.Dropout(rate=self.dropout_rate)(x,
                                               deterministic=not is_training)

        x += inputs

        y = nn.LayerNorm(dtype=self.dtype)(x)
        y = FFBlock(expand_ratio=self.expand_ratio,
                    dropout_rate=self.dropout_rate,
                    dtype=self.dtype,
                    precision=self.precision,
                    kernel_init=self.kernel_init,
                    bias_init=self.bias_init)(y, train=is_training)

        output = x + y
        return output
    def __call__(self, inputs):
        x = nn.LayerNorm(dtype=self.dype)(inputs)
        x = rearrange(x, '... l d -> ... d l')
        x = FFBlock(expand_ratio=self.tokens_expand_ratio,
                    activation_fn=self.activation_fn,
                    dype=self.dtype,
                    precision=self.precision,
                    kernel_init=self.kernel_init,
                    bias_init=self.bias_init)(x)
        x = rearrange(x, '... d l -> ... l d')
        x = x + inputs

        y = nn.LayerNorm(dtype=self.dtype)(x)
        y = FFBlock(expand_ratio=self.channels_expand_ratio,
                    activation_fn=self.activation_fn,
                    dype=self.dtype,
                    precision=self.precision,
                    kernel_init=self.kernel_init,
                    bias_init=self.bias_init)(y)
        output = x + y
        return output
    def __call__(self, inputs, is_training: bool):
        x = LCSelfAttentionBlock(num_heads=self.num_heads,
                                 dtype=self.dtype)(inputs,
                                                   is_training=is_training)
        x += inputs
        x = nn.LayerNorm(dtype=self.dtype)(x)

        y = FFBlock(expand_ratio=self.expand_ratio,
                    activation_fn=self.activation_fn,
                    dtype=self.dtype)(x, is_training=is_training)
        y = x + y
        output = nn.LayerNorm(dtype=self.dtype)(y)
        return output
    def __call__(self, inputs, is_training: bool):
        x = nn.LayerNorm(dtype=self.dtype)(inputs)
        x = SelfAttentionBlock(num_heads=self.num_heads,
                               attn_dropout_rate=self.attn_dropout_rate,
                               out_dropout_rate=self.dropout_rate,
                               dtype=self.dtype)(x, is_training=is_training)
        x = x + inputs

        y = nn.LayerNorm(dtype=self.dtype)(x)
        y = FFBlock(expand_ratio=self.expand_ratio,
                    dropout_rate=self.dropout_rate,
                    activation_fn=self.activation_fn,
                    dtype=self.dtype)(y, is_training=is_training)
        output = x + y
        return output
    def __call__(self, inputs, is_training: bool):
        inputs = zero_pad_and_reshape(inputs)

        x = CvTSelfAttentionBlock(num_heads=self.num_heads,
                                  kernel_size=self.kernel_size,
                                  use_bias=self.use_bias,
                                  bn_momentum=self.bn_momentum,
                                  bn_epsilon=self.bn_epsilon,
                                  dtype=self.dtype)(inputs,
                                                    is_training=is_training)

        x = x + rearrange(inputs, 'b h w d -> b (h w) d')

        y = nn.LayerNorm(dtype=self.dtype)(x)
        y = FFBlock(expand_ratio=self.expand_ratio,
                    activation_fn=self.activation_fn,
                    dtype=self.dtype)(y, is_training=is_training)
        output = x + y
        return output
    def __call__(self, inputs, is_training: bool):
        x = CvTSelfAttentionBlock(num_heads=self.num_heads,
                                  kernel_size=self.kernel_size,
                                  use_bias=self.use_bias,
                                  bn_momentum=self.bn_momentum,
                                  bn_epsilon=self.bn_epsilon,
                                  dtype=self.dtype,
                                  precision=self.precision,
                                  kernel_init=self.kernel_init,
                                  bias_init=self.bias_init)(
                                      inputs, is_training=is_training)
        x = x + inputs

        y = nn.LayerNorm(dtype=self.dtype)(x)
        y = FFBlock(expand_ratio=self.expand_ratio,
                    activation_fn=self.activation_fn,
                    dtype=self.dtype,
                    precision=self.precision,
                    kernel_init=self.kernel_init,
                    bias_init=self.bias_init)(y, is_training=is_training)

        output = x + y
        return output