def __call__(self, key, shape, dtype=None): if dtype is None: dtype = "float32" initializer_fn = jax_initializers.kaiming_uniform() return initializer_fn(key, shape, dtype)
class CvT(nn.Module): num_classes: int stage_sizes: Sequence[int] num_heads: Sequence[int] embed_dim: Sequence[int] embed_kernel_size: Sequence[int] = [7, 3, 3] embed_strides: Sequence[int] = [4, 2, 2] sa_kernel_size: Sequence[int] = [3, 3, 3] use_bias: bool = False bn_momentum: float = 0.9 bn_epsilon: float = 1e-5 expand_ratio: int = 4 dtype: jnp.dtype = jnp.float32 precision: Precision = Precision.DEFAULT kernel_init: Callable = initializers.kaiming_uniform() bias_init: Callable = initializers.zeros @nn.compact def __call__(self, inputs, is_training: bool): x = inputs for i in range(len(self.stage_sizes) - 1): x = Stage(size=self.stage_sizes[i], num_heads=self.num_heads[i], embed_dim=self.embed_dim[i], embed_kernel_size=self.embed_kernel_size[i], embed_strides=self.embed_strides[i], sa_kernel_size=self.sa_kernel_size[i], use_bias=self.use_bias, bn_momentum=self.bn_momentum, bn_epsilon=self.bn_epsilon, expand_ratio=self.expand_ratio, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(x, is_training=is_training) l = x.shape[1] spatial_ch = int(jnp.sqrt(l)) x = rearrange(x, 'b (H W) c -> b H W c', H=spatial_ch) x = Stage(size=self.stage_sizes[i], num_heads=self.num_heads[i], embed_dim=self.embed_dim[i], embed_kernel_size=self.embed_kernel_size[i], embed_strides=self.embed_strides[i], sa_kernel_size=self.sa_kernel_size[i], use_bias=self.use_bias, bn_momentum=self.bn_momentum, bn_epsilon=self.bn_epsilon, expand_ratio=self.expand_ratio, insert_cls=True, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(x, is_training=is_training) cls_token = x[:, 0] output = nn.Dense(features=self.num_classes, use_bias=True, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=initializers.zeros)(cls_token) return output
class EncoderBlock(nn.Module): inner_num_heads: int outer_num_heads: int inner_expand_ratio: float = 4 outer_expand_ratio: float = 4 attn_dropout_rate: float = 0. dropout_rate: float = 0. activation_fn = nn.activation.gelu dtype: jnp.dtype = jnp.float32 precision: Precision = Precision.DEFAULT kernel_init: Callable = initializers.kaiming_uniform() bias_init: Callable = initializers.zeros @nn.compact def __call__(self, patch_inputs, pixel_inputs, is_training: bool): inner_x = nn.LayerNorm(dtype=self.dtype)(pixel_inputs) inner_x = SelfAttentionBlock(num_heads=self.inner_num_heads, attn_drop_rate=self.attn_dropout_rate, out_drop_rate=self.dropout_rate, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)( inner_x, is_training=is_training) inner_x = inner_x + pixel_inputs inner_y = nn.LayerNorm(dtype=self.dtype)(inner_x) inner_y = FFBlock(expand_ratio=self.inner_expand_ratio, dropout_rate=self.dropout_rate, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(inner_y, is_training=is_training) inner_output = inner_x + inner_y outer_x = Inner2OuterBlock(dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(inner_output, patch_inputs) outer_x = nn.LayerNorm(dtype=self.dtype)(outer_x) outer_x = SelfAttentionBlock(num_heads=self.outer_num_heads, attn_drop_rate=self.attn_dropout_rate, out_drop_rate=self.dropout_rate, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)( outer_x, is_training=is_training) outer_x = outer_x + patch_inputs outer_y = nn.LayerNorm(dtype=self.dtype)(outer_x) outer_y = FFBlock(expand_ratio=self.outer_expand_ratio, dropout_rate=self.dropout_rate, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(outer_y, is_training=is_training) outer_output = outer_x + outer_y return outer_output, inner_output
class TNT(nn.Module): num_layers: int patch_shape: Tuple[int, int] transformed_patch_shape: Tuple[int, int] inner_num_heads: int outer_num_heads: int inner_embed_dim: int outer_embed_dim: int inner_expand_ratio: float = 4 outer_expand_ratio: float = 4 attn_dropout_rate: float = 0. dropout_rate: float = 0. activation_fn = nn.activation.gelu dtype: jnp.dtype = jnp.float32 precision: Precision = Precision.DEFAULT kernel_init: Callable = initializers.kaiming_uniform() bias_init: Callable = initializers.zeros pos_embed_init: Callable = initializers.normal(stddev=0.02) @nn.compact def __call__(self, inputs, is_training: bool): pixel_embeddings = PixelEmbedBlock( patch_shape=self.patch_shape, transformed_patch_shape=self.transformed_patch_shape, embed_dim=self.inner_embed_dim, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init)(inputs) pixel_embeddings = AddAbsPosEmbed( embed_init=self.pos_embed_init)(pixel_embeddings) patch_embeddings = PatchEmbedBlock(patch_shape=self.patch_shape, embed_dim=self.outer_embed_dim, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(inputs) b, l, _ = patch_embeddings.shape cls_shape = (1, 1, self.outer_embed_dim) cls_token = self.param('cls', initializers.zeros, cls_shape) cls_token = jnp.tile(cls_token, [b, 1, 1]) patch_embeddings = jnp.concatenate([cls_token, patch_embeddings], axis=1) patch_embeddings = AddAbsPosEmbed( embed_init=self.pos_embed_init)(patch_embeddings) patch_embeddings = nn.Dropout(rate=self.dropout_rate)( patch_embeddings, deterministic=not is_training) patch_embeddings = Encoder(num_layers=self.num_layers, inner_num_heads=self.inner_num_heads, outer_num_heads=self.outer_num_heads, attn_dropout_rate=self.attn_dropout_rate, dropout_rate=self.dropout_rate, activation_fn=self.activation_fn, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)( patch_embeddings, pixel_embeddings, is_training=is_training) cls_token = patch_embeddings[:, 0] output = nn.Dense( features=self.num_classes, use_bias=True, dtype=self.dtype, kernel_init=initializers.zeros, bias_init=self.bias_init, )(cls_token) return output
class CaiT(nn.Module): num_classes: int num_layers: int num_layers_token_only: int num_heads: int embed_dim: int patch_shape: Tuple[int, int] expand_ratio: float = 4 attn_dropout_rate: float = 0. dropout_rate: float = 0. stoch_depth_rate: float = 0. layerscale_eps = float = 0. activation_fn: Callable = nn.activation.gelu dtype: jnp.dtype = jnp.float32 precision: Precision = Precision.DEFAULT kernel_init: Callable = initializers.kaiming_uniform() bias_init: Callable = initializers.zeros @nn.compact def __call__(self, inputs, is_training: bool): x = PatchEmbedBlock( patch_shape=self.patch_shape, embed_dim=self.embed_dim, dtype=self.dtype, precision=self.precision, )(inputs) x = Encoder(num_layers=self.num_layers, num_heads=self.num_heads, expand_ratio=self.expand_ratio, attn_dropout_rate=self.attn_dropout_rate, dropout_rate=self.dropout_rate, stoch_depth_rate=self.stoch_depth_rate, layerscale_eps=self.layerscale_eps, activation_fn=self.activation_fn, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init)(x, is_training=is_training) b = x.shape[0] cls_shape = (1, 1, self.embed_dim) cls_token = self.param('cls', initializers.zeros, cls_shape) cls_token = jnp.tile(cls_token, [b, 1, 1]) for _ in range(self.num_layers_token_only): cls_token = CAEncoderBlock( num_heads=self.num_heads, expand_ratio=self.expand_ratio, attn_dropout_rate=self.attn_dropout_rate, dropout_rate=self.dropout_rate, stoch_depth_rate=self.stoch_depth_rate, layerscale_eps=self.layerscale_eps, attn_class=ClassSelfAttentionBlock, activation_fn=self.activation_fn, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init)(x, cls_token, is_training=is_training) x = jnp.concatenate([cls_token, x], axis=1) x = nn.LayerNorm(dtype=self.dtype)(x) cls_token = x[:, 0] output = nn.Dense( features=self.num_classes, use_bias=True, dtype=self.dtype, kernel_init=initializers.zeros, bias_init=self.bias_init, )(cls_token) return output
class CvTSelfAttentionBlock(nn.Module): num_heads: int head_ch: int out_ch: int kernel_size: int = 3 strides: Sequence[int] = [1, 2, 2] use_bias: bool = False bn_momentum: float = 0.9 bn_epsilon: float = 1e-5 dtype: jnp.dtype = jnp.float32 precision: Precision = Precision.DEFAULT kernel_init: Callable = initializers.kaiming_uniform() bias_init: Callable = initializers.zeros @nn.compact def __call__(self, inputs, is_training: bool): assert len(self.strides) == 3 assert inputs.ndim == 3 q_strides, k_strides, v_strides = self.strides b, l, c = inputs.shape out_ch = self.out_ch if self.out_ch is not None else c spatial_ch = int(jnp.ceil(jnp.sqrt(l))) inputs = jnp.pad(inputs, ((0, 0), (0, spatial_ch**2 - l), (0, 0))) inputs = rearrange(inputs, 'b (H W) c -> b H W c', W=spatial_ch) conv_proj = partial(ConvProjectionBlock, out_ch=self.num_heads * self.head_ch, kernel_size=self.kernel_size, use_bias=self.use_bias, bn_momentum=self.bn_momentum, bn_epsilon=self.bn_epsilon, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init) query = conv_proj(strides=q_strides)(inputs, is_training=is_training) key = conv_proj(strides=k_strides)(inputs, is_training=is_training) value = conv_proj(strides=v_strides)(inputs, is_training=is_training) query = rearrange(query, 'b H W (h d) -> b (H W) h d', h=self.num_heads) key = rearrange(key, 'b H W (h d) -> b (H W) h d', h=self.num_heads) value = rearrange(value, 'b H W (h d) -> b (H W) h d', h=self.num_heads) query = query / jnp.sqrt(self.head_ch) attn_weights = jnp.einsum('... q h d, ... k h d -> ... h q k', query, key, precision=self.precision) attn_weights = nn.softmax(attn_weights) attn_scores = jnp.einsum('... h q k, ... k h d -> ... q h d', attn_weights, value, precision=self.precision) if (self.num_heads * self.head_ch) == self.out_ch: output = rearrange(attn_scores, '... q h d -> ... q (h d)') else: output = nn.DenseGeneral(features=self.out_ch, axis=(-2, -1), use_bias=self.use_bias, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(attn_scores) return output