def __call__(self, patch_inputs, pixel_inputs, is_training: bool): inner_x = nn.LayerNorm(dtype=self.dtype)(pixel_inputs) inner_x = SelfAttentionBlock(num_heads=self.inner_num_heads, attn_dropout_rate=self.attn_dropout_rate, out_dropout_rate=self.dropout_rate, dtype=self.dtype)(inner_x, is_training=is_training) inner_x = inner_x + pixel_inputs inner_y = nn.LayerNorm(dtype=self.dtype)(inner_x) inner_y = FFBlock(expand_ratio=self.inner_expand_ratio, dropout_rate=self.dropout_rate, dtype=self.dtype)(inner_y, is_training=is_training) inner_output = inner_x + inner_y outer_x = Inner2OuterBlock(dtype=self.dtype)(patch_inputs, inner_output) outer_x = nn.LayerNorm(dtype=self.dtype)(outer_x) outer_x = SelfAttentionBlock(num_heads=self.outer_num_heads, attn_dropout_rate=self.attn_dropout_rate, out_dropout_rate=self.dropout_rate, dtype=self.dtype)(outer_x, is_training=is_training) outer_x = outer_x + patch_inputs outer_y = nn.LayerNorm(dtype=self.dtype)(outer_x) outer_y = FFBlock(expand_ratio=self.outer_expand_ratio, dropout_rate=self.dropout_rate, dtype=self.dtype)(outer_y, is_training=is_training) outer_output = outer_x + outer_y return outer_output, inner_output
def __call__(self, inputs, is_training: bool): x = nn.LayerNorm(dtype=self.dtype)(inputs) x = rearrange(x, '... l d -> ... d l') x = FFBlock(expand_ratio=self.tokens_expand_ratio, activation_fn=self.activation_fn, dtype=self.dtype)(x, is_training=is_training) x = rearrange(x, '... d l -> ... l d') x = x + inputs y = nn.LayerNorm(dtype=self.dtype)(x) y = FFBlock(expand_ratio=self.channels_expand_ratio, activation_fn=self.activation_fn, dtype=self.dtype)(y, is_training=is_training) output = x + y return output
def __call__(self, inputs, cls_token, is_training: bool): x = jnp.concatenate([cls_token, inputs], axis=1) x = nn.LayerNorm(dtype=self.dtype)(x) x = ClassSelfAttentionBlock(num_heads=self.num_heads, attn_dropout_rate=self.attn_dropout_rate, out_dropout_rate=self.dropout_rate, dtype=self.dtype)(x, is_training=is_training) x = LayerScaleBlock(eps=self.layerscale_eps, dtype=self.dtype)(x, is_training=is_training) x = StochasticDepthBlock(drop_rate=self.stoch_depth_rate)( x, is_training=is_training) cls_token = cls_token + x y = nn.LayerNorm(dtype=self.dtype)(cls_token) y = FFBlock(expand_ratio=self.expand_ratio, dropout_rate=self.dropout_rate, activation_fn=self.activation_fn, dtype=self.dtype)(y, is_training=is_training) y = LayerScaleBlock(eps=self.layerscale_eps, dtype=self.dtype)(y, is_training=is_training) y = StochasticDepthBlock(drop_rate=self.stoch_depth_rate)( y, is_training=is_training) output = cls_token + y return output
def __call__(self, inputs, is_training: bool): x = nn.LayerNorm(dtype=self.dtype)(inputs) x = SelfAttentionBlock(num_heads=self.num_heads, talking_heads=True, attn_dropout_rate=self.attn_dropout_rate, out_dropout_rate=self.dropout_rate, dtype=self.dtype)(x, is_training=is_training) x = LayerScaleBlock(eps=self.layerscale_eps, dtype=self.dtype)(x, is_training=is_training) x = StochasticDepthBlock(drop_rate=self.stoch_depth_rate)( x, is_training=is_training) x = x + inputs y = nn.LayerNorm(dtype=self.dtype)(x) y = FFBlock(expand_ratio=self.expand_ratio, dropout_rate=self.dropout_rate, activation_fn=self.activation_fn, dtype=self.dtype)(y, is_training=is_training) y = LayerScaleBlock(eps=self.layerscale_eps, dtype=self.dtype)(y, is_training=is_training) y = StochasticDepthBlock(drop_rate=self.stoch_depth_rate)( y, is_training=is_training) output = x + y return output
def __call__(self, inputs, is_training: bool): assert inputs.ndim == 4 assert self.embed_dim % self.num_heads == 0 head_ch = int(self.embed_dim / self.num_heads) x = CvTSelfAttentionBlock(num_heads=self.num_heads, head_ch=head_ch, out_ch=self.embed_dim, kernel_size=self.kernel_size, use_bias=self.use_bias, bn_momentum=self.bn_momentum, bn_epsilon=self.bn_epsilon, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)( inputs, is_training=is_training) x = x + inputs y = nn.LayerNorm(dtype=self.dtype)(x) y = FFBlock(expand_ratio=self.expand_ratio, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(y, is_training=is_training) output = x + y return output
def __call__(self, inputs, is_training: bool): x = nn.LayerNorm(dtype=self.dtype)(inputs) x = SelfAttentionBlock(num_heads=self.num_heads, head_ch=self.head_ch, out_ch=self.num_heads * self.head_ch, dropout_rate=self.attn_dropout_rate, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init)( x, is_training=is_training) x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not is_training) x += inputs y = nn.LayerNorm(dtype=self.dtype)(x) y = FFBlock(expand_ratio=self.expand_ratio, dropout_rate=self.dropout_rate, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(y, train=is_training) output = x + y return output
def __call__(self, inputs): x = nn.LayerNorm(dtype=self.dype)(inputs) x = rearrange(x, '... l d -> ... d l') x = FFBlock(expand_ratio=self.tokens_expand_ratio, activation_fn=self.activation_fn, dype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) x = rearrange(x, '... d l -> ... l d') x = x + inputs y = nn.LayerNorm(dtype=self.dtype)(x) y = FFBlock(expand_ratio=self.channels_expand_ratio, activation_fn=self.activation_fn, dype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(y) output = x + y return output
def __call__(self, inputs, is_training: bool): x = LCSelfAttentionBlock(num_heads=self.num_heads, dtype=self.dtype)(inputs, is_training=is_training) x += inputs x = nn.LayerNorm(dtype=self.dtype)(x) y = FFBlock(expand_ratio=self.expand_ratio, activation_fn=self.activation_fn, dtype=self.dtype)(x, is_training=is_training) y = x + y output = nn.LayerNorm(dtype=self.dtype)(y) return output
def __call__(self, inputs, is_training: bool): x = nn.LayerNorm(dtype=self.dtype)(inputs) x = SelfAttentionBlock(num_heads=self.num_heads, attn_dropout_rate=self.attn_dropout_rate, out_dropout_rate=self.dropout_rate, dtype=self.dtype)(x, is_training=is_training) x = x + inputs y = nn.LayerNorm(dtype=self.dtype)(x) y = FFBlock(expand_ratio=self.expand_ratio, dropout_rate=self.dropout_rate, activation_fn=self.activation_fn, dtype=self.dtype)(y, is_training=is_training) output = x + y return output
def __call__(self, inputs, is_training: bool): inputs = zero_pad_and_reshape(inputs) x = CvTSelfAttentionBlock(num_heads=self.num_heads, kernel_size=self.kernel_size, use_bias=self.use_bias, bn_momentum=self.bn_momentum, bn_epsilon=self.bn_epsilon, dtype=self.dtype)(inputs, is_training=is_training) x = x + rearrange(inputs, 'b h w d -> b (h w) d') y = nn.LayerNorm(dtype=self.dtype)(x) y = FFBlock(expand_ratio=self.expand_ratio, activation_fn=self.activation_fn, dtype=self.dtype)(y, is_training=is_training) output = x + y return output
def __call__(self, inputs, is_training: bool): x = CvTSelfAttentionBlock(num_heads=self.num_heads, kernel_size=self.kernel_size, use_bias=self.use_bias, bn_momentum=self.bn_momentum, bn_epsilon=self.bn_epsilon, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)( inputs, is_training=is_training) x = x + inputs y = nn.LayerNorm(dtype=self.dtype)(x) y = FFBlock(expand_ratio=self.expand_ratio, activation_fn=self.activation_fn, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(y, is_training=is_training) output = x + y return output