コード例 #1
0
 def __call__(self, key, shape, dtype=None):
     if dtype is None:
         dtype = "float32"
     initializer_fn = jax_initializers.kaiming_uniform()
     return initializer_fn(key, shape, dtype)
コード例 #2
0
class CvT(nn.Module):
    num_classes: int
    stage_sizes: Sequence[int]
    num_heads: Sequence[int]
    embed_dim: Sequence[int]
    embed_kernel_size: Sequence[int] = [7, 3, 3]
    embed_strides: Sequence[int] = [4, 2, 2]
    sa_kernel_size: Sequence[int] = [3, 3, 3]
    use_bias: bool = False
    bn_momentum: float = 0.9
    bn_epsilon: float = 1e-5
    expand_ratio: int = 4
    dtype: jnp.dtype = jnp.float32
    precision: Precision = Precision.DEFAULT
    kernel_init: Callable = initializers.kaiming_uniform()
    bias_init: Callable = initializers.zeros

    @nn.compact
    def __call__(self, inputs, is_training: bool):
        x = inputs
        for i in range(len(self.stage_sizes) - 1):
            x = Stage(size=self.stage_sizes[i],
                      num_heads=self.num_heads[i],
                      embed_dim=self.embed_dim[i],
                      embed_kernel_size=self.embed_kernel_size[i],
                      embed_strides=self.embed_strides[i],
                      sa_kernel_size=self.sa_kernel_size[i],
                      use_bias=self.use_bias,
                      bn_momentum=self.bn_momentum,
                      bn_epsilon=self.bn_epsilon,
                      expand_ratio=self.expand_ratio,
                      dtype=self.dtype,
                      precision=self.precision,
                      kernel_init=self.kernel_init,
                      bias_init=self.bias_init)(x, is_training=is_training)

            l = x.shape[1]
            spatial_ch = int(jnp.sqrt(l))
            x = rearrange(x, 'b (H W) c -> b H W c', H=spatial_ch)

        x = Stage(size=self.stage_sizes[i],
                  num_heads=self.num_heads[i],
                  embed_dim=self.embed_dim[i],
                  embed_kernel_size=self.embed_kernel_size[i],
                  embed_strides=self.embed_strides[i],
                  sa_kernel_size=self.sa_kernel_size[i],
                  use_bias=self.use_bias,
                  bn_momentum=self.bn_momentum,
                  bn_epsilon=self.bn_epsilon,
                  expand_ratio=self.expand_ratio,
                  insert_cls=True,
                  dtype=self.dtype,
                  precision=self.precision,
                  kernel_init=self.kernel_init,
                  bias_init=self.bias_init)(x, is_training=is_training)

        cls_token = x[:, 0]
        output = nn.Dense(features=self.num_classes,
                          use_bias=True,
                          dtype=self.dtype,
                          precision=self.precision,
                          kernel_init=self.kernel_init,
                          bias_init=initializers.zeros)(cls_token)
        return output
コード例 #3
0
class EncoderBlock(nn.Module):
    inner_num_heads: int
    outer_num_heads: int
    inner_expand_ratio: float = 4
    outer_expand_ratio: float = 4
    attn_dropout_rate: float = 0.
    dropout_rate: float = 0.
    activation_fn = nn.activation.gelu
    dtype: jnp.dtype = jnp.float32
    precision: Precision = Precision.DEFAULT
    kernel_init: Callable = initializers.kaiming_uniform()
    bias_init: Callable = initializers.zeros

    @nn.compact
    def __call__(self, patch_inputs, pixel_inputs, is_training: bool):
        inner_x = nn.LayerNorm(dtype=self.dtype)(pixel_inputs)
        inner_x = SelfAttentionBlock(num_heads=self.inner_num_heads,
                                     attn_drop_rate=self.attn_dropout_rate,
                                     out_drop_rate=self.dropout_rate,
                                     dtype=self.dtype,
                                     precision=self.precision,
                                     kernel_init=self.kernel_init,
                                     bias_init=self.bias_init)(
                                         inner_x, is_training=is_training)
        inner_x = inner_x + pixel_inputs
        inner_y = nn.LayerNorm(dtype=self.dtype)(inner_x)
        inner_y = FFBlock(expand_ratio=self.inner_expand_ratio,
                          dropout_rate=self.dropout_rate,
                          dtype=self.dtype,
                          precision=self.precision,
                          kernel_init=self.kernel_init,
                          bias_init=self.bias_init)(inner_y,
                                                    is_training=is_training)
        inner_output = inner_x + inner_y

        outer_x = Inner2OuterBlock(dtype=self.dtype,
                                   precision=self.precision,
                                   kernel_init=self.kernel_init,
                                   bias_init=self.bias_init)(inner_output,
                                                             patch_inputs)

        outer_x = nn.LayerNorm(dtype=self.dtype)(outer_x)
        outer_x = SelfAttentionBlock(num_heads=self.outer_num_heads,
                                     attn_drop_rate=self.attn_dropout_rate,
                                     out_drop_rate=self.dropout_rate,
                                     dtype=self.dtype,
                                     precision=self.precision,
                                     kernel_init=self.kernel_init,
                                     bias_init=self.bias_init)(
                                         outer_x, is_training=is_training)
        outer_x = outer_x + patch_inputs
        outer_y = nn.LayerNorm(dtype=self.dtype)(outer_x)
        outer_y = FFBlock(expand_ratio=self.outer_expand_ratio,
                          dropout_rate=self.dropout_rate,
                          dtype=self.dtype,
                          precision=self.precision,
                          kernel_init=self.kernel_init,
                          bias_init=self.bias_init)(outer_y,
                                                    is_training=is_training)
        outer_output = outer_x + outer_y

        return outer_output, inner_output
