Пример #1
0
    def __init__(self,
                 num_tokens,
                 embed_dims,
                 tokens_mlp_dims,
                 channels_mlp_dims,
                 drop_rate=0.,
                 drop_path_rate=0.,
                 num_fcs=2,
                 act_cfg=dict(type='GELU'),
                 norm_cfg=dict(type='LN'),
                 init_cfg=None):
        super(MixerBlock, self).__init__(init_cfg=init_cfg)

        self.norm1_name, norm1 = build_norm_layer(
            norm_cfg, embed_dims, postfix=1)
        self.add_module(self.norm1_name, norm1)
        self.token_mix = FFN(
            embed_dims=num_tokens,
            feedforward_channels=tokens_mlp_dims,
            num_fcs=num_fcs,
            ffn_drop=drop_rate,
            dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
            act_cfg=act_cfg,
            add_identity=False)

        self.norm2_name, norm2 = build_norm_layer(
            norm_cfg, embed_dims, postfix=2)
        self.add_module(self.norm2_name, norm2)
        self.channel_mix = FFN(
            embed_dims=embed_dims,
            feedforward_channels=channels_mlp_dims,
            num_fcs=num_fcs,
            ffn_drop=drop_rate,
            dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
            act_cfg=act_cfg)
Пример #2
0
    def __init__(self,
                 embed_dims,
                 num_heads,
                 ffn_ratio=4,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.,
                 num_fcs=2,
                 qkv_bias=False,
                 act_cfg=dict(type='GELU'),
                 norm_cfg=dict(type='LN'),
                 batch_first=True,
                 init_cfg=None):
        super(TransformerBlock, self).__init__(init_cfg=init_cfg)

        self.norm_attn = build_norm_layer(norm_cfg, embed_dims)[1]
        self.attn = MultiheadAttention(
            embed_dims=embed_dims,
            num_heads=num_heads,
            attn_drop=attn_drop_rate,
            proj_drop=drop_rate,
            dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
            batch_first=batch_first)

        self.norm_ffn = build_norm_layer(norm_cfg, embed_dims)[1]
        self.ffn = FFN(
            embed_dims=embed_dims,
            feedforward_channels=embed_dims * ffn_ratio,
            num_fcs=num_fcs,
            ffn_drop=drop_rate,
            dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
            act_cfg=act_cfg)

        if not qkv_bias:
            self.attn.attn.in_proj_bias = None
Пример #3
0
def test_ffn():
    with pytest.raises(AssertionError):
        # num_fcs should be no less than 2
        FFN(num_fcs=1)
    FFN(dropout=0, add_residual=True)
    ffn = FFN(dropout=0, add_identity=True)

    input_tensor = torch.rand(2, 20, 256)
    input_tensor_nbc = input_tensor.transpose(0, 1)
    assert torch.allclose(ffn(input_tensor).sum(), ffn(input_tensor_nbc).sum())
    residual = torch.rand_like(input_tensor)
    torch.allclose(
        ffn(input_tensor, residual=residual).sum(),
        ffn(input_tensor).sum() + residual.sum() - input_tensor.sum())

    torch.allclose(
        ffn(input_tensor, identity=residual).sum(),
        ffn(input_tensor).sum() + residual.sum() - input_tensor.sum())
    def __init__(self,
                 embed_dims,
                 num_heads,
                 feedforward_channels,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.,
                 num_fcs=2,
                 qkv_bias=True,
                 act_cfg=dict(type='GELU'),
                 norm_cfg=dict(type='LN'),
                 init_cfg=None):
        super(TransformerEncoderLayer, self).__init__(init_cfg=init_cfg)

        self.embed_dims = embed_dims

        self.norm1_name, norm1 = build_norm_layer(norm_cfg,
                                                  self.embed_dims,
                                                  postfix=1)
        self.add_module(self.norm1_name, norm1)

        self.attn = MultiheadAttention(embed_dims=embed_dims,
                                       num_heads=num_heads,
                                       attn_drop=attn_drop_rate,
                                       proj_drop=drop_rate,
                                       dropout_layer=dict(
                                           type='DropPath',
                                           drop_prob=drop_path_rate),
                                       qkv_bias=qkv_bias)

        self.norm2_name, norm2 = build_norm_layer(norm_cfg,
                                                  self.embed_dims,
                                                  postfix=2)
        self.add_module(self.norm2_name, norm2)

        self.ffn = FFN(embed_dims=embed_dims,
                       feedforward_channels=feedforward_channels,
                       num_fcs=num_fcs,
                       ffn_drop=drop_rate,
                       dropout_layer=dict(type='DropPath',
                                          drop_prob=drop_path_rate),
                       act_cfg=act_cfg)
