class TransformerEncoderLayer(BaseModule): """Implements one encoder layer in Vision Transformer. Args: embed_dims (int): The feature dimension num_heads (int): Parallel attention heads feedforward_channels (int): The hidden dimension for FFNs drop_rate (float): Probability of an element to be zeroed after the feed forward layer. Defaults to 0. attn_drop_rate (float): The drop out rate for attention output weights. Defaults to 0. drop_path_rate (float): Stochastic depth rate. Defaults to 0. num_fcs (int): The number of fully-connected layers for FFNs. Defaults to 2. qkv_bias (bool): enable bias for qkv if True. Defaults to True. act_cfg (dict): The activation config for FFNs. Defaluts to ``dict(type='GELU')``. norm_cfg (dict): Config dict for normalization layer. Defaults to ``dict(type='LN')``. init_cfg (dict, optional): Initialization config dict. Defaults to None. """ def __init__(self, embed_dims, num_heads, feedforward_channels, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., num_fcs=2, qkv_bias=True, act_cfg=dict(type='GELU'), norm_cfg=dict(type='LN'), init_cfg=None): super(TransformerEncoderLayer, self).__init__(init_cfg=init_cfg) self.embed_dims = embed_dims self.norm1_name, norm1 = build_norm_layer( norm_cfg, self.embed_dims, postfix=1) self.add_module(self.norm1_name, norm1) self.attn = MultiheadAttention( embed_dims=embed_dims, num_heads=num_heads, attn_drop=attn_drop_rate, proj_drop=drop_rate, dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), qkv_bias=qkv_bias) self.norm2_name, norm2 = build_norm_layer( norm_cfg, self.embed_dims, postfix=2) self.add_module(self.norm2_name, norm2) self.ffn = FFN( embed_dims=embed_dims, feedforward_channels=feedforward_channels, num_fcs=num_fcs, ffn_drop=drop_rate, dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), act_cfg=act_cfg) @property def norm1(self): return getattr(self, self.norm1_name) @property def norm2(self): return getattr(self, self.norm2_name) def init_weights(self): super(TransformerEncoderLayer, self).init_weights() for m in self.ffn.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) nn.init.normal_(m.bias, std=1e-6) def forward(self, x): x = x + self.attn(self.norm1(x)) x = self.ffn(self.norm2(x), identity=x) return x
class MixerBlock(BaseModule): """Mlp-Mixer basic block. Basic module of `MLP-Mixer: An all-MLP Architecture for Vision <https://arxiv.org/pdf/2105.01601.pdf>`_ Args: num_tokens (int): The number of patched tokens embed_dims (int): The feature dimension tokens_mlp_dims (int): The hidden dimension for tokens FFNs channels_mlp_dims (int): The hidden dimension for channels FFNs drop_rate (float): Probability of an element to be zeroed after the feed forward layer. Defaults to 0. drop_path_rate (float): Stochastic depth rate. Defaults to 0. num_fcs (int): The number of fully-connected layers for FFNs. Defaults to 2. act_cfg (dict): The activation config for FFNs. Defaluts to ``dict(type='GELU')``. norm_cfg (dict): Config dict for normalization layer. Defaults to ``dict(type='LN')``. init_cfg (dict, optional): Initialization config dict. Defaults to None. """ def __init__(self, num_tokens, embed_dims, tokens_mlp_dims, channels_mlp_dims, drop_rate=0., drop_path_rate=0., num_fcs=2, act_cfg=dict(type='GELU'), norm_cfg=dict(type='LN'), init_cfg=None): super(MixerBlock, self).__init__(init_cfg=init_cfg) self.norm1_name, norm1 = build_norm_layer( norm_cfg, embed_dims, postfix=1) self.add_module(self.norm1_name, norm1) self.token_mix = FFN( embed_dims=num_tokens, feedforward_channels=tokens_mlp_dims, num_fcs=num_fcs, ffn_drop=drop_rate, dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), act_cfg=act_cfg, add_identity=False) self.norm2_name, norm2 = build_norm_layer( norm_cfg, embed_dims, postfix=2) self.add_module(self.norm2_name, norm2) self.channel_mix = FFN( embed_dims=embed_dims, feedforward_channels=channels_mlp_dims, num_fcs=num_fcs, ffn_drop=drop_rate, dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), act_cfg=act_cfg) @property def norm1(self): return getattr(self, self.norm1_name) @property def norm2(self): return getattr(self, self.norm2_name) def init_weights(self): super(MixerBlock, self).init_weights() for m in self.token_mix.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) nn.init.normal_(m.bias, std=1e-6) for m in self.channel_mix.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) nn.init.normal_(m.bias, std=1e-6) def forward(self, x): out = self.norm1(x).transpose(1, 2) x = x + self.token_mix(out).transpose(1, 2) x = self.channel_mix(self.norm2(x), identity=x) return x