def __call__(self, patch_inputs, pixel_inputs, is_training: bool): inner_x = nn.LayerNorm(dtype=self.dtype)(pixel_inputs) inner_x = SelfAttentionBlock(num_heads=self.inner_num_heads, attn_dropout_rate=self.attn_dropout_rate, out_dropout_rate=self.dropout_rate, dtype=self.dtype)(inner_x, is_training=is_training) inner_x = inner_x + pixel_inputs inner_y = nn.LayerNorm(dtype=self.dtype)(inner_x) inner_y = FFBlock(expand_ratio=self.inner_expand_ratio, dropout_rate=self.dropout_rate, dtype=self.dtype)(inner_y, is_training=is_training) inner_output = inner_x + inner_y outer_x = Inner2OuterBlock(dtype=self.dtype)(patch_inputs, inner_output) outer_x = nn.LayerNorm(dtype=self.dtype)(outer_x) outer_x = SelfAttentionBlock(num_heads=self.outer_num_heads, attn_dropout_rate=self.attn_dropout_rate, out_dropout_rate=self.dropout_rate, dtype=self.dtype)(outer_x, is_training=is_training) outer_x = outer_x + patch_inputs outer_y = nn.LayerNorm(dtype=self.dtype)(outer_x) outer_y = FFBlock(expand_ratio=self.outer_expand_ratio, dropout_rate=self.dropout_rate, dtype=self.dtype)(outer_y, is_training=is_training) outer_output = outer_x + outer_y return outer_output, inner_output
def __call__(self, inputs, is_training: bool): x = nn.LayerNorm(dtype=self.dtype)(inputs) x = SelfAttentionBlock(num_heads=self.num_heads, talking_heads=True, attn_dropout_rate=self.attn_dropout_rate, out_dropout_rate=self.dropout_rate, dtype=self.dtype)(x, is_training=is_training) x = LayerScaleBlock(eps=self.layerscale_eps, dtype=self.dtype)(x, is_training=is_training) x = StochasticDepthBlock(drop_rate=self.stoch_depth_rate)( x, is_training=is_training) x = x + inputs y = nn.LayerNorm(dtype=self.dtype)(x) y = FFBlock(expand_ratio=self.expand_ratio, dropout_rate=self.dropout_rate, activation_fn=self.activation_fn, dtype=self.dtype)(y, is_training=is_training) y = LayerScaleBlock(eps=self.layerscale_eps, dtype=self.dtype)(y, is_training=is_training) y = StochasticDepthBlock(drop_rate=self.stoch_depth_rate)( y, is_training=is_training) output = x + y return output
def __call__(self, inputs, is_training: bool): x = nn.LayerNorm(dtype=self.dtype)(inputs) x = SelfAttentionBlock(num_heads=self.num_heads, head_ch=self.head_ch, out_ch=self.num_heads * self.head_ch, dropout_rate=self.attn_dropout_rate, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init)( x, is_training=is_training) x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not is_training) x += inputs y = nn.LayerNorm(dtype=self.dtype)(x) y = FFBlock(expand_ratio=self.expand_ratio, dropout_rate=self.dropout_rate, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(y, train=is_training) output = x + y return output
def __call__(self, inputs, is_training: bool): x = nn.LayerNorm(dtype=self.dtype)(inputs) x = SelfAttentionBlock(num_heads=self.num_heads, attn_dropout_rate=self.attn_dropout_rate, out_dropout_rate=self.dropout_rate, dtype=self.dtype)(x, is_training=is_training) x = x + inputs y = nn.LayerNorm(dtype=self.dtype)(x) y = FFBlock(expand_ratio=self.expand_ratio, dropout_rate=self.dropout_rate, activation_fn=self.activation_fn, dtype=self.dtype)(y, is_training=is_training) output = x + y return output
def __call__(self, inputs, is_training: bool): x = SelfAttentionBlock(num_heads=self.num_heads, dtype=self.dtype)(inputs, is_training=is_training) x += inputs x = nn.LayerNorm(dtype=self.dtype)(x) y = LeFFBlock(expand_ratio=self.expand_ratio, kernel_size=self.leff_kernel_size, activation_fn=self.activation_fn, bn_momentum=self.bn_momentum, bn_epsilon=self.bn_epsilon, dtype=self.dtype)(x, is_training=is_training) y = x + y output = nn.LayerNorm(dtype=self.dtype)(y) return output