Пример #5
0
 def _init_layers(self):
     """Initialize layers of the transformer head."""
     self.input_proj = Conv2d(self.in_channels,
                              self.embed_dims,
                              kernel_size=1)
     self.fc_cls = Linear(self.embed_dims, self.cls_out_channels)
     self.reg_ffn = FFN(self.embed_dims,
                        self.embed_dims,
                        self.num_reg_fcs,
                        self.act_cfg,
                        dropout=0.0,
                        add_residual=False)
     self.fc_reg = Linear(self.embed_dims, 4)
     self.query_embedding = nn.Embedding(self.num_query, self.embed_dims)
Пример #6
0
    def __init__(self,
                 embed_dims,
                 num_heads,
                 feedforward_channels,
                 input_dims=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.,
                 num_fcs=2,
                 qkv_bias=False,
                 qk_scale=None,
                 act_cfg=dict(type='GELU'),
                 norm_cfg=dict(type='LN'),
                 init_cfg=None):
        super(T2TTransformerLayer, self).__init__(init_cfg=init_cfg)

        self.v_shortcut = True if input_dims is not None else False
        input_dims = input_dims or embed_dims

        self.norm1_name, norm1 = build_norm_layer(norm_cfg,
                                                  input_dims,
                                                  postfix=1)
        self.add_module(self.norm1_name, norm1)

        self.attn = MultiheadAttention(
            input_dims=input_dims,
            embed_dims=embed_dims,
            num_heads=num_heads,
            attn_drop=attn_drop_rate,
            proj_drop=drop_rate,
            dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
            qkv_bias=qkv_bias,
            qk_scale=qk_scale or (input_dims // num_heads)**-0.5,
            v_shortcut=self.v_shortcut)

        self.norm2_name, norm2 = build_norm_layer(norm_cfg,
                                                  embed_dims,
                                                  postfix=2)
        self.add_module(self.norm2_name, norm2)

        self.ffn = FFN(embed_dims=embed_dims,
                       feedforward_channels=feedforward_channels,
                       num_fcs=num_fcs,
                       ffn_drop=drop_rate,
                       dropout_layer=dict(type='DropPath',
                                          drop_prob=drop_path_rate),
                       act_cfg=act_cfg)
Пример #7
0
    def __init__(self,
                 embed_dims,
                 num_heads,
                 feedforward_channels,
                 window_size=7,
                 shift=False,
                 qkv_bias=True,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.,
                 act_cfg=dict(type='GELU'),
                 norm_cfg=dict(type='LN'),
                 with_cp=False,
                 init_cfg=None):

        super(SwinBlock, self).__init__()

        self.init_cfg = init_cfg
        self.with_cp = with_cp

        self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
        self.attn = ShiftWindowMSA(embed_dims=embed_dims,
                                   num_heads=num_heads,
                                   window_size=window_size,
                                   shift_size=window_size // 2 if shift else 0,
                                   qkv_bias=qkv_bias,
                                   qk_scale=qk_scale,
                                   attn_drop_rate=attn_drop_rate,
                                   proj_drop_rate=drop_rate,
                                   dropout_layer=dict(
                                       type='DropPath',
                                       drop_prob=drop_path_rate),
                                   init_cfg=None)

        self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
        self.ffn = FFN(embed_dims=embed_dims,
                       feedforward_channels=feedforward_channels,
                       num_fcs=2,
                       ffn_drop=drop_rate,
                       dropout_layer=dict(type='DropPath',
                                          drop_prob=drop_path_rate),
                       act_cfg=act_cfg,
                       add_identity=True,
                       init_cfg=None)
Пример #8
0
    def __init__(self,
                 embed_dims,
                 input_resolution,
                 num_heads,
                 window_size=7,
                 shift=False,
                 ffn_ratio=4.,
                 drop_path=0.,
                 attn_cfgs=dict(),
                 ffn_cfgs=dict(),
                 norm_cfg=dict(type='LN'),
                 with_cp=False,
                 auto_pad=False,
                 init_cfg=None):

        super(SwinBlock, self).__init__(init_cfg)
        self.with_cp = with_cp

        _attn_cfgs = {
            'embed_dims': embed_dims,
            'input_resolution': input_resolution,
            'num_heads': num_heads,
            'shift_size': window_size // 2 if shift else 0,
            'window_size': window_size,
            'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
            'auto_pad': auto_pad,
            **attn_cfgs
        }
        self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
        self.attn = ShiftWindowMSA(**_attn_cfgs)

        _ffn_cfgs = {
            'embed_dims': embed_dims,
            'feedforward_channels': int(embed_dims * ffn_ratio),
            'num_fcs': 2,
            'ffn_drop': 0,
            'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
            'act_cfg': dict(type='GELU'),
            **ffn_cfgs
        }
        self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
        self.ffn = FFN(**_ffn_cfgs)
Пример #9
0
    def __init__(self,
                 embed_dims,
                 num_heads,
                 feedforward_channels,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.,
                 num_fcs=2,
                 qkv_bias=True,
                 act_cfg=dict(type='GELU'),
                 norm_cfg=dict(type='LN'),
                 sr_ratio=1.,
                 init_cfg=None):
        super(GSAEncoderLayer, self).__init__(init_cfg=init_cfg)

        self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1]
        self.attn = GlobalSubsampledAttention(embed_dims=embed_dims,
                                              num_heads=num_heads,
                                              attn_drop=attn_drop_rate,
                                              proj_drop=drop_rate,
                                              dropout_layer=dict(
                                                  type='DropPath',
                                                  drop_prob=drop_path_rate),
                                              qkv_bias=qkv_bias,
                                              norm_cfg=norm_cfg,
                                              sr_ratio=sr_ratio)

        self.norm2 = build_norm_layer(norm_cfg, embed_dims, postfix=2)[1]
        self.ffn = FFN(embed_dims=embed_dims,
                       feedforward_channels=feedforward_channels,
                       num_fcs=num_fcs,
                       ffn_drop=drop_rate,
                       dropout_layer=dict(type='DropPath',
                                          drop_prob=drop_path_rate),
                       act_cfg=act_cfg,
                       add_identity=False)

        self.drop_path = build_dropout(
            dict(type='DropPath', drop_prob=drop_path_rate)
        ) if drop_path_rate > 0. else nn.Identity()
