class TNT(BaseBackbone): """Transformer in Transformer. A PyTorch implement of: `Transformer in Transformer <https://arxiv.org/abs/2103.00112>`_ Inspiration from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/tnt.py Args: arch (str | dict): Vision Transformer architecture Default: 'b' img_size (int | tuple): Input image size. Default to 224 patch_size (int | tuple): The patch size. Deault to 16 in_channels (int): Number of input channels. Default to 3 ffn_ratio (int): A ratio to calculate the hidden_dims in ffn layer. Default: 4 qkv_bias (bool): Enable bias for qkv if True. Default False drop_rate (float): Probability of an element to be zeroed after the feed forward layer. Default 0. attn_drop_rate (float): The drop out rate for attention layer. Default 0. drop_path_rate (float): stochastic depth rate. Default 0. act_cfg (dict): The activation config for FFNs. Defaults to GELU. norm_cfg (dict): Config dict for normalization layer. Default layer normalization first_stride (int): The stride of the conv2d layer. We use a conv2d layer and a unfold layer to implement image to pixel embedding. num_fcs (int): The number of fully-connected layers for FFNs. Default 2 init_cfg (dict, optional): Initialization config dict """ arch_zoo = { **dict.fromkeys( ['s', 'small'], { 'embed_dims_outer': 384, 'embed_dims_inner': 24, 'num_layers': 12, 'num_heads_outer': 6, 'num_heads_inner': 4 }), **dict.fromkeys( ['b', 'base'], { 'embed_dims_outer': 640, 'embed_dims_inner': 40, 'num_layers': 12, 'num_heads_outer': 10, 'num_heads_inner': 4 }) } def __init__(self, arch='b', img_size=224, patch_size=16, in_channels=3, ffn_ratio=4, qkv_bias=False, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., act_cfg=dict(type='GELU'), norm_cfg=dict(type='LN'), first_stride=4, num_fcs=2, init_cfg=[ dict(type='TruncNormal', layer='Linear', std=.02), dict(type='Constant', layer='LayerNorm', val=1., bias=0.) ]): super(TNT, self).__init__(init_cfg=init_cfg) if isinstance(arch, str): arch = arch.lower() assert arch in set(self.arch_zoo), \ f'Arch {arch} is not in default archs {set(self.arch_zoo)}' self.arch_settings = self.arch_zoo[arch] else: essential_keys = { 'embed_dims_outer', 'embed_dims_inner', 'num_layers', 'num_heads_inner', 'num_heads_outer' } assert isinstance(arch, dict) and set(arch) == essential_keys, \ f'Custom arch needs a dict with keys {essential_keys}' self.arch_settings = arch self.embed_dims_inner = self.arch_settings['embed_dims_inner'] self.embed_dims_outer = self.arch_settings['embed_dims_outer'] # embed_dims for consistency with other models self.embed_dims = self.embed_dims_outer self.num_layers = self.arch_settings['num_layers'] self.num_heads_inner = self.arch_settings['num_heads_inner'] self.num_heads_outer = self.arch_settings['num_heads_outer'] self.pixel_embed = PixelEmbed(img_size=img_size, patch_size=patch_size, in_channels=in_channels, embed_dims_inner=self.embed_dims_inner, stride=first_stride) num_patches = self.pixel_embed.num_patches self.num_patches = num_patches new_patch_size = self.pixel_embed.new_patch_size num_pixel = new_patch_size[0] * new_patch_size[1] self.norm1_proj = build_norm_layer(norm_cfg, num_pixel * self.embed_dims_inner)[1] self.projection = nn.Linear(num_pixel * self.embed_dims_inner, self.embed_dims_outer) self.norm2_proj = build_norm_layer(norm_cfg, self.embed_dims_outer)[1] self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims_outer)) self.patch_pos = nn.Parameter( torch.zeros(1, num_patches + 1, self.embed_dims_outer)) self.pixel_pos = nn.Parameter( torch.zeros(1, self.embed_dims_inner, new_patch_size[0], new_patch_size[1])) self.drop_after_pos = nn.Dropout(p=drop_rate) dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers) ] # stochastic depth decay rule self.layers = ModuleList() for i in range(self.num_layers): block_cfg = dict(ffn_ratio=ffn_ratio, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=dpr[i], num_fcs=num_fcs, qkv_bias=qkv_bias, norm_cfg=norm_cfg, batch_first=True) self.layers.append( TnTLayer(num_pixel=num_pixel, embed_dims_inner=self.embed_dims_inner, embed_dims_outer=self.embed_dims_outer, num_heads_inner=self.num_heads_inner, num_heads_outer=self.num_heads_outer, inner_block_cfg=block_cfg, outer_block_cfg=block_cfg, norm_cfg=norm_cfg)) self.norm = build_norm_layer(norm_cfg, self.embed_dims_outer)[1] trunc_normal_(self.cls_token, std=.02) trunc_normal_(self.patch_pos, std=.02) trunc_normal_(self.pixel_pos, std=.02) def forward(self, x): B = x.shape[0] pixel_embed = self.pixel_embed(x, self.pixel_pos) patch_embed = self.norm2_proj( self.projection( self.norm1_proj(pixel_embed.reshape(B, self.num_patches, -1)))) patch_embed = torch.cat( (self.cls_token.expand(B, -1, -1), patch_embed), dim=1) patch_embed = patch_embed + self.patch_pos patch_embed = self.drop_after_pos(patch_embed) for layer in self.layers: pixel_embed, patch_embed = layer(pixel_embed, patch_embed) patch_embed = self.norm(patch_embed) return (patch_embed[:, 0], )
class BaseTransformerLayer(BaseModule): """Base `TransformerLayer` for vision transformer. It can be built from `mmcv.ConfigDict` and support more flexible customization, for example, using any number of `FFN or LN ` and use different kinds of `attention` by specifying a list of `ConfigDict` named `attn_cfgs`. It is worth mentioning that it supports `prenorm` when you specifying `norm` as the first element of `operation_order`. More details about the `prenorm`: `On Layer Normalization in the Transformer Architecture <https://arxiv.org/abs/2002.04745>`_ . Args: attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )): Configs for `self_attention` or `cross_attention` modules, The order of the configs in the list should be consistent with corresponding attentions in operation_order. If it is a dict, all of the attention modules in operation_order will be built with this config. Default: None. ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )): Configs for FFN, The order of the configs in the list should be consistent with corresponding ffn in operation_order. If it is a dict, all of the attention modules in operation_order will be built with this config. operation_order (tuple[str]): The execution order of operation in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm'). Support `prenorm` when you specifying first element as `norm`. Default:None. norm_cfg (dict): Config dict for normalization layer. Default: dict(type='LN'). init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Default: None. batch_first (bool): Key, Query and Value are shape of (batch, n, embed_dim) or (n, batch, embed_dim). Default to False. """ def __init__(self, attn_cfgs=None, ffn_cfgs=dict( type='FFN', embed_dims=256, feedforward_channels=1024, num_fcs=2, ffn_drop=0., act_cfg=dict(type='ReLU', inplace=True), ), operation_order=None, norm_cfg=dict(type='LN'), init_cfg=None, batch_first=False, **kwargs): deprecated_args = dict(feedforward_channels='feedforward_channels', ffn_dropout='ffn_drop', ffn_num_fcs='num_fcs') for ori_name, new_name in deprecated_args.items(): if ori_name in kwargs: warnings.warn( f'The arguments `{ori_name}` in BaseTransformerLayer ' f'has been deprecated, now you should set `{new_name}` ' f'and other FFN related arguments ' f'to a dict named `ffn_cfgs`. ', DeprecationWarning) ffn_cfgs[new_name] = kwargs[ori_name] super(BaseTransformerLayer, self).__init__(init_cfg) self.batch_first = batch_first assert set(operation_order) & set( ['self_attn', 'norm', 'ffn', 'cross_attn']) == \ set(operation_order), f'The operation_order of' \ f' {self.__class__.__name__} should ' \ f'contains all four operation type ' \ f"{['self_attn', 'norm', 'ffn', 'cross_attn']}" num_attn = operation_order.count('self_attn') + operation_order.count( 'cross_attn') if isinstance(attn_cfgs, dict): attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)] else: assert num_attn == len(attn_cfgs), f'The length ' \ f'of attn_cfg {num_attn} is ' \ f'not consistent with the number of attention' \ f'in operation_order {operation_order}.' self.num_attn = num_attn self.operation_order = operation_order self.norm_cfg = norm_cfg self.pre_norm = operation_order[0] == 'norm' self.attentions = ModuleList() index = 0 for operation_name in operation_order: if operation_name in ['self_attn', 'cross_attn']: if 'batch_first' in attn_cfgs[index]: assert self.batch_first == attn_cfgs[index]['batch_first'] else: attn_cfgs[index]['batch_first'] = self.batch_first attention = build_attention(attn_cfgs[index]) # Some custom attentions used as `self_attn` # or `cross_attn` can have different behavior. attention.operation_name = operation_name self.attentions.append(attention) index += 1 self.embed_dims = self.attentions[0].embed_dims self.ffns = ModuleList() num_ffns = operation_order.count('ffn') if isinstance(ffn_cfgs, dict): ffn_cfgs = ConfigDict(ffn_cfgs) if isinstance(ffn_cfgs, dict): ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)] assert len(ffn_cfgs) == num_ffns for ffn_index in range(num_ffns): if 'embed_dims' not in ffn_cfgs[ffn_index]: ffn_cfgs['embed_dims'] = self.embed_dims else: assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims self.ffns.append( build_feedforward_network(ffn_cfgs[ffn_index], dict(type='FFN'))) self.norms = ModuleList() num_norms = operation_order.count('norm') for _ in range(num_norms): self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1]) def forward(self, query, key=None, value=None, query_pos=None, key_pos=None, attn_masks=None, query_key_padding_mask=None, key_padding_mask=None, **kwargs): """Forward function for `TransformerDecoderLayer`. **kwargs contains some specific arguments of attentions. Args: query (Tensor): The input query with shape [num_queries, bs, embed_dims] if self.batch_first is False, else [bs, num_queries embed_dims]. key (Tensor): The key tensor with shape [num_keys, bs, embed_dims] if self.batch_first is False, else [bs, num_keys, embed_dims] . value (Tensor): The value tensor with same shape as `key`. query_pos (Tensor): The positional encoding for `query`. Default: None. key_pos (Tensor): The positional encoding for `key`. Default: None. attn_masks (List[Tensor] | None): 2D Tensor used in calculation of corresponding attention. The length of it should equal to the number of `attention` in `operation_order`. Default: None. query_key_padding_mask (Tensor): ByteTensor for `query`, with shape [bs, num_queries]. Only used in `self_attn` layer. Defaults to None. key_padding_mask (Tensor): ByteTensor for `query`, with shape [bs, num_keys]. Default: None. Returns: Tensor: forwarded results with shape [num_queries, bs, embed_dims]. """ norm_index = 0 attn_index = 0 ffn_index = 0 identity = query if attn_masks is None: attn_masks = [None for _ in range(self.num_attn)] elif isinstance(attn_masks, torch.Tensor): attn_masks = [ copy.deepcopy(attn_masks) for _ in range(self.num_attn) ] warnings.warn(f'Use same attn_mask in all attentions in ' f'{self.__class__.__name__} ') else: assert len(attn_masks) == self.num_attn, f'The length of ' \ f'attn_masks {len(attn_masks)} must be equal ' \ f'to the number of attention in ' \ f'operation_order {self.num_attn}' for layer in self.operation_order: if layer == 'self_attn': temp_key = temp_value = query query = self.attentions[attn_index]( query, temp_key, temp_value, identity if self.pre_norm else None, query_pos=query_pos, key_pos=query_pos, attn_mask=attn_masks[attn_index], key_padding_mask=query_key_padding_mask, **kwargs) attn_index += 1 identity = query elif layer == 'norm': query = self.norms[norm_index](query) norm_index += 1 elif layer == 'cross_attn': query = self.attentions[attn_index]( query, key, value, identity if self.pre_norm else None, query_pos=query_pos, key_pos=key_pos, attn_mask=attn_masks[attn_index], key_padding_mask=key_padding_mask, **kwargs) attn_index += 1 identity = query elif layer == 'ffn': query = self.ffns[ffn_index]( query, identity if self.pre_norm else None) ffn_index += 1 return query
class VisionTransformer(BaseBackbone): """Vision Transformer. A PyTorch implement of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale<https://arxiv.org/abs/2010.11929>`_ Args: arch (str | dict): Vision Transformer architecture Default: 'b' img_size (int | tuple): Input image size patch_size (int | tuple): The patch size out_indices (Sequence | int): Output from which stages. Defaults to -1, means the last stage. drop_rate (float): Probability of an element to be zeroed. Defaults to 0. drop_path_rate (float): stochastic depth rate. Defaults to 0. norm_cfg (dict): Config dict for normalization layer. Defaults to ``dict(type='LN')``. final_norm (bool): Whether to add a additional layer to normalize final feature map. Defaults to True. output_cls_token (bool): Whether output the cls_token. If set True, `with_cls_token` must be True. Defaults to True. interpolate_mode (str): Select the interpolate mode for position embeding vector resize. Defaults to "bicubic". patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. layer_cfgs (Sequence | dict): Configs of each transformer layer in encoder. Defaults to an empty dict. init_cfg (dict, optional): Initialization config dict. Defaults to None. """ arch_zoo = { **dict.fromkeys( ['s', 'small'], { 'embed_dims': 768, 'num_layers': 8, 'num_heads': 8, 'feedforward_channels': 768 * 3, 'qkv_bias': False }), **dict.fromkeys( ['b', 'base'], { 'embed_dims': 768, 'num_layers': 12, 'num_heads': 12, 'feedforward_channels': 3072 }), **dict.fromkeys( ['l', 'large'], { 'embed_dims': 1024, 'num_layers': 24, 'num_heads': 16, 'feedforward_channels': 4096 }), } def __init__(self, arch='b', img_size=224, patch_size=16, out_indices=-1, drop_rate=0., drop_path_rate=0., norm_cfg=dict(type='LN', eps=1e-6), final_norm=True, output_cls_token=True, interpolate_mode='bicubic', patch_cfg=dict(), layer_cfgs=dict(), init_cfg=None): super(VisionTransformer, self).__init__(init_cfg) if isinstance(arch, str): arch = arch.lower() assert arch in set(self.arch_zoo), \ f'Arch {arch} is not in default archs {set(self.arch_zoo)}' self.arch_settings = self.arch_zoo[arch] else: essential_keys = { 'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels' } assert isinstance(arch, dict) and set(arch) == essential_keys, \ f'Custom arch needs a dict with keys {essential_keys}' self.arch_settings = arch self.embed_dims = self.arch_settings['embed_dims'] self.num_layers = self.arch_settings['num_layers'] self.img_size = to_2tuple(img_size) # Set patch embedding _patch_cfg = dict( img_size=img_size, embed_dims=self.embed_dims, conv_cfg=dict( type='Conv2d', kernel_size=patch_size, stride=patch_size), ) _patch_cfg.update(patch_cfg) self.patch_embed = PatchEmbed(**_patch_cfg) num_patches = self.patch_embed.num_patches # Set cls token self.output_cls_token = output_cls_token self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) # Set position embedding self.interpolate_mode = interpolate_mode self.pos_embed = nn.Parameter( torch.zeros(1, num_patches + 1, self.embed_dims)) self.drop_after_pos = nn.Dropout(p=drop_rate) if isinstance(out_indices, int): out_indices = [out_indices] assert isinstance(out_indices, Sequence), \ f'"out_indices" must by a sequence or int, ' \ f'get {type(out_indices)} instead.' for i, index in enumerate(out_indices): if index < 0: out_indices[i] = self.num_layers + index assert out_indices[i] >= 0, f'Invalid out_indices {index}' self.out_indices = out_indices # stochastic depth decay rule dpr = np.linspace(0, drop_path_rate, self.arch_settings['num_layers']) self.layers = ModuleList() if isinstance(layer_cfgs, dict): layer_cfgs = [layer_cfgs] * self.num_layers for i in range(self.num_layers): _layer_cfg = dict( embed_dims=self.embed_dims, num_heads=self.arch_settings['num_heads'], feedforward_channels=self. arch_settings['feedforward_channels'], drop_rate=drop_rate, drop_path_rate=dpr[i], qkv_bias=self.arch_settings.get('qkv_bias', True), norm_cfg=norm_cfg) _layer_cfg.update(layer_cfgs[i]) self.layers.append(TransformerEncoderLayer(**_layer_cfg)) self.final_norm = final_norm if final_norm: self.norm1_name, norm1 = build_norm_layer( norm_cfg, self.embed_dims, postfix=1) self.add_module(self.norm1_name, norm1) @property def norm1(self): return getattr(self, self.norm1_name) def init_weights(self): # Suppress default init if use pretrained model. # And use custom load_checkpoint function to load checkpoint. if (isinstance(self.init_cfg, dict) and self.init_cfg['type'] == 'Pretrained'): init_cfg = deepcopy(self.init_cfg) init_cfg.pop('type') self._load_checkpoint(**init_cfg) else: super(VisionTransformer, self).init_weights() # Modified from ClassyVision nn.init.normal_(self.pos_embed, std=0.02) def _load_checkpoint(self, checkpoint, prefix=None, map_location=None): from mmcv.runner import (_load_checkpoint, _load_checkpoint_with_prefix, load_state_dict) from mmcv.utils import print_log logger = get_root_logger() if prefix is None: print_log(f'load model from: {checkpoint}', logger=logger) checkpoint = _load_checkpoint(checkpoint, map_location, logger) # get state_dict from checkpoint if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: state_dict = checkpoint else: print_log( f'load {prefix} in model from: {checkpoint}', logger=logger) state_dict = _load_checkpoint_with_prefix(prefix, checkpoint, map_location) if 'pos_embed' in state_dict.keys(): ckpt_pos_embed_shape = state_dict['pos_embed'].shape if self.pos_embed.shape != ckpt_pos_embed_shape: print_log( f'Resize the pos_embed shape from {ckpt_pos_embed_shape} ' f'to {self.pos_embed.shape}.', logger=logger) ckpt_pos_embed_shape = to_2tuple( int(np.sqrt(ckpt_pos_embed_shape[1] - 1))) pos_embed_shape = self.patch_embed.patches_resolution state_dict['pos_embed'] = self.resize_pos_embed( state_dict['pos_embed'], ckpt_pos_embed_shape, pos_embed_shape, self.interpolate_mode) # load state_dict load_state_dict(self, state_dict, strict=False, logger=logger) @staticmethod def resize_pos_embed(pos_embed, src_shape, dst_shape, mode='bicubic'): """Resize pos_embed weights. Args: pos_embed (torch.Tensor): Position embedding weights with shape [1, L, C]. src_shape (tuple): The resolution of downsampled origin training image. dst_shape (tuple): The resolution of downsampled new training image. mode (str): Algorithm used for upsampling: ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | ``'trilinear'``. Default: ``'bicubic'`` Return: torch.Tensor: The resized pos_embed of shape [1, L_new, C] """ assert pos_embed.ndim == 3, 'shape of pos_embed must be [1, L, C]' _, L, C = pos_embed.shape src_h, src_w = src_shape assert L == src_h * src_w + 1 cls_token = pos_embed[:, :1] src_weight = pos_embed[:, 1:] src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2) dst_weight = F.interpolate( src_weight, size=dst_shape, align_corners=False, mode=mode) dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2) return torch.cat((cls_token, dst_weight), dim=1) def forward(self, x): B = x.shape[0] x = self.patch_embed(x) patch_resolution = self.patch_embed.patches_resolution # stole cls_tokens impl from Phil Wang, thanks cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) x = x + self.pos_embed x = self.drop_after_pos(x) outs = [] for i, layer in enumerate(self.layers): x = layer(x) if i == len(self.layers) - 1 and self.final_norm: x = self.norm1(x) if i in self.out_indices: B, _, C = x.shape patch_token = x[:, 1:].reshape(B, *patch_resolution, C) patch_token = patch_token.permute(0, 3, 1, 2) cls_token = x[:, 0] if self.output_cls_token: out = [patch_token, cls_token] else: out = patch_token outs.append(out) return tuple(outs)
class T2T_ViT(BaseBackbone): """Tokens-to-Token Vision Transformer (T2T-ViT) A PyTorch implementation of `Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet <https://arxiv.org/abs/2101.11986>`_ Args: img_size (int | tuple): The expected input image shape. Because we support dynamic input shape, just set the argument to the most common input image shape. Defaults to 224. in_channels (int): Number of input channels. embed_dims (int): Embedding dimension. num_layers (int): Num of transformer layers in encoder. Defaults to 14. out_indices (Sequence | int): Output from which stages. Defaults to -1, means the last stage. drop_rate (float): Dropout rate after position embedding. Defaults to 0. drop_path_rate (float): stochastic depth rate. Defaults to 0. norm_cfg (dict): Config dict for normalization layer. Defaults to ``dict(type='LN')``. final_norm (bool): Whether to add a additional layer to normalize final feature map. Defaults to True. with_cls_token (bool): Whether concatenating class token into image tokens as transformer input. Defaults to True. output_cls_token (bool): Whether output the cls_token. If set True, ``with_cls_token`` must be True. Defaults to True. interpolate_mode (str): Select the interpolate mode for position embeding vector resize. Defaults to "bicubic". t2t_cfg (dict): Extra config of Tokens-to-Token module. Defaults to an empty dict. layer_cfgs (Sequence | dict): Configs of each transformer layer in encoder. Defaults to an empty dict. init_cfg (dict, optional): The Config for initialization. Defaults to None. """ num_extra_tokens = 1 # cls_token def __init__(self, img_size=224, in_channels=3, embed_dims=384, num_layers=14, out_indices=-1, drop_rate=0., drop_path_rate=0., norm_cfg=dict(type='LN'), final_norm=True, with_cls_token=True, output_cls_token=True, interpolate_mode='bicubic', t2t_cfg=dict(), layer_cfgs=dict(), init_cfg=None): super(T2T_ViT, self).__init__(init_cfg) # Token-to-Token Module self.tokens_to_token = T2TModule( img_size=img_size, in_channels=in_channels, embed_dims=embed_dims, **t2t_cfg) self.patch_resolution = self.tokens_to_token.init_out_size num_patches = self.patch_resolution[0] * self.patch_resolution[1] # Set cls token if output_cls_token: assert with_cls_token is True, f'with_cls_token must be True if' \ f'set output_cls_token to True, but got {with_cls_token}' self.with_cls_token = with_cls_token self.output_cls_token = output_cls_token self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) # Set position embedding self.interpolate_mode = interpolate_mode sinusoid_table = get_sinusoid_encoding( num_patches + self.num_extra_tokens, embed_dims) self.register_buffer('pos_embed', sinusoid_table) self._register_load_state_dict_pre_hook(self._prepare_pos_embed) self.drop_after_pos = nn.Dropout(p=drop_rate) if isinstance(out_indices, int): out_indices = [out_indices] assert isinstance(out_indices, Sequence), \ f'"out_indices" must be a sequence or int, ' \ f'get {type(out_indices)} instead.' for i, index in enumerate(out_indices): if index < 0: out_indices[i] = num_layers + index assert 0 <= out_indices[i] <= num_layers, \ f'Invalid out_indices {index}' self.out_indices = out_indices # stochastic depth decay rule dpr = [x for x in np.linspace(0, drop_path_rate, num_layers)] self.encoder = ModuleList() for i in range(num_layers): if isinstance(layer_cfgs, Sequence): layer_cfg = layer_cfgs[i] else: layer_cfg = deepcopy(layer_cfgs) layer_cfg = { 'embed_dims': embed_dims, 'num_heads': 6, 'feedforward_channels': 3 * embed_dims, 'drop_path_rate': dpr[i], 'qkv_bias': False, 'norm_cfg': norm_cfg, **layer_cfg } layer = T2TTransformerLayer(**layer_cfg) self.encoder.append(layer) self.final_norm = final_norm if final_norm: self.norm = build_norm_layer(norm_cfg, embed_dims)[1] else: self.norm = nn.Identity() def init_weights(self): super().init_weights() if (isinstance(self.init_cfg, dict) and self.init_cfg['type'] == 'Pretrained'): # Suppress custom init if use pretrained model. return trunc_normal_(self.cls_token, std=.02) def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs): name = prefix + 'pos_embed' if name not in state_dict.keys(): return ckpt_pos_embed_shape = state_dict[name].shape if self.pos_embed.shape != ckpt_pos_embed_shape: from mmcls.utils import get_root_logger logger = get_root_logger() logger.info( f'Resize the pos_embed shape from {ckpt_pos_embed_shape} ' f'to {self.pos_embed.shape}.') ckpt_pos_embed_shape = to_2tuple( int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens))) pos_embed_shape = self.tokens_to_token.init_out_size state_dict[name] = resize_pos_embed(state_dict[name], ckpt_pos_embed_shape, pos_embed_shape, self.interpolate_mode, self.num_extra_tokens) def forward(self, x): B = x.shape[0] x, patch_resolution = self.tokens_to_token(x) # stole cls_tokens impl from Phil Wang, thanks cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) x = x + resize_pos_embed( self.pos_embed, self.patch_resolution, patch_resolution, mode=self.interpolate_mode, num_extra_tokens=self.num_extra_tokens) x = self.drop_after_pos(x) if not self.with_cls_token: # Remove class token for transformer encoder input x = x[:, 1:] outs = [] for i, layer in enumerate(self.encoder): x = layer(x) if i == len(self.encoder) - 1 and self.final_norm: x = self.norm(x) if i in self.out_indices: B, _, C = x.shape if self.with_cls_token: patch_token = x[:, 1:].reshape(B, *patch_resolution, C) patch_token = patch_token.permute(0, 3, 1, 2) cls_token = x[:, 0] else: patch_token = x.reshape(B, *patch_resolution, C) patch_token = patch_token.permute(0, 3, 1, 2) cls_token = None if self.output_cls_token: out = [patch_token, cls_token] else: out = patch_token outs.append(out) return tuple(outs)
class SwinTransformer(BaseBackbone): """ Swin Transformer A PyTorch implement of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - https://arxiv.org/abs/2103.14030 Inspiration from https://github.com/microsoft/Swin-Transformer Args: arch (str | dict): Swin Transformer architecture Defaults to 'T'. img_size (int | tuple): The size of input image. Defaults to 224. in_channels (int): The num of input channels. Defaults to 3. drop_rate (float): Dropout rate after embedding. Defaults to 0. drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. use_abs_pos_embed (bool): If True, add absolute position embedding to the patch embedding. Defaults to False. auto_pad (bool): If True, auto pad feature map to fit window_size. Defaults to False. norm_cfg (dict, optional): Config dict for normalization layer at end of backone. Defaults to dict(type='LN') stage_cfgs (Sequence | dict, optional): Extra config dict for each stage. Defaults to empty dict. patch_cfg (dict, optional): Extra config dict for patch embedding. Defaults to empty dict. init_cfg (dict, optional): The Config for initialization. Defaults to None. Examples: >>> from mmcls.models import SwinTransformer >>> import torch >>> extra_config = dict( >>> arch='tiny', >>> stage_cfgs=dict(downsample_cfg={'kernel_size': 3, >>> 'expansion_ratio': 3}), >>> auto_pad=True) >>> self = SwinTransformer(**extra_config) >>> inputs = torch.rand(1, 3, 224, 224) >>> output = self.forward(inputs) >>> print(output.shape) (1, 2592, 4) """ arch_zoo = { **dict.fromkeys(['t', 'tiny'], {'embed_dims': 96, 'depths': [2, 2, 6, 2], 'num_heads': [3, 6, 12, 24]}), **dict.fromkeys(['s', 'small'], {'embed_dims': 96, 'depths': [2, 2, 18, 2], 'num_heads': [3, 6, 12, 24]}), **dict.fromkeys(['b', 'base'], {'embed_dims': 128, 'depths': [2, 2, 18, 2], 'num_heads': [4, 8, 16, 32]}), **dict.fromkeys(['l', 'large'], {'embed_dims': 192, 'depths': [2, 2, 18, 2], 'num_heads': [6, 12, 24, 48]}), } # yapf: disable def __init__(self, arch='T', img_size=224, in_channels=3, drop_rate=0., drop_path_rate=0.1, use_abs_pos_embed=False, auto_pad=False, norm_cfg=dict(type='LN'), stage_cfgs=dict(), patch_cfg=dict(), init_cfg=None): super(SwinTransformer, self).__init__(init_cfg) if isinstance(arch, str): arch = arch.lower() assert arch in set(self.arch_zoo), \ f'Arch {arch} is not in default archs {set(self.arch_zoo)}' self.arch_settings = self.arch_zoo[arch] else: essential_keys = {'embed_dims', 'depths', 'num_head'} assert isinstance(arch, dict) and set(arch) == essential_keys, \ f'Custom arch needs a dict with keys {essential_keys}' self.arch_settings = arch self.embed_dims = self.arch_settings['embed_dims'] self.depths = self.arch_settings['depths'] self.num_heads = self.arch_settings['num_heads'] self.num_layers = len(self.depths) self.use_abs_pos_embed = use_abs_pos_embed self.auto_pad = auto_pad _patch_cfg = dict(img_size=img_size, in_channels=in_channels, embed_dims=self.embed_dims, conv_cfg=dict(type='Conv2d', kernel_size=4, stride=4, padding=0, dilation=1), norm_cfg=dict(type='LN'), **patch_cfg) self.patch_embed = PatchEmbed(**_patch_cfg) num_patches = self.patch_embed.num_patches patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution if self.use_abs_pos_embed: self.absolute_pos_embed = nn.Parameter( torch.zeros(1, num_patches, self.embed_dims)) self.drop_after_pos = nn.Dropout(p=drop_rate) # stochastic depth total_depth = sum(self.depths) dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, total_depth) ] # stochastic depth decay rule self.stages = ModuleList() embed_dims = self.embed_dims input_resolution = patches_resolution for i, (depth, num_heads) in enumerate(zip(self.depths, self.num_heads)): if isinstance(stage_cfgs, Sequence): stage_cfg = stage_cfgs[i] else: stage_cfg = deepcopy(stage_cfgs) downsample = True if i < self.num_layers - 1 else False _stage_cfg = { 'embed_dims': embed_dims, 'depth': depth, 'num_heads': num_heads, 'downsample': downsample, 'input_resolution': input_resolution, 'drop_paths': dpr[:depth], 'auto_pad': auto_pad, **stage_cfg } stage = SwinBlockSequence(**_stage_cfg) self.stages.append(stage) dpr = dpr[depth:] if downsample: embed_dims = stage.downsample.out_channels input_resolution = stage.downsample.output_resolution if norm_cfg is not None: self.norm = build_norm_layer(norm_cfg, embed_dims)[1] else: self.norm = None def init_weights(self): super(SwinTransformer, self).init_weights() if self.use_abs_pos_embed: trunc_normal_(self.absolute_pos_embed, std=0.02) def forward(self, x): x = self.patch_embed(x) if self.use_abs_pos_embed: x = x + self.absolute_pos_embed x = self.drop_after_pos(x) for stage in self.stages: x = stage(x) x = self.norm(x) if self.norm else x return x.transpose(1, 2)
class T2T_ViT(BaseBackbone): """Tokens-to-Token Vision Transformer (T2T-ViT) A PyTorch implementation of `Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet<https://arxiv.org/abs/2101.11986>`_ Args: img_size (int): Input image size. in_channels (int): Number of input channels. embed_dims (int): Embedding dimension. t2t_cfg (dict): Extra config of Tokens-to-Token module. Defaults to an empty dict. drop_rate (float): Dropout rate after position embedding. Defaults to 0. num_layers (int): Num of transformer layers in encoder. Defaults to 14. out_indices (Sequence | int): Output from which stages. Defaults to -1, means the last stage. layer_cfgs (Sequence | dict): Configs of each transformer layer in encoder. Defaults to an empty dict. drop_path_rate (float): stochastic depth rate. Defaults to 0. norm_cfg (dict): Config dict for normalization layer. Defaults to ``dict(type='LN')``. final_norm (bool): Whether to add a additional layer to normalize final feature map. Defaults to True. output_cls_token (bool): Whether output the cls_token. Defaults to True. init_cfg (dict, optional): The Config for initialization. Defaults to None. """ def __init__(self, img_size=224, in_channels=3, embed_dims=384, t2t_cfg=dict(), drop_rate=0., num_layers=14, out_indices=-1, layer_cfgs=dict(), drop_path_rate=0., norm_cfg=dict(type='LN'), final_norm=True, output_cls_token=True, init_cfg=None): super(T2T_ViT, self).__init__(init_cfg) # Token-to-Token Module self.tokens_to_token = T2TModule(img_size=img_size, in_channels=in_channels, embed_dims=embed_dims, **t2t_cfg) num_patches = self.tokens_to_token.num_patches # Class token self.output_cls_token = output_cls_token self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) # Position Embedding sinusoid_table = get_sinusoid_encoding(num_patches + 1, embed_dims) self.register_buffer('pos_embed', sinusoid_table) self.drop_after_pos = nn.Dropout(p=drop_rate) if isinstance(out_indices, int): out_indices = [out_indices] assert isinstance(out_indices, Sequence), \ f'"out_indices" must by a sequence or int, ' \ f'get {type(out_indices)} instead.' for i, index in enumerate(out_indices): if index < 0: out_indices[i] = num_layers + index assert out_indices[i] >= 0, f'Invalid out_indices {index}' self.out_indices = out_indices dpr = [x for x in np.linspace(0, drop_path_rate, num_layers)] self.encoder = ModuleList() for i in range(num_layers): if isinstance(layer_cfgs, Sequence): layer_cfg = layer_cfgs[i] else: layer_cfg = deepcopy(layer_cfgs) layer_cfg = { 'embed_dims': embed_dims, 'num_heads': 6, 'feedforward_channels': 3 * embed_dims, 'drop_path_rate': dpr[i], 'qkv_bias': False, 'norm_cfg': norm_cfg, **layer_cfg } layer = T2TTransformerLayer(**layer_cfg) self.encoder.append(layer) self.final_norm = final_norm if final_norm: self.norm = build_norm_layer(norm_cfg, embed_dims)[1] else: self.norm = nn.Identity() def init_weights(self): super().init_weights() if (isinstance(self.init_cfg, dict) and self.init_cfg['type'] == 'Pretrained'): # Suppress custom init if use pretrained model. return trunc_normal_(self.cls_token, std=.02) def forward(self, x): B = x.shape[0] x = self.tokens_to_token(x) num_patches = self.tokens_to_token.num_patches patch_resolution = [int(np.sqrt(num_patches))] * 2 cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) x = x + self.pos_embed x = self.drop_after_pos(x) outs = [] for i, layer in enumerate(self.encoder): x = layer(x) if i == len(self.encoder) - 1 and self.final_norm: x = self.norm(x) if i in self.out_indices: B, _, C = x.shape patch_token = x[:, 1:].reshape(B, *patch_resolution, C) patch_token = patch_token.permute(0, 3, 1, 2) cls_token = x[:, 0] if self.output_cls_token: out = [patch_token, cls_token] else: out = patch_token outs.append(out) return tuple(outs)
class VisionTransformer(BaseBackbone): """Vision Transformer. A PyTorch implement of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_ Args: arch (str | dict): Vision Transformer architecture. If use string, choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small' and 'deit-base'. If use dict, it should have below keys: - **embed_dims** (int): The dimensions of embedding. - **num_layers** (int): The number of transformer encoder layers. - **num_heads** (int): The number of heads in attention modules. - **feedforward_channels** (int): The hidden dimensions in feedforward modules. Defaults to 'base'. img_size (int | tuple): The expected input image shape. Because we support dynamic input shape, just set the argument to the most common input image shape. Defaults to 224. patch_size (int | tuple): The patch size in patch embedding. Defaults to 16. in_channels (int): The num of input channels. Defaults to 3. out_indices (Sequence | int): Output from which stages. Defaults to -1, means the last stage. drop_rate (float): Probability of an element to be zeroed. Defaults to 0. drop_path_rate (float): stochastic depth rate. Defaults to 0. qkv_bias (bool): Whether to add bias for qkv in attention modules. Defaults to True. norm_cfg (dict): Config dict for normalization layer. Defaults to ``dict(type='LN')``. final_norm (bool): Whether to add a additional layer to normalize final feature map. Defaults to True. with_cls_token (bool): Whether concatenating class token into image tokens as transformer input. Defaults to True. output_cls_token (bool): Whether output the cls_token. If set True, ``with_cls_token`` must be True. Defaults to True. interpolate_mode (str): Select the interpolate mode for position embeding vector resize. Defaults to "bicubic". patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. layer_cfgs (Sequence | dict): Configs of each transformer layer in encoder. Defaults to an empty dict. init_cfg (dict, optional): Initialization config dict. Defaults to None. """ arch_zoo = { **dict.fromkeys( ['s', 'small'], { 'embed_dims': 768, 'num_layers': 8, 'num_heads': 8, 'feedforward_channels': 768 * 3, }), **dict.fromkeys( ['b', 'base'], { 'embed_dims': 768, 'num_layers': 12, 'num_heads': 12, 'feedforward_channels': 3072 }), **dict.fromkeys( ['l', 'large'], { 'embed_dims': 1024, 'num_layers': 24, 'num_heads': 16, 'feedforward_channels': 4096 }), **dict.fromkeys( ['deit-t', 'deit-tiny'], { 'embed_dims': 192, 'num_layers': 12, 'num_heads': 3, 'feedforward_channels': 192 * 4 }), **dict.fromkeys( ['deit-s', 'deit-small'], { 'embed_dims': 384, 'num_layers': 12, 'num_heads': 6, 'feedforward_channels': 384 * 4 }), **dict.fromkeys( ['deit-b', 'deit-base'], { 'embed_dims': 768, 'num_layers': 12, 'num_heads': 12, 'feedforward_channels': 768 * 4 }), } # Some structures have multiple extra tokens, like DeiT. num_extra_tokens = 1 # cls_token def __init__(self, arch='base', img_size=224, patch_size=16, in_channels=3, out_indices=-1, drop_rate=0., drop_path_rate=0., qkv_bias=True, norm_cfg=dict(type='LN', eps=1e-6), final_norm=True, with_cls_token=True, output_cls_token=True, interpolate_mode='bicubic', patch_cfg=dict(), layer_cfgs=dict(), init_cfg=None): super(VisionTransformer, self).__init__(init_cfg) if isinstance(arch, str): arch = arch.lower() assert arch in set(self.arch_zoo), \ f'Arch {arch} is not in default archs {set(self.arch_zoo)}' self.arch_settings = self.arch_zoo[arch] else: essential_keys = { 'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels' } assert isinstance(arch, dict) and essential_keys <= set(arch), \ f'Custom arch needs a dict with keys {essential_keys}' self.arch_settings = arch self.embed_dims = self.arch_settings['embed_dims'] self.num_layers = self.arch_settings['num_layers'] self.img_size = to_2tuple(img_size) # Set patch embedding _patch_cfg = dict( in_channels=in_channels, input_size=img_size, embed_dims=self.embed_dims, conv_type='Conv2d', kernel_size=patch_size, stride=patch_size, ) _patch_cfg.update(patch_cfg) self.patch_embed = PatchEmbed(**_patch_cfg) self.patch_resolution = self.patch_embed.init_out_size num_patches = self.patch_resolution[0] * self.patch_resolution[1] # Set cls token if output_cls_token: assert with_cls_token is True, f'with_cls_token must be True if' \ f'set output_cls_token to True, but got {with_cls_token}' self.with_cls_token = with_cls_token self.output_cls_token = output_cls_token self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) # Set position embedding self.interpolate_mode = interpolate_mode self.pos_embed = nn.Parameter( torch.zeros(1, num_patches + self.num_extra_tokens, self.embed_dims)) self._register_load_state_dict_pre_hook(self._prepare_pos_embed) self.drop_after_pos = nn.Dropout(p=drop_rate) if isinstance(out_indices, int): out_indices = [out_indices] assert isinstance(out_indices, Sequence), \ f'"out_indices" must by a sequence or int, ' \ f'get {type(out_indices)} instead.' for i, index in enumerate(out_indices): if index < 0: out_indices[i] = self.num_layers + index assert 0 <= out_indices[i] <= self.num_layers, \ f'Invalid out_indices {index}' self.out_indices = out_indices # stochastic depth decay rule dpr = np.linspace(0, drop_path_rate, self.num_layers) self.layers = ModuleList() if isinstance(layer_cfgs, dict): layer_cfgs = [layer_cfgs] * self.num_layers for i in range(self.num_layers): _layer_cfg = dict(embed_dims=self.embed_dims, num_heads=self.arch_settings['num_heads'], feedforward_channels=self. arch_settings['feedforward_channels'], drop_rate=drop_rate, drop_path_rate=dpr[i], qkv_bias=qkv_bias, norm_cfg=norm_cfg) _layer_cfg.update(layer_cfgs[i]) self.layers.append(TransformerEncoderLayer(**_layer_cfg)) self.final_norm = final_norm if final_norm: self.norm1_name, norm1 = build_norm_layer(norm_cfg, self.embed_dims, postfix=1) self.add_module(self.norm1_name, norm1) @property def norm1(self): return getattr(self, self.norm1_name) def init_weights(self): super(VisionTransformer, self).init_weights() if not (isinstance(self.init_cfg, dict) and self.init_cfg['type'] == 'Pretrained'): trunc_normal_(self.pos_embed, std=0.02) def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs): name = prefix + 'pos_embed' if name not in state_dict.keys(): return ckpt_pos_embed_shape = state_dict[name].shape if self.pos_embed.shape != ckpt_pos_embed_shape: from mmcv.utils import print_log logger = get_root_logger() print_log( f'Resize the pos_embed shape from {ckpt_pos_embed_shape} ' f'to {self.pos_embed.shape}.', logger=logger) ckpt_pos_embed_shape = to_2tuple( int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens))) pos_embed_shape = self.patch_embed.init_out_size state_dict[name] = resize_pos_embed(state_dict[name], ckpt_pos_embed_shape, pos_embed_shape, self.interpolate_mode, self.num_extra_tokens) @staticmethod def resize_pos_embed(*args, **kwargs): """Interface for backward-compatibility.""" return resize_pos_embed(*args, **kwargs) def forward(self, x): B = x.shape[0] x, patch_resolution = self.patch_embed(x) # stole cls_tokens impl from Phil Wang, thanks cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) x = x + resize_pos_embed(self.pos_embed, self.patch_resolution, patch_resolution, mode=self.interpolate_mode, num_extra_tokens=self.num_extra_tokens) x = self.drop_after_pos(x) if not self.with_cls_token: # Remove class token for transformer encoder input x = x[:, 1:] outs = [] for i, layer in enumerate(self.layers): x = layer(x) if i == len(self.layers) - 1 and self.final_norm: x = self.norm1(x) if i in self.out_indices: B, _, C = x.shape if self.with_cls_token: patch_token = x[:, 1:].reshape(B, *patch_resolution, C) patch_token = patch_token.permute(0, 3, 1, 2) cls_token = x[:, 0] else: patch_token = x.reshape(B, *patch_resolution, C) patch_token = patch_token.permute(0, 3, 1, 2) cls_token = None if self.output_cls_token: out = [patch_token, cls_token] else: out = patch_token outs.append(out) return tuple(outs)
class STDCNet(BaseModule): """This backbone is the implementation of `Rethinking BiSeNet For Real-time Semantic Segmentation <https://arxiv.org/abs/2104.13188>`_. Args: stdc_type (int): The type of backbone structure, `STDCNet1` and`STDCNet2` denotes two main backbones in paper, whose FLOPs is 813M and 1446M, respectively. in_channels (int): The num of input_channels. channels (tuple[int]): The output channels for each stage. bottleneck_type (str): The type of STDC Module type, the value must be 'add' or 'cat'. norm_cfg (dict): Config dict for normalization layer. act_cfg (dict): The activation config for conv layers. num_convs (int): Numbers of conv layer at each STDC Module. Default: 4. with_final_conv (bool): Whether add a conv layer at the Module output. Default: True. pretrained (str, optional): Model pretrained path. Default: None. init_cfg (dict or list[dict], optional): Initialization config dict. Default: None. Example: >>> import torch >>> stdc_type = 'STDCNet1' >>> in_channels = 3 >>> channels = (32, 64, 256, 512, 1024) >>> bottleneck_type = 'cat' >>> inputs = torch.rand(1, 3, 1024, 2048) >>> self = STDCNet(stdc_type, in_channels, ... channels, bottleneck_type).eval() >>> outputs = self.forward(inputs) >>> for i in range(len(outputs)): ... print(f'outputs[{i}].shape = {outputs[i].shape}') outputs[0].shape = torch.Size([1, 256, 128, 256]) outputs[1].shape = torch.Size([1, 512, 64, 128]) outputs[2].shape = torch.Size([1, 1024, 32, 64]) """ arch_settings = { 'STDCNet1': [(2, 1), (2, 1), (2, 1)], 'STDCNet2': [(2, 1, 1, 1), (2, 1, 1, 1, 1), (2, 1, 1)] } def __init__(self, stdc_type, in_channels, channels, bottleneck_type, norm_cfg, act_cfg, num_convs=4, with_final_conv=False, pretrained=None, init_cfg=None): super(STDCNet, self).__init__(init_cfg=init_cfg) assert stdc_type in self.arch_settings, \ f'invalid structure {stdc_type} for STDCNet.' assert bottleneck_type in ['add', 'cat'],\ f'bottleneck_type must be `add` or `cat`, got {bottleneck_type}' assert len(channels) == 5,\ f'invalid channels length {len(channels)} for STDCNet.' self.in_channels = in_channels self.channels = channels self.stage_strides = self.arch_settings[stdc_type] self.prtrained = pretrained self.num_convs = num_convs self.with_final_conv = with_final_conv self.stages = ModuleList([ ConvModule(self.in_channels, self.channels[0], kernel_size=3, stride=2, padding=1, norm_cfg=norm_cfg, act_cfg=act_cfg), ConvModule(self.channels[0], self.channels[1], kernel_size=3, stride=2, padding=1, norm_cfg=norm_cfg, act_cfg=act_cfg) ]) # `self.num_shallow_features` is the number of shallow modules in # `STDCNet`, which is noted as `Stage1` and `Stage2` in original paper. # They are both not used for following modules like Attention # Refinement Module and Feature Fusion Module. # Thus they would be cut from `outs`. Please refer to Figure 4 # of original paper for more details. self.num_shallow_features = len(self.stages) for strides in self.stage_strides: idx = len(self.stages) - 1 self.stages.append( self._make_stage(self.channels[idx], self.channels[idx + 1], strides, norm_cfg, act_cfg, bottleneck_type)) # After appending, `self.stages` is a ModuleList including several # shallow modules and STDCModules. # (len(self.stages) == # self.num_shallow_features + len(self.stage_strides)) if self.with_final_conv: self.final_conv = ConvModule(self.channels[-1], max(1024, self.channels[-1]), 1, norm_cfg=norm_cfg, act_cfg=act_cfg) def _make_stage(self, in_channels, out_channels, strides, norm_cfg, act_cfg, bottleneck_type): layers = [] for i, stride in enumerate(strides): layers.append( STDCModule(in_channels if i == 0 else out_channels, out_channels, stride, norm_cfg, act_cfg, num_convs=self.num_convs, fusion_type=bottleneck_type)) return Sequential(*layers) def forward(self, x): outs = [] for stage in self.stages: x = stage(x) outs.append(x) if self.with_final_conv: outs[-1] = self.final_conv(outs[-1]) outs = outs[self.num_shallow_features:] return tuple(outs)
class SwinTransformer(BaseModule): """ Swin Transformer A PyTorch implement of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - https://arxiv.org/abs/2103.14030 Inspiration from https://github.com/microsoft/Swin-Transformer Args: pretrain_img_size (int | tuple[int]): The size of input image when pretrain. Defaults: 224. in_channels (int): The num of input channels. Defaults: 3. embed_dims (int): The feature dimension. Default: 96. patch_size (int | tuple[int]): Patch size. Default: 4. window_size (int): Window size. Default: 7. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. depths (tuple[int]): Depths of each Swin Transformer stage. Default: (2, 2, 6, 2). num_heads (tuple[int]): Parallel attention heads of each Swin Transformer stage. Default: (3, 6, 12, 24). strides (tuple[int]): The patch merging or patch embedding stride of each Swin Transformer stage. (In swin, we set kernel size equal to stride.) Default: (4, 2, 2, 2). out_indices (tuple[int]): Output from which stages. Default: (0, 1, 2, 3). qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. Default: None. patch_norm (bool): If add a norm layer for patch embed and patch merging. Default: True. drop_rate (float): Dropout rate. Defaults: 0. attn_drop_rate (float): Attention dropout rate. Default: 0. drop_path_rate (float): Stochastic depth rate. Defaults: 0.1. use_abs_pos_embed (bool): If True, add absolute position embedding to the patch embedding. Defaults: False. act_cfg (dict): Config dict for activation layer. Default: dict(type='LN'). norm_cfg (dict): Config dict for normalization layer at output of backone. Defaults: dict(type='LN'). pretrain_style (str): Choose to use official or mmcls pretrain weights. Default: official. pretrained (str, optional): model pretrained path. Default: None. init_cfg (dict, optional): The Config for initialization. Defaults to None. """ def __init__(self, pretrain_img_size=224, in_channels=3, embed_dims=96, patch_size=4, window_size=7, mlp_ratio=4, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), strides=(4, 2, 2, 2), out_indices=(0, 1, 2, 3), qkv_bias=True, qk_scale=None, patch_norm=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, use_abs_pos_embed=False, act_cfg=dict(type='GELU'), norm_cfg=dict(type='LN'), pretrain_style='official', pretrained=None, init_cfg=None): super(SwinTransformer, self).__init__() if isinstance(pretrain_img_size, int): pretrain_img_size = to_2tuple(pretrain_img_size) elif isinstance(pretrain_img_size, tuple): if len(pretrain_img_size) == 1: pretrain_img_size = to_2tuple(pretrain_img_size[0]) assert len(pretrain_img_size) == 2, \ f'The size of image should have length 1 or 2, ' \ f'but got {len(pretrain_img_size)}' assert pretrain_style in ['official', 'mmcls'], 'We only support load ' 'official ckpt and mmcls ckpt.' if isinstance(pretrained, str) or pretrained is None: warnings.warn('DeprecationWarning: pretrained is a deprecated, ' 'please use "init_cfg" instead') else: raise TypeError('pretrained must be a str or None') num_layers = len(depths) self.out_indices = out_indices self.use_abs_pos_embed = use_abs_pos_embed self.pretrain_style = pretrain_style self.pretrained = pretrained self.init_cfg = init_cfg assert strides[0] == patch_size, 'Use non-overlapping patch embed.' self.patch_embed = PatchEmbed( in_channels=in_channels, embed_dims=embed_dims, conv_type='Conv2d', kernel_size=patch_size, stride=strides[0], pad_to_patch_size=True, norm_cfg=norm_cfg if patch_norm else None, init_cfg=None) if self.use_abs_pos_embed: patch_row = pretrain_img_size[0] // patch_size patch_col = pretrain_img_size[1] // patch_size num_patches = patch_row * patch_col self.absolute_pos_embed = nn.Parameter( torch.zeros((1, num_patches, embed_dims))) self.drop_after_pos = nn.Dropout(p=drop_rate) # stochastic depth total_depth = sum(depths) dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, total_depth) ] # stochastic depth decay rule self.stages = ModuleList() in_channels = embed_dims for i in range(num_layers): if i < num_layers - 1: downsample = PatchMerging( in_channels=in_channels, out_channels=2 * in_channels, stride=strides[i + 1], norm_cfg=norm_cfg if patch_norm else None, init_cfg=None) else: downsample = None stage = SwinBlockSequence(embed_dims=in_channels, num_heads=num_heads[i], feedforward_channels=mlp_ratio * in_channels, depth=depths[i], window_size=window_size, qkv_bias=qkv_bias, qk_scale=qk_scale, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=dpr[:depths[i]], downsample=downsample, act_cfg=act_cfg, norm_cfg=norm_cfg, init_cfg=None) self.stages.append(stage) dpr = dpr[depths[i]:] if downsample: in_channels = downsample.out_channels self.num_features = [int(embed_dims * 2**i) for i in range(num_layers)] # Add a norm layer for each output for i in out_indices: layer = build_norm_layer(norm_cfg, self.num_features[i])[1] layer_name = f'norm{i}' self.add_module(layer_name, layer) def init_weights(self): if self.pretrained is None: super().init_weights() if self.use_abs_pos_embed: trunc_normal_init(self.absolute_pos_embed, std=0.02) for m in self.modules(): if isinstance(m, Linear): trunc_normal_init(m.weight, std=.02) if m.bias is not None: constant_init(m.bias, 0) elif isinstance(m, LayerNorm): constant_init(m.bias, 0) constant_init(m.weight, 1.0) elif isinstance(self.pretrained, str): logger = get_root_logger() ckpt = _load_checkpoint(self.pretrained, logger=logger, map_location='cpu') if 'state_dict' in ckpt: state_dict = ckpt['state_dict'] elif 'model' in ckpt: state_dict = ckpt['model'] else: state_dict = ckpt if self.pretrain_style == 'official': state_dict = swin_convert(state_dict) # strip prefix of state_dict if list(state_dict.keys())[0].startswith('module.'): state_dict = {k[7:]: v for k, v in state_dict.items()} # reshape absolute position embedding if state_dict.get('absolute_pos_embed') is not None: absolute_pos_embed = state_dict['absolute_pos_embed'] N1, L, C1 = absolute_pos_embed.size() N2, C2, H, W = self.absolute_pos_embed.size() if N1 != N2 or C1 != C2 or L != H * W: logger.warning('Error in loading absolute_pos_embed, pass') else: state_dict['absolute_pos_embed'] = absolute_pos_embed.view( N2, H, W, C2).permute(0, 3, 1, 2).contiguous() # interpolate position bias table if needed relative_position_bias_table_keys = [ k for k in state_dict.keys() if 'relative_position_bias_table' in k ] for table_key in relative_position_bias_table_keys: table_pretrained = state_dict[table_key] table_current = self.state_dict()[table_key] L1, nH1 = table_pretrained.size() L2, nH2 = table_current.size() if nH1 != nH2: logger.warning(f'Error in loading {table_key}, pass') else: if L1 != L2: S1 = int(L1**0.5) S2 = int(L2**0.5) table_pretrained_resized = F.interpolate( table_pretrained.permute(1, 0).reshape( 1, nH1, S1, S1), size=(S2, S2), mode='bicubic') state_dict[table_key] = table_pretrained_resized.view( nH2, L2).permute(1, 0).contiguous() # load state_dict self.load_state_dict(state_dict, False) def forward(self, x): x = self.patch_embed(x) hw_shape = (self.patch_embed.DH, self.patch_embed.DW) if self.use_abs_pos_embed: x = x + self.absolute_pos_embed x = self.drop_after_pos(x) outs = [] for i, stage in enumerate(self.stages): x, hw_shape, out, out_hw_shape = stage(x, hw_shape) if i in self.out_indices: norm_layer = getattr(self, f'norm{i}') out = norm_layer(out) out = out.view(-1, *out_hw_shape, self.num_features[i]).permute(0, 3, 1, 2).contiguous() outs.append(out) return outs
class ConvNeXt(BaseBackbone): """ConvNeXt. A PyTorch implementation of : `A ConvNet for the 2020s <https://arxiv.org/pdf/2201.03545.pdf>`_ Modified from the `official repo <https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py>`_ and `timm <https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/convnext.py>`_. Args: arch (str | dict): The model's architecture. If string, it should be one of architecture in ``ConvNeXt.arch_settings``. And if dict, it should include the following two keys: - depths (list[int]): Number of blocks at each stage. - channels (list[int]): The number of channels at each stage. Defaults to 'tiny'. in_channels (int): Number of input image channels. Defaults to 3. stem_patch_size (int): The size of one patch in the stem layer. Defaults to 4. norm_cfg (dict): The config dict for norm layers. Defaults to ``dict(type='LN2d', eps=1e-6)``. act_cfg (dict): The config dict for activation between pointwise convolution. Defaults to ``dict(type='GELU')``. linear_pw_conv (bool): Whether to use linear layer to do pointwise convolution. Defaults to True. drop_path_rate (float): Stochastic depth rate. Defaults to 0. layer_scale_init_value (float): Init value for Layer Scale. Defaults to 1e-6. out_indices (Sequence | int): Output from which stages. Defaults to -1, means the last stage. frozen_stages (int): Stages to be frozen (all param fixed). Defaults to 0, which means not freezing any parameters. gap_before_final_norm (bool): Whether to globally average the feature map before the final norm layer. In the official repo, it's only used in classification task. Defaults to True. init_cfg (dict, optional): Initialization config dict """ # noqa: E501 arch_settings = { 'tiny': { 'depths': [3, 3, 9, 3], 'channels': [96, 192, 384, 768] }, 'small': { 'depths': [3, 3, 27, 3], 'channels': [96, 192, 384, 768] }, 'base': { 'depths': [3, 3, 27, 3], 'channels': [128, 256, 512, 1024] }, 'large': { 'depths': [3, 3, 27, 3], 'channels': [192, 384, 768, 1536] }, 'xlarge': { 'depths': [3, 3, 27, 3], 'channels': [256, 512, 1024, 2048] }, } def __init__(self, arch='tiny', in_channels=3, stem_patch_size=4, norm_cfg=dict(type='LN2d', eps=1e-6), act_cfg=dict(type='GELU'), linear_pw_conv=True, drop_path_rate=0., layer_scale_init_value=1e-6, out_indices=-1, frozen_stages=0, gap_before_final_norm=True, init_cfg=None): super().__init__(init_cfg=init_cfg) if isinstance(arch, str): assert arch in self.arch_settings, \ f'Unavailable arch, please choose from ' \ f'({set(self.arch_settings)}) or pass a dict.' arch = self.arch_settings[arch] elif isinstance(arch, dict): assert 'depths' in arch and 'channels' in arch, \ f'The arch dict must have "depths" and "channels", ' \ f'but got {list(arch.keys())}.' self.depths = arch['depths'] self.channels = arch['channels'] assert (isinstance(self.depths, Sequence) and isinstance(self.channels, Sequence) and len(self.depths) == len(self.channels)), \ f'The "depths" ({self.depths}) and "channels" ({self.channels}) ' \ 'should be both sequence with the same length.' self.num_stages = len(self.depths) if isinstance(out_indices, int): out_indices = [out_indices] assert isinstance(out_indices, Sequence), \ f'"out_indices" must by a sequence or int, ' \ f'get {type(out_indices)} instead.' for i, index in enumerate(out_indices): if index < 0: out_indices[i] = 4 + index assert out_indices[i] >= 0, f'Invalid out_indices {index}' self.out_indices = out_indices self.frozen_stages = frozen_stages self.gap_before_final_norm = gap_before_final_norm # stochastic depth decay rule dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths)) ] block_idx = 0 # 4 downsample layers between stages, including the stem layer. self.downsample_layers = ModuleList() stem = nn.Sequential( nn.Conv2d( in_channels, self.channels[0], kernel_size=stem_patch_size, stride=stem_patch_size), build_norm_layer(norm_cfg, self.channels[0])[1], ) self.downsample_layers.append(stem) # 4 feature resolution stages, each consisting of multiple residual # blocks self.stages = nn.ModuleList() for i in range(self.num_stages): depth = self.depths[i] channels = self.channels[i] if i >= 1: downsample_layer = nn.Sequential( LayerNorm2d(self.channels[i - 1]), nn.Conv2d( self.channels[i - 1], channels, kernel_size=2, stride=2), ) self.downsample_layers.append(downsample_layer) stage = Sequential(*[ ConvNeXtBlock( in_channels=channels, drop_path_rate=dpr[block_idx + j], norm_cfg=norm_cfg, act_cfg=act_cfg, linear_pw_conv=linear_pw_conv, layer_scale_init_value=layer_scale_init_value) for j in range(depth) ]) block_idx += depth self.stages.append(stage) if i in self.out_indices: norm_layer = build_norm_layer(norm_cfg, channels)[1] self.add_module(f'norm{i}', norm_layer) self._freeze_stages() def forward(self, x): outs = [] for i, stage in enumerate(self.stages): x = self.downsample_layers[i](x) x = stage(x) if i in self.out_indices: norm_layer = getattr(self, f'norm{i}') if self.gap_before_final_norm: gap = x.mean([-2, -1], keepdim=True) outs.append(norm_layer(gap).flatten(1)) else: outs.append(norm_layer(x)) return tuple(outs) def _freeze_stages(self): for i in range(self.frozen_stages): downsample_layer = self.downsample_layers[i] stage = self.stages[i] downsample_layer.eval() stage.eval() for param in chain(downsample_layer.parameters(), stage.parameters()): param.requires_grad = False def train(self, mode=True): super(ConvNeXt, self).train(mode) self._freeze_stages()
class SwinBlockSequence(BaseModule): """Implements one stage in Swin Transformer. Args: embed_dims (int): The feature dimension. num_heads (int): Parallel attention heads. feedforward_channels (int): The hidden dimension for FFNs. depth (int): The number of blocks in this stage. window size (int): The local window scale. Default: 7. qkv_bias (int): enable bias for qkv if True. Default: True. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. Default: None. drop_rate (float, optional): Dropout rate. Default: 0. attn_drop_rate (float, optional): Attention dropout rate. Default: 0. drop_path_rate (float, optional): Stochastic depth rate. Default: 0.2. downsample (BaseModule | None, optional): The downsample operation module. Default: None. act_cfg (dict, optional): The config dict of activation function. Default: dict(type='GELU'). norm_cfg (dict, optional): The config dict of nomalization. Default: dict(type='LN'). init_cfg (dict | list | None, optional): The init config. Default: None. """ def __init__(self, embed_dims, num_heads, feedforward_channels, depth, window_size=7, qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., downsample=None, act_cfg=dict(type='GELU'), norm_cfg=dict(type='LN'), init_cfg=None): super().__init__() self.init_cfg = init_cfg drop_path_rate = drop_path_rate if isinstance( drop_path_rate, list) else [deepcopy(drop_path_rate) for _ in range(depth)] self.blocks = ModuleList() for i in range(depth): block = SwinBlock(embed_dims=embed_dims, num_heads=num_heads, feedforward_channels=feedforward_channels, window_size=window_size, shift=False if i % 2 == 0 else True, qkv_bias=qkv_bias, qk_scale=qk_scale, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate[i], act_cfg=act_cfg, norm_cfg=norm_cfg, init_cfg=None) self.blocks.append(block) self.downsample = downsample def forward(self, x, hw_shape): for block in self.blocks: x = block(x, hw_shape) if self.downsample: x_down, down_hw_shape = self.downsample(x, hw_shape) return x_down, down_hw_shape, x, hw_shape else: return x, hw_shape, x, hw_shape
class MlpMixer(BaseBackbone): """Mlp-Mixer backbone. Pytorch implementation of `MLP-Mixer: An all-MLP Architecture for Vision <https://arxiv.org/pdf/2105.01601.pdf>`_ Args: arch (str | dict): MLP Mixer architecture Defaults to 'b'. img_size (int | tuple): Input image size. patch_size (int | tuple): The patch size. out_indices (Sequence | int): Output from which layer. Defaults to -1, means the last layer. drop_rate (float): Probability of an element to be zeroed. Defaults to 0. drop_path_rate (float): stochastic depth rate. Defaults to 0. norm_cfg (dict): Config dict for normalization layer. Defaults to ``dict(type='LN')``. act_cfg (dict): The activation config for FFNs. Default GELU. patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. layer_cfgs (Sequence | dict): Configs of each mixer block layer. Defaults to an empty dict. init_cfg (dict, optional): Initialization config dict. Defaults to None. """ arch_zoo = { **dict.fromkeys( ['s', 'small'], { 'embed_dims': 512, 'num_layers': 8, 'tokens_mlp_dims': 256, 'channels_mlp_dims': 2048, }), **dict.fromkeys( ['b', 'base'], { 'embed_dims': 768, 'num_layers': 12, 'tokens_mlp_dims': 384, 'channels_mlp_dims': 3072, }), **dict.fromkeys( ['l', 'large'], { 'embed_dims': 1024, 'num_layers': 24, 'tokens_mlp_dims': 512, 'channels_mlp_dims': 4096, }), } def __init__(self, arch='b', img_size=224, patch_size=16, out_indices=-1, drop_rate=0., drop_path_rate=0., norm_cfg=dict(type='LN'), act_cfg=dict(type='GELU'), patch_cfg=dict(), layer_cfgs=dict(), init_cfg=None): super(MlpMixer, self).__init__(init_cfg) if isinstance(arch, str): arch = arch.lower() assert arch in set(self.arch_zoo), \ f'Arch {arch} is not in default archs {set(self.arch_zoo)}' self.arch_settings = self.arch_zoo[arch] else: essential_keys = { 'embed_dims', 'num_layers', 'tokens_mlp_dims', 'channels_mlp_dims' } assert isinstance(arch, dict) and set(arch) == essential_keys, \ f'Custom arch needs a dict with keys {essential_keys}' self.arch_settings = arch self.embed_dims = self.arch_settings['embed_dims'] self.num_layers = self.arch_settings['num_layers'] self.tokens_mlp_dims = self.arch_settings['tokens_mlp_dims'] self.channels_mlp_dims = self.arch_settings['channels_mlp_dims'] self.img_size = to_2tuple(img_size) _patch_cfg = dict( img_size=img_size, embed_dims=self.embed_dims, conv_cfg=dict( type='Conv2d', kernel_size=patch_size, stride=patch_size), ) _patch_cfg.update(patch_cfg) self.patch_embed = PatchEmbed(**_patch_cfg) num_patches = self.patch_embed.num_patches if isinstance(out_indices, int): out_indices = [out_indices] assert isinstance(out_indices, Sequence), \ f'"out_indices" must be a sequence or int, ' \ f'get {type(out_indices)} instead.' for i, index in enumerate(out_indices): if index < 0: out_indices[i] = self.num_layers + index assert out_indices[i] >= 0, f'Invalid out_indices {index}' else: assert index >= self.num_layers, f'Invalid out_indices {index}' self.out_indices = out_indices self.layers = ModuleList() if isinstance(layer_cfgs, dict): layer_cfgs = [layer_cfgs] * self.num_layers for i in range(self.num_layers): _layer_cfg = dict( num_tokens=num_patches, embed_dims=self.embed_dims, tokens_mlp_dims=self.tokens_mlp_dims, channels_mlp_dims=self.channels_mlp_dims, drop_rate=drop_rate, drop_path_rate=drop_path_rate, act_cfg=act_cfg, norm_cfg=norm_cfg, ) _layer_cfg.update(layer_cfgs[i]) self.layers.append(MixerBlock(**_layer_cfg)) self.norm1_name, norm1 = build_norm_layer( norm_cfg, self.embed_dims, postfix=1) self.add_module(self.norm1_name, norm1) @property def norm1(self): return getattr(self, self.norm1_name) def forward(self, x): x = self.patch_embed(x) outs = [] for i, layer in enumerate(self.layers): x = layer(x) if i == len(self.layers) - 1: x = self.norm1(x) if i in self.out_indices: out = x.transpose(1, 2) outs.append(out) return tuple(outs)
class SwinBlockSequence(BaseModule): """Module with successive Swin Transformer blocks and downsample layer. Args: embed_dims (int): Number of input channels. input_resolution (Tuple[int, int]): The resolution of the input feature map. depth (int): Number of successive swin transformer blocks. num_heads (int): Number of attention heads. downsample (bool, optional): Downsample the output of blocks by patch merging. Defaults to False. downsample_cfg (dict, optional): The extra config of the patch merging layer. Defaults to empty dict. drop_paths (Sequence[float] | float, optional): The drop path rate in each block. Defaults to 0. block_cfgs (Sequence[dict] | dict, optional): The extra config of each block. Defaults to empty dicts. auto_pad (bool, optional): Auto pad the feature map to be divisible by window_size, Defaults to False. init_cfg (dict, optional): The extra config for initialization. Default: None. """ def __init__(self, embed_dims, input_resolution, depth, num_heads, downsample=False, downsample_cfg=dict(), drop_paths=0., block_cfgs=dict(), auto_pad=False, init_cfg=None): super().__init__(init_cfg) if not isinstance(drop_paths, Sequence): drop_paths = [drop_paths] * depth if not isinstance(block_cfgs, Sequence): block_cfg = [deepcopy(block_cfgs) for _ in range(depth)] self.blocks = ModuleList() for i in range(depth): _block_cfg = { 'embed_dims': embed_dims, 'input_resolution': input_resolution, 'num_heads': num_heads, 'shift': False if i % 2 == 0 else True, 'drop_path': drop_paths[i], 'auto_pad': auto_pad, **block_cfg[i] } block = SwinBlock(**_block_cfg) self.blocks.append(block) if downsample: _downsample_cfg = { 'input_resolution': input_resolution, 'in_channels': embed_dims, 'expansion_ratio': 2, 'norm_cfg': dict(type='LN'), **downsample_cfg } self.downsample = PatchMerging(**_downsample_cfg) else: self.downsample = None def forward(self, x): for block in self.blocks: x = block(x) if self.downsample: x = self.downsample(x) return x
class TransformerLayerSequence(BaseModule): """Base class for TransformerEncoder and TransformerDecoder in vision transformer. As base-class of Encoder and Decoder in vision transformer. Support customization such as specifying different kind of `transformer_layer` in `transformer_coder`. Args: transformerlayer (list[obj:`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict`): Config of transformerlayer in TransformerCoder. If it is obj:`mmcv.ConfigDict`, it would be repeated `num_layer` times to a list[`mmcv.ConfigDict`]. Default: None. num_layers (int): The number of `TransformerLayer`. Default: None. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Default: None. """ def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None): super(TransformerLayerSequence, self).__init__(init_cfg) if isinstance(transformerlayers, dict): transformerlayers = [ copy.deepcopy(transformerlayers) for _ in range(num_layers) ] else: assert isinstance(transformerlayers, list) and \ len(transformerlayers) == num_layers self.num_layers = num_layers self.layers = ModuleList() for i in range(num_layers): self.layers.append(build_transformer_layer(transformerlayers[i])) self.embed_dims = self.layers[0].embed_dims self.pre_norm = self.layers[0].pre_norm def forward(self, query, key, value, query_pos=None, key_pos=None, attn_masks=None, query_key_padding_mask=None, key_padding_mask=None, **kwargs): """Forward function for `TransformerCoder`. Args: query (Tensor): Input query with shape `(num_queries, bs, embed_dims)`. key (Tensor): The key tensor with shape `(num_keys, bs, embed_dims)`. value (Tensor): The value tensor with shape `(num_keys, bs, embed_dims)`. query_pos (Tensor): The positional encoding for `query`. Default: None. key_pos (Tensor): The positional encoding for `key`. Default: None. attn_masks (List[Tensor], optional): Each element is 2D Tensor which is used in calculation of corresponding attention in operation_order. Default: None. query_key_padding_mask (Tensor): ByteTensor for `query`, with shape [bs, num_queries]. Only used in self-attention Default: None. key_padding_mask (Tensor): ByteTensor for `query`, with shape [bs, num_keys]. Default: None. Returns: Tensor: results with shape [num_queries, bs, embed_dims]. """ for layer in self.layers: query = layer(query, key, value, query_pos=query_pos, key_pos=key_pos, attn_masks=attn_masks, query_key_padding_mask=query_key_padding_mask, key_padding_mask=key_padding_mask, **kwargs) return query
class SwinBlockSequence(BaseModule): """Module with successive Swin Transformer blocks and downsample layer. Args: embed_dims (int): Number of input channels. depth (int): Number of successive swin transformer blocks. num_heads (int): Number of attention heads. window_size (int): The height and width of the window. Defaults to 7. downsample (bool): Downsample the output of blocks by patch merging. Defaults to False. downsample_cfg (dict): The extra config of the patch merging layer. Defaults to empty dict. drop_paths (Sequence[float] | float): The drop path rate in each block. Defaults to 0. block_cfgs (Sequence[dict] | dict): The extra config of each block. Defaults to empty dicts. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Defaults to False. pad_small_map (bool): If True, pad the small feature map to the window size, which is common used in detection and segmentation. If False, avoid shifting window and shrink the window size to the size of feature map, which is common used in classification. Defaults to False. init_cfg (dict, optional): The extra config for initialization. Defaults to None. """ def __init__(self, embed_dims, depth, num_heads, window_size=7, downsample=False, downsample_cfg=dict(), drop_paths=0., block_cfgs=dict(), with_cp=False, pad_small_map=False, init_cfg=None): super().__init__(init_cfg) if not isinstance(drop_paths, Sequence): drop_paths = [drop_paths] * depth if not isinstance(block_cfgs, Sequence): block_cfgs = [deepcopy(block_cfgs) for _ in range(depth)] self.embed_dims = embed_dims self.blocks = ModuleList() for i in range(depth): _block_cfg = { 'embed_dims': embed_dims, 'num_heads': num_heads, 'window_size': window_size, 'shift': False if i % 2 == 0 else True, 'drop_path': drop_paths[i], 'with_cp': with_cp, 'pad_small_map': pad_small_map, **block_cfgs[i] } block = SwinBlock(**_block_cfg) self.blocks.append(block) if downsample: _downsample_cfg = { 'in_channels': embed_dims, 'out_channels': 2 * embed_dims, 'norm_cfg': dict(type='LN'), **downsample_cfg } self.downsample = PatchMerging(**_downsample_cfg) else: self.downsample = None def forward(self, x, in_shape): for block in self.blocks: x = block(x, in_shape) if self.downsample: x, out_shape = self.downsample(x, in_shape) else: out_shape = in_shape return x, out_shape @property def out_channels(self): if self.downsample: return self.downsample.out_channels else: return self.embed_dims
class STDCModule(BaseModule): """STDCModule. Args: in_channels (int): The number of input channels. out_channels (int): The number of output channels before scaling. stride (int): The number of stride for the first conv layer. norm_cfg (dict): Config dict for normalization layer. Default: None. act_cfg (dict): The activation config for conv layers. num_convs (int): Numbers of conv layers. fusion_type (str): Type of fusion operation. Default: 'add'. init_cfg (dict or list[dict], optional): Initialization config dict. Default: None. """ def __init__(self, in_channels, out_channels, stride, norm_cfg=None, act_cfg=None, num_convs=4, fusion_type='add', init_cfg=None): super(STDCModule, self).__init__(init_cfg=init_cfg) assert num_convs > 1 assert fusion_type in ['add', 'cat'] self.stride = stride self.with_downsample = True if self.stride == 2 else False self.fusion_type = fusion_type self.layers = ModuleList() conv_0 = ConvModule(in_channels, out_channels // 2, kernel_size=1, norm_cfg=norm_cfg) if self.with_downsample: self.downsample = ConvModule(out_channels // 2, out_channels // 2, kernel_size=3, stride=2, padding=1, groups=out_channels // 2, norm_cfg=norm_cfg, act_cfg=None) if self.fusion_type == 'add': self.layers.append(nn.Sequential(conv_0, self.downsample)) self.skip = Sequential( ConvModule(in_channels, in_channels, kernel_size=3, stride=2, padding=1, groups=in_channels, norm_cfg=norm_cfg, act_cfg=None), ConvModule(in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=None)) else: self.layers.append(conv_0) self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) else: self.layers.append(conv_0) for i in range(1, num_convs): out_factor = 2**(i + 1) if i != num_convs - 1 else 2**i self.layers.append( ConvModule(out_channels // 2**i, out_channels // out_factor, kernel_size=3, stride=1, padding=1, norm_cfg=norm_cfg, act_cfg=act_cfg)) def forward(self, inputs): if self.fusion_type == 'add': out = self.forward_add(inputs) else: out = self.forward_cat(inputs) return out def forward_add(self, inputs): layer_outputs = [] x = inputs.clone() for layer in self.layers: x = layer(x) layer_outputs.append(x) if self.with_downsample: inputs = self.skip(inputs) return torch.cat(layer_outputs, dim=1) + inputs def forward_cat(self, inputs): x0 = self.layers[0](inputs) layer_outputs = [x0] for i, layer in enumerate(self.layers[1:]): if i == 0: if self.with_downsample: x = layer(self.downsample(x0)) else: x = layer(x0) else: x = layer(x) layer_outputs.append(x) if self.with_downsample: layer_outputs[0] = self.skip(x0) return torch.cat(layer_outputs, dim=1)
class SwinTransformer(BaseBackbone): """Swin Transformer. A PyTorch implement of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/abs/2103.14030>`_ Inspiration from https://github.com/microsoft/Swin-Transformer Args: arch (str | dict): Swin Transformer architecture. If use string, choose from 'tiny', 'small', 'base' and 'large'. If use dict, it should have below keys: - **embed_dims** (int): The dimensions of embedding. - **depths** (List[int]): The number of blocks in each stage. - **num_heads** (List[int]): The number of heads in attention modules of each stage. Defaults to 'tiny'. img_size (int | tuple): The expected input image shape. Because we support dynamic input shape, just set the argument to the most common input image shape. Defaults to 224. patch_size (int | tuple): The patch size in patch embedding. Defaults to 4. in_channels (int): The num of input channels. Defaults to 3. window_size (int): The height and width of the window. Defaults to 7. drop_rate (float): Dropout rate after embedding. Defaults to 0. drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. use_abs_pos_embed (bool): If True, add absolute position embedding to the patch embedding. Defaults to False. interpolate_mode (str): Select the interpolate mode for absolute position embeding vector resize. Defaults to "bicubic". with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Defaults to False. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. Defaults to -1. norm_eval (bool): Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. Defaults to False. pad_small_map (bool): If True, pad the small feature map to the window size, which is common used in detection and segmentation. If False, avoid shifting window and shrink the window size to the size of feature map, which is common used in classification. Defaults to False. norm_cfg (dict): Config dict for normalization layer for all output features. Defaults to ``dict(type='LN')`` stage_cfgs (Sequence[dict] | dict): Extra config dict for each stage. Defaults to an empty dict. patch_cfg (dict): Extra config dict for patch embedding. Defaults to an empty dict. init_cfg (dict, optional): The Config for initialization. Defaults to None. Examples: >>> from mmcls.models import SwinTransformer >>> import torch >>> extra_config = dict( >>> arch='tiny', >>> stage_cfgs=dict(downsample_cfg={'kernel_size': 3, >>> 'expansion_ratio': 3})) >>> self = SwinTransformer(**extra_config) >>> inputs = torch.rand(1, 3, 224, 224) >>> output = self.forward(inputs) >>> print(output.shape) (1, 2592, 4) """ arch_zoo = { **dict.fromkeys(['t', 'tiny'], {'embed_dims': 96, 'depths': [2, 2, 6, 2], 'num_heads': [3, 6, 12, 24]}), **dict.fromkeys(['s', 'small'], {'embed_dims': 96, 'depths': [2, 2, 18, 2], 'num_heads': [3, 6, 12, 24]}), **dict.fromkeys(['b', 'base'], {'embed_dims': 128, 'depths': [2, 2, 18, 2], 'num_heads': [4, 8, 16, 32]}), **dict.fromkeys(['l', 'large'], {'embed_dims': 192, 'depths': [2, 2, 18, 2], 'num_heads': [6, 12, 24, 48]}), } # yapf: disable _version = 3 num_extra_tokens = 0 def __init__(self, arch='tiny', img_size=224, patch_size=4, in_channels=3, window_size=7, drop_rate=0., drop_path_rate=0.1, out_indices=(3, ), use_abs_pos_embed=False, interpolate_mode='bicubic', with_cp=False, frozen_stages=-1, norm_eval=False, pad_small_map=False, norm_cfg=dict(type='LN'), stage_cfgs=dict(), patch_cfg=dict(), init_cfg=None): super(SwinTransformer, self).__init__(init_cfg=init_cfg) if isinstance(arch, str): arch = arch.lower() assert arch in set(self.arch_zoo), \ f'Arch {arch} is not in default archs {set(self.arch_zoo)}' self.arch_settings = self.arch_zoo[arch] else: essential_keys = {'embed_dims', 'depths', 'num_heads'} assert isinstance(arch, dict) and set(arch) == essential_keys, \ f'Custom arch needs a dict with keys {essential_keys}' self.arch_settings = arch self.embed_dims = self.arch_settings['embed_dims'] self.depths = self.arch_settings['depths'] self.num_heads = self.arch_settings['num_heads'] self.num_layers = len(self.depths) self.out_indices = out_indices self.use_abs_pos_embed = use_abs_pos_embed self.interpolate_mode = interpolate_mode self.frozen_stages = frozen_stages _patch_cfg = dict( in_channels=in_channels, input_size=img_size, embed_dims=self.embed_dims, conv_type='Conv2d', kernel_size=patch_size, stride=patch_size, norm_cfg=dict(type='LN'), ) _patch_cfg.update(patch_cfg) self.patch_embed = PatchEmbed(**_patch_cfg) self.patch_resolution = self.patch_embed.init_out_size if self.use_abs_pos_embed: num_patches = self.patch_resolution[0] * self.patch_resolution[1] self.absolute_pos_embed = nn.Parameter( torch.zeros(1, num_patches, self.embed_dims)) self._register_load_state_dict_pre_hook( self._prepare_abs_pos_embed) self.drop_after_pos = nn.Dropout(p=drop_rate) self.norm_eval = norm_eval # stochastic depth total_depth = sum(self.depths) dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, total_depth) ] # stochastic depth decay rule self.stages = ModuleList() embed_dims = [self.embed_dims] for i, (depth, num_heads) in enumerate(zip(self.depths, self.num_heads)): if isinstance(stage_cfgs, Sequence): stage_cfg = stage_cfgs[i] else: stage_cfg = deepcopy(stage_cfgs) downsample = True if i < self.num_layers - 1 else False _stage_cfg = { 'embed_dims': embed_dims[-1], 'depth': depth, 'num_heads': num_heads, 'window_size': window_size, 'downsample': downsample, 'drop_paths': dpr[:depth], 'with_cp': with_cp, 'pad_small_map': pad_small_map, **stage_cfg } stage = SwinBlockSequence(**_stage_cfg) self.stages.append(stage) dpr = dpr[depth:] embed_dims.append(stage.out_channels) for i in out_indices: if norm_cfg is not None: norm_layer = build_norm_layer(norm_cfg, embed_dims[i + 1])[1] else: norm_layer = nn.Identity() self.add_module(f'norm{i}', norm_layer) def init_weights(self): super(SwinTransformer, self).init_weights() if (isinstance(self.init_cfg, dict) and self.init_cfg['type'] == 'Pretrained'): # Suppress default init if use pretrained model. return if self.use_abs_pos_embed: trunc_normal_(self.absolute_pos_embed, std=0.02) def forward(self, x): x, hw_shape = self.patch_embed(x) if self.use_abs_pos_embed: x = x + resize_pos_embed( self.absolute_pos_embed, self.patch_resolution, hw_shape, self.interpolate_mode, self.num_extra_tokens) x = self.drop_after_pos(x) outs = [] for i, stage in enumerate(self.stages): x, hw_shape = stage(x, hw_shape) if i in self.out_indices: norm_layer = getattr(self, f'norm{i}') out = norm_layer(x) out = out.view(-1, *hw_shape, stage.out_channels).permute(0, 3, 1, 2).contiguous() outs.append(out) return tuple(outs) def _load_from_state_dict(self, state_dict, prefix, local_metadata, *args, **kwargs): """load checkpoints.""" # Names of some parameters in has been changed. version = local_metadata.get('version', None) if (version is None or version < 2) and self.__class__ is SwinTransformer: final_stage_num = len(self.stages) - 1 state_dict_keys = list(state_dict.keys()) for k in state_dict_keys: if k.startswith('norm.') or k.startswith('backbone.norm.'): convert_key = k.replace('norm.', f'norm{final_stage_num}.') state_dict[convert_key] = state_dict[k] del state_dict[k] if (version is None or version < 3) and self.__class__ is SwinTransformer: state_dict_keys = list(state_dict.keys()) for k in state_dict_keys: if 'attn_mask' in k: del state_dict[k] super()._load_from_state_dict(state_dict, prefix, local_metadata, *args, **kwargs) def _freeze_stages(self): if self.frozen_stages >= 0: self.patch_embed.eval() for param in self.patch_embed.parameters(): param.requires_grad = False for i in range(0, self.frozen_stages + 1): m = self.stages[i] m.eval() for param in m.parameters(): param.requires_grad = False for i in self.out_indices: if i <= self.frozen_stages: for param in getattr(self, f'norm{i}').parameters(): param.requires_grad = False def train(self, mode=True): super(SwinTransformer, self).train(mode) self._freeze_stages() if mode and self.norm_eval: for m in self.modules(): # trick: eval have effect on BatchNorm only if isinstance(m, _BatchNorm): m.eval() def _prepare_abs_pos_embed(self, state_dict, prefix, *args, **kwargs): name = prefix + 'absolute_pos_embed' if name not in state_dict.keys(): return ckpt_pos_embed_shape = state_dict[name].shape if self.absolute_pos_embed.shape != ckpt_pos_embed_shape: from mmcls.utils import get_root_logger logger = get_root_logger() logger.info( 'Resize the absolute_pos_embed shape from ' f'{ckpt_pos_embed_shape} to {self.absolute_pos_embed.shape}.') ckpt_pos_embed_shape = to_2tuple( int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens))) pos_embed_shape = self.patch_embed.init_out_size state_dict[name] = resize_pos_embed(state_dict[name], ckpt_pos_embed_shape, pos_embed_shape, self.interpolate_mode, self.num_extra_tokens)
class STDCContextPathNet(BaseModule): """STDCNet with Context Path. The `outs` below is a list of three feature maps from deep to shallow, whose height and width is from small to big, respectively. The biggest feature map of `outs` is outputted for `STDCHead`, where Detail Loss would be calculated by Detail Ground-truth. The other two feature maps are used for Attention Refinement Module, respectively. Besides, the biggest feature map of `outs` and the last output of Attention Refinement Module are concatenated for Feature Fusion Module. Then, this fusion feature map `feat_fuse` would be outputted for `decode_head`. More details please refer to Figure 4 of original paper. Args: backbone_cfg (dict): Config dict for stdc backbone. last_in_channels (tuple(int)), The number of channels of last two feature maps from stdc backbone. Default: (1024, 512). out_channels (int): The channels of output feature maps. Default: 128. ffm_cfg (dict): Config dict for Feature Fusion Module. Default: `dict(in_channels=512, out_channels=256, scale_factor=4)`. upsample_mode (str): Algorithm used for upsampling: ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | ``'trilinear'``. Default: ``'nearest'``. align_corners (str): align_corners argument of F.interpolate. It must be `None` if upsample_mode is ``'nearest'``. Default: None. norm_cfg (dict): Config dict for normalization layer. Default: dict(type='BN'). init_cfg (dict or list[dict], optional): Initialization config dict. Default: None. Return: outputs (tuple): The tuple of list of output feature map for auxiliary heads and decoder head. """ def __init__(self, backbone_cfg, last_in_channels=(1024, 512), out_channels=128, ffm_cfg=dict(in_channels=512, out_channels=256, scale_factor=4), upsample_mode='nearest', align_corners=None, norm_cfg=dict(type='BN'), init_cfg=None): super(STDCContextPathNet, self).__init__(init_cfg=init_cfg) self.backbone = build_backbone(backbone_cfg) self.arms = ModuleList() self.convs = ModuleList() for channels in last_in_channels: self.arms.append(AttentionRefinementModule(channels, out_channels)) self.convs.append( ConvModule(out_channels, out_channels, 3, padding=1, norm_cfg=norm_cfg)) self.conv_avg = ConvModule(last_in_channels[0], out_channels, 1, norm_cfg=norm_cfg) self.ffm = FeatureFusionModule(**ffm_cfg) self.upsample_mode = upsample_mode self.align_corners = align_corners def forward(self, x): outs = list(self.backbone(x)) avg = F.adaptive_avg_pool2d(outs[-1], 1) avg_feat = self.conv_avg(avg) feature_up = resize(avg_feat, size=outs[-1].shape[2:], mode=self.upsample_mode, align_corners=self.align_corners) arms_out = [] for i in range(len(self.arms)): x_arm = self.arms[i](outs[len(outs) - 1 - i]) + feature_up feature_up = resize(x_arm, size=outs[len(outs) - 1 - i - 1].shape[2:], mode=self.upsample_mode, align_corners=self.align_corners) feature_up = self.convs[i](feature_up) arms_out.append(feature_up) feat_fuse = self.ffm(outs[0], arms_out[1]) # The `outputs` has four feature maps. # `outs[0]` is outputted for `STDCHead` auxiliary head. # Two feature maps of `arms_out` are outputted for auxiliary head. # `feat_fuse` is outputted for decoder head. outputs = [outs[0]] + list(arms_out) + [feat_fuse] return tuple(outputs)
class SwinTransformer(BaseBackbone): """ Swin Transformer A PyTorch implement of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/abs/2103.14030>`_ Inspiration from https://github.com/microsoft/Swin-Transformer Args: arch (str | dict): Swin Transformer architecture Defaults to 'T'. img_size (int | tuple): The size of input image. Defaults to 224. in_channels (int): The num of input channels. Defaults to 3. drop_rate (float): Dropout rate after embedding. Defaults to 0. drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. use_abs_pos_embed (bool): If True, add absolute position embedding to the patch embedding. Defaults to False. with_cp (bool, optional): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Defaults to False. auto_pad (bool): If True, auto pad feature map to fit window_size. Defaults to False. norm_cfg (dict, optional): Config dict for normalization layer at end of backone. Defaults to dict(type='LN') stage_cfgs (Sequence | dict, optional): Extra config dict for each stage. Defaults to empty dict. patch_cfg (dict, optional): Extra config dict for patch embedding. Defaults to empty dict. init_cfg (dict, optional): The Config for initialization. Defaults to None. Examples: >>> from mmcls.models import SwinTransformer >>> import torch >>> extra_config = dict( >>> arch='tiny', >>> stage_cfgs=dict(downsample_cfg={'kernel_size': 3, >>> 'expansion_ratio': 3}), >>> auto_pad=True) >>> self = SwinTransformer(**extra_config) >>> inputs = torch.rand(1, 3, 224, 224) >>> output = self.forward(inputs) >>> print(output.shape) (1, 2592, 4) """ arch_zoo = { **dict.fromkeys(['t', 'tiny'], {'embed_dims': 96, 'depths': [2, 2, 6, 2], 'num_heads': [3, 6, 12, 24]}), **dict.fromkeys(['s', 'small'], {'embed_dims': 96, 'depths': [2, 2, 18, 2], 'num_heads': [3, 6, 12, 24]}), **dict.fromkeys(['b', 'base'], {'embed_dims': 128, 'depths': [2, 2, 18, 2], 'num_heads': [4, 8, 16, 32]}), **dict.fromkeys(['l', 'large'], {'embed_dims': 192, 'depths': [2, 2, 18, 2], 'num_heads': [6, 12, 24, 48]}), } # yapf: disable _version = 2 def __init__(self, arch='T', img_size=224, in_channels=3, drop_rate=0., drop_path_rate=0.1, out_indices=(3, ), use_abs_pos_embed=False, auto_pad=False, with_cp=False, norm_cfg=dict(type='LN'), stage_cfgs=dict(), patch_cfg=dict(), init_cfg=None): super(SwinTransformer, self).__init__(init_cfg) if isinstance(arch, str): arch = arch.lower() assert arch in set(self.arch_zoo), \ f'Arch {arch} is not in default archs {set(self.arch_zoo)}' self.arch_settings = self.arch_zoo[arch] else: essential_keys = {'embed_dims', 'depths', 'num_head'} assert isinstance(arch, dict) and set(arch) == essential_keys, \ f'Custom arch needs a dict with keys {essential_keys}' self.arch_settings = arch self.embed_dims = self.arch_settings['embed_dims'] self.depths = self.arch_settings['depths'] self.num_heads = self.arch_settings['num_heads'] self.num_layers = len(self.depths) self.out_indices = out_indices self.use_abs_pos_embed = use_abs_pos_embed self.auto_pad = auto_pad _patch_cfg = { 'img_size': img_size, 'in_channels': in_channels, 'embed_dims': self.embed_dims, 'conv_cfg': dict(type='Conv2d', kernel_size=4, stride=4), 'norm_cfg': dict(type='LN'), **patch_cfg } self.patch_embed = PatchEmbed(**_patch_cfg) num_patches = self.patch_embed.num_patches patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution if self.use_abs_pos_embed: self.absolute_pos_embed = nn.Parameter( torch.zeros(1, num_patches, self.embed_dims)) self.drop_after_pos = nn.Dropout(p=drop_rate) # stochastic depth total_depth = sum(self.depths) dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, total_depth) ] # stochastic depth decay rule self.stages = ModuleList() embed_dims = self.embed_dims input_resolution = patches_resolution for i, (depth, num_heads) in enumerate(zip(self.depths, self.num_heads)): if isinstance(stage_cfgs, Sequence): stage_cfg = stage_cfgs[i] else: stage_cfg = deepcopy(stage_cfgs) downsample = True if i < self.num_layers - 1 else False _stage_cfg = { 'embed_dims': embed_dims, 'depth': depth, 'num_heads': num_heads, 'downsample': downsample, 'input_resolution': input_resolution, 'drop_paths': dpr[:depth], 'with_cp': with_cp, 'auto_pad': auto_pad, **stage_cfg } stage = SwinBlockSequence(**_stage_cfg) self.stages.append(stage) dpr = dpr[depth:] embed_dims = stage.out_channels input_resolution = stage.out_resolution for i in out_indices: if norm_cfg is not None: norm_layer = build_norm_layer(norm_cfg, embed_dims)[1] else: norm_layer = nn.Identity() self.add_module(f'norm{i}', norm_layer) def init_weights(self): super(SwinTransformer, self).init_weights() if (isinstance(self.init_cfg, dict) and self.init_cfg['type'] == 'Pretrained'): # Suppress default init if use pretrained model. return if self.use_abs_pos_embed: trunc_normal_(self.absolute_pos_embed, std=0.02) def forward(self, x): x = self.patch_embed(x) if self.use_abs_pos_embed: x = x + self.absolute_pos_embed x = self.drop_after_pos(x) outs = [] for i, stage in enumerate(self.stages): x = stage(x) if i in self.out_indices: norm_layer = getattr(self, f'norm{i}') out = norm_layer(x) out = out.view(-1, *stage.out_resolution, stage.out_channels).permute(0, 3, 1, 2).contiguous() outs.append(out) return tuple(outs) def _load_from_state_dict(self, state_dict, prefix, local_metadata, *args, **kwargs): """load checkpoints.""" # Names of some parameters in has been changed. version = local_metadata.get('version', None) if (version is None or version < 2) and self.__class__ is SwinTransformer: final_stage_num = len(self.stages) - 1 state_dict_keys = list(state_dict.keys()) for k in state_dict_keys: if k.startswith('norm.') or k.startswith('backbone.norm.'): convert_key = k.replace('norm.', f'norm{final_stage_num}.') state_dict[convert_key] = state_dict[k] del state_dict[k] super()._load_from_state_dict(state_dict, prefix, local_metadata, *args, **kwargs)