コード例 #4
0
class TNT(nn.Module):
    num_layers: int
    patch_shape: Tuple[int, int]
    transformed_patch_shape: Tuple[int, int]
    inner_num_heads: int
    outer_num_heads: int
    inner_embed_dim: int
    outer_embed_dim: int
    inner_expand_ratio: float = 4
    outer_expand_ratio: float = 4
    attn_dropout_rate: float = 0.
    dropout_rate: float = 0.
    activation_fn = nn.activation.gelu
    dtype: jnp.dtype = jnp.float32
    precision: Precision = Precision.DEFAULT
    kernel_init: Callable = initializers.kaiming_uniform()
    bias_init: Callable = initializers.zeros
    pos_embed_init: Callable = initializers.normal(stddev=0.02)

    @nn.compact
    def __call__(self, inputs, is_training: bool):
        pixel_embeddings = PixelEmbedBlock(
            patch_shape=self.patch_shape,
            transformed_patch_shape=self.transformed_patch_shape,
            embed_dim=self.inner_embed_dim,
            dtype=self.dtype,
            precision=self.precision,
            kernel_init=self.kernel_init)(inputs)

        pixel_embeddings = AddAbsPosEmbed(
            embed_init=self.pos_embed_init)(pixel_embeddings)

        patch_embeddings = PatchEmbedBlock(patch_shape=self.patch_shape,
                                           embed_dim=self.outer_embed_dim,
                                           dtype=self.dtype,
                                           precision=self.precision,
                                           kernel_init=self.kernel_init,
                                           bias_init=self.bias_init)(inputs)

        b, l, _ = patch_embeddings.shape
        cls_shape = (1, 1, self.outer_embed_dim)
        cls_token = self.param('cls', initializers.zeros, cls_shape)
        cls_token = jnp.tile(cls_token, [b, 1, 1])
        patch_embeddings = jnp.concatenate([cls_token, patch_embeddings],
                                           axis=1)

        patch_embeddings = AddAbsPosEmbed(
            embed_init=self.pos_embed_init)(patch_embeddings)
        patch_embeddings = nn.Dropout(rate=self.dropout_rate)(
            patch_embeddings, deterministic=not is_training)

        patch_embeddings = Encoder(num_layers=self.num_layers,
                                   inner_num_heads=self.inner_num_heads,
                                   outer_num_heads=self.outer_num_heads,
                                   attn_dropout_rate=self.attn_dropout_rate,
                                   dropout_rate=self.dropout_rate,
                                   activation_fn=self.activation_fn,
                                   dtype=self.dtype,
                                   precision=self.precision,
                                   kernel_init=self.kernel_init,
                                   bias_init=self.bias_init)(
                                       patch_embeddings,
                                       pixel_embeddings,
                                       is_training=is_training)

        cls_token = patch_embeddings[:, 0]
        output = nn.Dense(
            features=self.num_classes,
            use_bias=True,
            dtype=self.dtype,
            kernel_init=initializers.zeros,
            bias_init=self.bias_init,
        )(cls_token)
        return output