Пример #10
0
    def __init__(self,
                 embed_dims,
                 num_heads,
                 feedforward_channels,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.,
                 num_fcs=2,
                 qkv_bias=True,
                 qk_scale=None,
                 act_cfg=dict(type='GELU'),
                 norm_cfg=dict(type='LN'),
                 window_size=1,
                 init_cfg=None):

        super(LSAEncoderLayer, self).__init__(init_cfg=init_cfg)

        self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1]
        self.attn = LocallyGroupedSelfAttention(embed_dims, num_heads,
                                                qkv_bias, qk_scale,
                                                attn_drop_rate, drop_rate,
                                                window_size)

        self.norm2 = build_norm_layer(norm_cfg, embed_dims, postfix=2)[1]
        self.ffn = FFN(embed_dims=embed_dims,
                       feedforward_channels=feedforward_channels,
                       num_fcs=num_fcs,
                       ffn_drop=drop_rate,
                       dropout_layer=dict(type='DropPath',
                                          drop_prob=drop_path_rate),
                       act_cfg=act_cfg,
                       add_identity=False)

        self.drop_path = build_dropout(
            dict(type='DropPath', drop_prob=drop_path_rate)
        ) if drop_path_rate > 0. else nn.Identity()
Пример #11
0
    def __init__(self,
                 num_classes=80,
                 num_ffn_fcs=2,
                 num_heads=8,
                 num_cls_fcs=1,
                 num_reg_fcs=3,
                 feedforward_channels=2048,
                 in_channels=256,
                 dropout=0.0,
                 ffn_act_cfg=dict(type='ReLU', inplace=True),
                 dynamic_conv_cfg=dict(type='DynamicConv',
                                       in_channels=256,
                                       feat_channels=64,
                                       out_channels=256,
                                       input_feat_shape=7,
                                       act_cfg=dict(type='ReLU', inplace=True),
                                       norm_cfg=dict(type='LN')),
                 loss_iou=dict(type='GIoULoss', loss_weight=2.0),
                 init_cfg=None,
                 **kwargs):
        assert init_cfg is None, 'To prevent abnormal initialization ' \
                                 'behavior, init_cfg is not allowed to be set'
        super(DIIHead, self).__init__(num_classes=num_classes,
                                      reg_decoded_bbox=True,
                                      reg_class_agnostic=True,
                                      init_cfg=init_cfg,
                                      **kwargs)
        self.loss_iou = build_loss(loss_iou)
        self.in_channels = in_channels
        self.fp16_enabled = False
        self.attention = MultiheadAttention(in_channels, num_heads, dropout)
        self.attention_norm = build_norm_layer(dict(type='LN'), in_channels)[1]

        self.instance_interactive_conv = build_transformer(dynamic_conv_cfg)
        self.instance_interactive_conv_dropout = nn.Dropout(dropout)
        self.instance_interactive_conv_norm = build_norm_layer(
            dict(type='LN'), in_channels)[1]

        self.ffn = FFN(in_channels,
                       feedforward_channels,
                       num_ffn_fcs,
                       act_cfg=ffn_act_cfg,
                       dropout=dropout)
        self.ffn_norm = build_norm_layer(dict(type='LN'), in_channels)[1]

        self.cls_fcs = nn.ModuleList()
        for _ in range(num_cls_fcs):
            self.cls_fcs.append(nn.Linear(in_channels, in_channels,
                                          bias=False))
            self.cls_fcs.append(
                build_norm_layer(dict(type='LN'), in_channels)[1])
            self.cls_fcs.append(
                build_activation_layer(dict(type='ReLU', inplace=True)))

        # over load the self.fc_cls in BBoxHead
        if self.loss_cls.use_sigmoid:
            self.fc_cls = nn.Linear(in_channels, self.num_classes)
        else:
            self.fc_cls = nn.Linear(in_channels, self.num_classes + 1)

        self.reg_fcs = nn.ModuleList()
        for _ in range(num_reg_fcs):
            self.reg_fcs.append(nn.Linear(in_channels, in_channels,
                                          bias=False))
            self.reg_fcs.append(
                build_norm_layer(dict(type='LN'), in_channels)[1])
            self.reg_fcs.append(
                build_activation_layer(dict(type='ReLU', inplace=True)))
        # over load the self.fc_cls in BBoxHead
        self.fc_reg = nn.Linear(in_channels, 4)

        assert self.reg_class_agnostic, 'DIIHead only ' \
            'suppport `reg_class_agnostic=True` '
        assert self.reg_decoded_bbox, 'DIIHead only ' \
            'suppport `reg_decoded_bbox=True`'
