Пример #1
0
    def __call__(self, patch_inputs, pixel_inputs, is_training: bool):
        inner_x = nn.LayerNorm(dtype=self.dtype)(pixel_inputs)
        inner_x = SelfAttentionBlock(num_heads=self.inner_num_heads,
                                     attn_dropout_rate=self.attn_dropout_rate,
                                     out_dropout_rate=self.dropout_rate,
                                     dtype=self.dtype)(inner_x,
                                                       is_training=is_training)
        inner_x = inner_x + pixel_inputs
        inner_y = nn.LayerNorm(dtype=self.dtype)(inner_x)
        inner_y = FFBlock(expand_ratio=self.inner_expand_ratio,
                          dropout_rate=self.dropout_rate,
                          dtype=self.dtype)(inner_y, is_training=is_training)
        inner_output = inner_x + inner_y

        outer_x = Inner2OuterBlock(dtype=self.dtype)(patch_inputs,
                                                     inner_output)

        outer_x = nn.LayerNorm(dtype=self.dtype)(outer_x)
        outer_x = SelfAttentionBlock(num_heads=self.outer_num_heads,
                                     attn_dropout_rate=self.attn_dropout_rate,
                                     out_dropout_rate=self.dropout_rate,
                                     dtype=self.dtype)(outer_x,
                                                       is_training=is_training)
        outer_x = outer_x + patch_inputs
        outer_y = nn.LayerNorm(dtype=self.dtype)(outer_x)
        outer_y = FFBlock(expand_ratio=self.outer_expand_ratio,
                          dropout_rate=self.dropout_rate,
                          dtype=self.dtype)(outer_y, is_training=is_training)
        outer_output = outer_x + outer_y

        return outer_output, inner_output
    def __call__(self, inputs, is_training: bool):
        x = nn.LayerNorm(dtype=self.dtype)(inputs)
        x = SelfAttentionBlock(num_heads=self.num_heads,
                               talking_heads=True,
                               attn_dropout_rate=self.attn_dropout_rate,
                               out_dropout_rate=self.dropout_rate,
                               dtype=self.dtype)(x, is_training=is_training)
        x = LayerScaleBlock(eps=self.layerscale_eps,
                            dtype=self.dtype)(x, is_training=is_training)
        x = StochasticDepthBlock(drop_rate=self.stoch_depth_rate)(
            x, is_training=is_training)
        x = x + inputs

        y = nn.LayerNorm(dtype=self.dtype)(x)
        y = FFBlock(expand_ratio=self.expand_ratio,
                    dropout_rate=self.dropout_rate,
                    activation_fn=self.activation_fn,
                    dtype=self.dtype)(y, is_training=is_training)
        y = LayerScaleBlock(eps=self.layerscale_eps,
                            dtype=self.dtype)(y, is_training=is_training)
        y = StochasticDepthBlock(drop_rate=self.stoch_depth_rate)(
            y, is_training=is_training)

        output = x + y
        return output
    def __call__(self, inputs, is_training: bool):
        x = nn.LayerNorm(dtype=self.dtype)(inputs)
        x = SelfAttentionBlock(num_heads=self.num_heads,
                               head_ch=self.head_ch,
                               out_ch=self.num_heads * self.head_ch,
                               dropout_rate=self.attn_dropout_rate,
                               dtype=self.dtype,
                               precision=self.precision,
                               kernel_init=self.kernel_init)(
                                   x, is_training=is_training)
        x = nn.Dropout(rate=self.dropout_rate)(x,
                                               deterministic=not is_training)

        x += inputs

        y = nn.LayerNorm(dtype=self.dtype)(x)
        y = FFBlock(expand_ratio=self.expand_ratio,
                    dropout_rate=self.dropout_rate,
                    dtype=self.dtype,
                    precision=self.precision,
                    kernel_init=self.kernel_init,
                    bias_init=self.bias_init)(y, train=is_training)

        output = x + y
        return output
    def __call__(self, inputs, is_training: bool):
        x = nn.LayerNorm(dtype=self.dtype)(inputs)
        x = SelfAttentionBlock(num_heads=self.num_heads,
                               attn_dropout_rate=self.attn_dropout_rate,
                               out_dropout_rate=self.dropout_rate,
                               dtype=self.dtype)(x, is_training=is_training)
        x = x + inputs

        y = nn.LayerNorm(dtype=self.dtype)(x)
        y = FFBlock(expand_ratio=self.expand_ratio,
                    dropout_rate=self.dropout_rate,
                    activation_fn=self.activation_fn,
                    dtype=self.dtype)(y, is_training=is_training)
        output = x + y
        return output
    def __call__(self, inputs, is_training: bool):
        x = SelfAttentionBlock(num_heads=self.num_heads,
                               dtype=self.dtype)(inputs,
                                                 is_training=is_training)
        x += inputs
        x = nn.LayerNorm(dtype=self.dtype)(x)

        y = LeFFBlock(expand_ratio=self.expand_ratio,
                      kernel_size=self.leff_kernel_size,
                      activation_fn=self.activation_fn,
                      bn_momentum=self.bn_momentum,
                      bn_epsilon=self.bn_epsilon,
                      dtype=self.dtype)(x, is_training=is_training)
        y = x + y
        output = nn.LayerNorm(dtype=self.dtype)(y)
        return output