コード例 #5
0
class CaiT(nn.Module):
    num_classes: int
    num_layers: int
    num_layers_token_only: int
    num_heads: int
    embed_dim: int
    patch_shape: Tuple[int, int]
    expand_ratio: float = 4
    attn_dropout_rate: float = 0.
    dropout_rate: float = 0.
    stoch_depth_rate: float = 0.
    layerscale_eps = float = 0.
    activation_fn: Callable = nn.activation.gelu
    dtype: jnp.dtype = jnp.float32
    precision: Precision = Precision.DEFAULT
    kernel_init: Callable = initializers.kaiming_uniform()
    bias_init: Callable = initializers.zeros

    @nn.compact
    def __call__(self, inputs, is_training: bool):
        x = PatchEmbedBlock(
            patch_shape=self.patch_shape,
            embed_dim=self.embed_dim,
            dtype=self.dtype,
            precision=self.precision,
        )(inputs)

        x = Encoder(num_layers=self.num_layers,
                    num_heads=self.num_heads,
                    expand_ratio=self.expand_ratio,
                    attn_dropout_rate=self.attn_dropout_rate,
                    dropout_rate=self.dropout_rate,
                    stoch_depth_rate=self.stoch_depth_rate,
                    layerscale_eps=self.layerscale_eps,
                    activation_fn=self.activation_fn,
                    dtype=self.dtype,
                    precision=self.precision,
                    kernel_init=self.kernel_init)(x, is_training=is_training)

        b = x.shape[0]
        cls_shape = (1, 1, self.embed_dim)
        cls_token = self.param('cls', initializers.zeros, cls_shape)
        cls_token = jnp.tile(cls_token, [b, 1, 1])

        for _ in range(self.num_layers_token_only):
            cls_token = CAEncoderBlock(
                num_heads=self.num_heads,
                expand_ratio=self.expand_ratio,
                attn_dropout_rate=self.attn_dropout_rate,
                dropout_rate=self.dropout_rate,
                stoch_depth_rate=self.stoch_depth_rate,
                layerscale_eps=self.layerscale_eps,
                attn_class=ClassSelfAttentionBlock,
                activation_fn=self.activation_fn,
                dtype=self.dtype,
                precision=self.precision,
                kernel_init=self.kernel_init)(x,
                                              cls_token,
                                              is_training=is_training)

        x = jnp.concatenate([cls_token, x], axis=1)
        x = nn.LayerNorm(dtype=self.dtype)(x)

        cls_token = x[:, 0]
        output = nn.Dense(
            features=self.num_classes,
            use_bias=True,
            dtype=self.dtype,
            kernel_init=initializers.zeros,
            bias_init=self.bias_init,
        )(cls_token)
        return output
コード例 #6
0
class CvTSelfAttentionBlock(nn.Module):
    num_heads: int
    head_ch: int
    out_ch: int
    kernel_size: int = 3
    strides: Sequence[int] = [1, 2, 2]
    use_bias: bool = False
    bn_momentum: float = 0.9
    bn_epsilon: float = 1e-5
    dtype: jnp.dtype = jnp.float32
    precision: Precision = Precision.DEFAULT
    kernel_init: Callable = initializers.kaiming_uniform()
    bias_init: Callable = initializers.zeros

    @nn.compact
    def __call__(self, inputs, is_training: bool):
        assert len(self.strides) == 3
        assert inputs.ndim == 3
        q_strides, k_strides, v_strides = self.strides
        b, l, c = inputs.shape
        out_ch = self.out_ch if self.out_ch is not None else c
        spatial_ch = int(jnp.ceil(jnp.sqrt(l)))
        inputs = jnp.pad(inputs, ((0, 0), (0, spatial_ch**2 - l), (0, 0)))
        inputs = rearrange(inputs, 'b (H W) c -> b H W c', W=spatial_ch)

        conv_proj = partial(ConvProjectionBlock,
                            out_ch=self.num_heads * self.head_ch,
                            kernel_size=self.kernel_size,
                            use_bias=self.use_bias,
                            bn_momentum=self.bn_momentum,
                            bn_epsilon=self.bn_epsilon,
                            dtype=self.dtype,
                            precision=self.precision,
                            kernel_init=self.kernel_init,
                            bias_init=self.bias_init)

        query = conv_proj(strides=q_strides)(inputs, is_training=is_training)
        key = conv_proj(strides=k_strides)(inputs, is_training=is_training)
        value = conv_proj(strides=v_strides)(inputs, is_training=is_training)

        query = rearrange(query, 'b H W (h d) -> b (H W) h d', h=self.num_heads)
        key = rearrange(key, 'b H W (h d) -> b (H W) h d', h=self.num_heads)
        value = rearrange(value, 'b H W (h d) -> b (H W) h d', h=self.num_heads)

        query = query / jnp.sqrt(self.head_ch)

        attn_weights = jnp.einsum('... q h d, ... k h d -> ... h q k',
                                  query,
                                  key,
                                  precision=self.precision)

        attn_weights = nn.softmax(attn_weights)

        attn_scores = jnp.einsum('... h q k, ... k h d -> ... q h d',
                                 attn_weights,
                                 value,
                                 precision=self.precision)

        if (self.num_heads * self.head_ch) == self.out_ch:
            output = rearrange(attn_scores, '... q h d -> ... q (h d)')
        else:
            output = nn.DenseGeneral(features=self.out_ch,
                                     axis=(-2, -1),
                                     use_bias=self.use_bias,
                                     dtype=self.dtype,
                                     precision=self.precision,
                                     kernel_init=self.kernel_init,
                                     bias_init=self.bias_init)(attn_scores)
        return output