Пример #12
0
    def __init__(self,
                 num_classes=150,
                 num_ffn_fcs=2,
                 num_heads=8,
                 num_mask_fcs=3,
                 feedforward_channels=2048,
                 in_channels=256,
                 out_channels=256,
                 dropout=0.0,
                 act_cfg=dict(type='ReLU', inplace=True),
                 ffn_act_cfg=dict(type='ReLU', inplace=True),
                 conv_kernel_size=1,
                 feat_transform_cfg=None,
                 kernel_init=False,
                 with_ffn=True,
                 feat_gather_stride=1,
                 mask_transform_stride=1,
                 kernel_updator_cfg=dict(type='DynamicConv',
                                         in_channels=256,
                                         feat_channels=64,
                                         out_channels=256,
                                         act_cfg=dict(type='ReLU',
                                                      inplace=True),
                                         norm_cfg=dict(type='LN'))):
        super(KernelUpdateHead, self).__init__()
        self.num_classes = num_classes
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.fp16_enabled = False
        self.dropout = dropout
        self.num_heads = num_heads
        self.kernel_init = kernel_init
        self.with_ffn = with_ffn
        self.conv_kernel_size = conv_kernel_size
        self.feat_gather_stride = feat_gather_stride
        self.mask_transform_stride = mask_transform_stride

        self.attention = MultiheadAttention(in_channels * conv_kernel_size**2,
                                            num_heads, dropout)
        self.attention_norm = build_norm_layer(
            dict(type='LN'), in_channels * conv_kernel_size**2)[1]
        self.kernel_update_conv = build_transformer_layer(kernel_updator_cfg)

        if feat_transform_cfg is not None:
            kernel_size = feat_transform_cfg.pop('kernel_size', 1)
            transform_channels = in_channels
            self.feat_transform = ConvModule(transform_channels,
                                             in_channels,
                                             kernel_size,
                                             stride=feat_gather_stride,
                                             padding=int(feat_gather_stride //
                                                         2),
                                             **feat_transform_cfg)
        else:
            self.feat_transform = None

        if self.with_ffn:
            self.ffn = FFN(in_channels,
                           feedforward_channels,
                           num_ffn_fcs,
                           act_cfg=ffn_act_cfg,
                           dropout=dropout)
            self.ffn_norm = build_norm_layer(dict(type='LN'), in_channels)[1]

        self.mask_fcs = nn.ModuleList()
        for _ in range(num_mask_fcs):
            self.mask_fcs.append(
                nn.Linear(in_channels, in_channels, bias=False))
            self.mask_fcs.append(
                build_norm_layer(dict(type='LN'), in_channels)[1])
            self.mask_fcs.append(build_activation_layer(act_cfg))

        self.fc_mask = nn.Linear(in_channels, out_channels)
class TransformerEncoderLayer(BaseModule):
    """Implements one encoder layer in Vision Transformer.

    Args:
        embed_dims (int): The feature dimension
        num_heads (int): Parallel attention heads
        feedforward_channels (int): The hidden dimension for FFNs
        drop_rate (float): Probability of an element to be zeroed
            after the feed forward layer. Defaults to 0.
        attn_drop_rate (float): The drop out rate for attention output weights.
            Defaults to 0.
        drop_path_rate (float): Stochastic depth rate. Defaults to 0.
        num_fcs (int): The number of fully-connected layers for FFNs.
            Defaults to 2.
        qkv_bias (bool): enable bias for qkv if True. Defaults to True.
        act_cfg (dict): The activation config for FFNs.
            Defaluts to ``dict(type='GELU')``.
        norm_cfg (dict): Config dict for normalization layer.
            Defaults to ``dict(type='LN')``.
        init_cfg (dict, optional): Initialization config dict.
            Defaults to None.
    """

    def __init__(self,
                 embed_dims,
                 num_heads,
                 feedforward_channels,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.,
                 num_fcs=2,
                 qkv_bias=True,
                 act_cfg=dict(type='GELU'),
                 norm_cfg=dict(type='LN'),
                 init_cfg=None):
        super(TransformerEncoderLayer, self).__init__(init_cfg=init_cfg)

        self.embed_dims = embed_dims

        self.norm1_name, norm1 = build_norm_layer(
            norm_cfg, self.embed_dims, postfix=1)
        self.add_module(self.norm1_name, norm1)

        self.attn = MultiheadAttention(
            embed_dims=embed_dims,
            num_heads=num_heads,
            attn_drop=attn_drop_rate,
            proj_drop=drop_rate,
            dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
            qkv_bias=qkv_bias)

        self.norm2_name, norm2 = build_norm_layer(
            norm_cfg, self.embed_dims, postfix=2)
        self.add_module(self.norm2_name, norm2)

        self.ffn = FFN(
            embed_dims=embed_dims,
            feedforward_channels=feedforward_channels,
            num_fcs=num_fcs,
            ffn_drop=drop_rate,
            dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
            act_cfg=act_cfg)

    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

    def init_weights(self):
        super(TransformerEncoderLayer, self).init_weights()
        for m in self.ffn.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.normal_(m.bias, std=1e-6)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = self.ffn(self.norm2(x), identity=x)
        return x
Пример #14
0
class MixerBlock(BaseModule):
    """Mlp-Mixer basic block.

    Basic module of `MLP-Mixer: An all-MLP Architecture for Vision
    <https://arxiv.org/pdf/2105.01601.pdf>`_

    Args:
        num_tokens (int): The number of patched tokens
        embed_dims (int): The feature dimension
        tokens_mlp_dims (int): The hidden dimension for tokens FFNs
        channels_mlp_dims (int): The hidden dimension for channels FFNs
        drop_rate (float): Probability of an element to be zeroed
            after the feed forward layer. Defaults to 0.
        drop_path_rate (float): Stochastic depth rate. Defaults to 0.
        num_fcs (int): The number of fully-connected layers for FFNs.
            Defaults to 2.
        act_cfg (dict): The activation config for FFNs.
            Defaluts to ``dict(type='GELU')``.
        norm_cfg (dict): Config dict for normalization layer.
            Defaults to ``dict(type='LN')``.
        init_cfg (dict, optional): Initialization config dict.
            Defaults to None.
    """

    def __init__(self,
                 num_tokens,
                 embed_dims,
                 tokens_mlp_dims,
                 channels_mlp_dims,
                 drop_rate=0.,
                 drop_path_rate=0.,
                 num_fcs=2,
                 act_cfg=dict(type='GELU'),
                 norm_cfg=dict(type='LN'),
                 init_cfg=None):
        super(MixerBlock, self).__init__(init_cfg=init_cfg)

        self.norm1_name, norm1 = build_norm_layer(
            norm_cfg, embed_dims, postfix=1)
        self.add_module(self.norm1_name, norm1)
        self.token_mix = FFN(
            embed_dims=num_tokens,
            feedforward_channels=tokens_mlp_dims,
            num_fcs=num_fcs,
            ffn_drop=drop_rate,
            dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
            act_cfg=act_cfg,
            add_identity=False)

        self.norm2_name, norm2 = build_norm_layer(
            norm_cfg, embed_dims, postfix=2)
        self.add_module(self.norm2_name, norm2)
        self.channel_mix = FFN(
            embed_dims=embed_dims,
            feedforward_channels=channels_mlp_dims,
            num_fcs=num_fcs,
            ffn_drop=drop_rate,
            dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
            act_cfg=act_cfg)

    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

    def init_weights(self):
        super(MixerBlock, self).init_weights()
        for m in self.token_mix.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.normal_(m.bias, std=1e-6)
        for m in self.channel_mix.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.normal_(m.bias, std=1e-6)

    def forward(self, x):
        out = self.norm1(x).transpose(1, 2)
        x = x + self.token_mix(out).transpose(1, 2)
        x = self.channel_mix(self.norm2(x), identity=x)
        return x
Пример #15
0
 def build_ffn(self, ffn_cfg):
     self.ffn = FFN(**ffn_cfg)