Пример #1
0
    def init_weights(self, pretrained=None):
        """Initiate the parameters either from existing checkpoint or from
        scratch."""
        trunc_normal_(self.pos_embed, std=.02)
        trunc_normal_(self.cls_token, std=.02)

        if pretrained:
            self.pretrained = pretrained
        if isinstance(self.pretrained, str):
            logger = get_root_logger()
            logger.info(f'load model from: {self.pretrained}')

            state_dict = _load_checkpoint(self.pretrained)
            if 'state_dict' in state_dict:
                state_dict = state_dict['state_dict']

            if self.attention_type == 'divided_space_time':
                # modify the key names of norm layers
                old_state_dict_keys = list(state_dict.keys())
                for old_key in old_state_dict_keys:
                    if 'norms' in old_key:
                        new_key = old_key.replace('norms.0',
                                                  'attentions.0.norm')
                        new_key = new_key.replace('norms.1', 'ffns.0.norm')
                        state_dict[new_key] = state_dict.pop(old_key)

                # copy the parameters of space attention to time attention
                old_state_dict_keys = list(state_dict.keys())
                for old_key in old_state_dict_keys:
                    if 'attentions.0' in old_key:
                        new_key = old_key.replace('attentions.0',
                                                  'attentions.1')
                        state_dict[new_key] = state_dict[old_key].clone()

            load_state_dict(self, state_dict, strict=False, logger=logger)
Пример #2
0
 def _init_weights(m):
     if isinstance(m, nn.Linear):
         trunc_normal_(m.weight, std=.02)
         if isinstance(m, nn.Linear) and m.bias is not None:
             nn.init.constant_(m.bias, 0)
     elif isinstance(m, nn.LayerNorm):
         nn.init.constant_(m.bias, 0)
         nn.init.constant_(m.weight, 1.0)
Пример #3
0
    def init_weights(self):
        super().init_weights()

        if (isinstance(self.init_cfg, dict)
                and self.init_cfg['type'] == 'Pretrained'):
            # Suppress custom init if use pretrained model.
            return

        trunc_normal_(self.cls_token, std=.02)
 def init_weights(self):
     trunc_normal_(self.cls_emb, std=self.init_std)
     trunc_normal_init(self.patch_proj, std=self.init_std)
     trunc_normal_init(self.classes_proj, std=self.init_std)
     for n, m in self.named_modules():
         if isinstance(m, nn.Linear):
             trunc_normal_init(m, std=self.init_std, bias=0)
         elif isinstance(m, nn.LayerNorm):
             constant_init(m, val=1.0, bias=0.0)
 def init_weights(self):
     super(VisionTransformerClsHead, self).init_weights()
     # Modified from ClassyVision
     if hasattr(self.layers, 'pre_logits'):
         # Lecun norm
         trunc_normal_(
             self.layers.pre_logits.weight,
             std=math.sqrt(1 / self.layers.pre_logits.in_features))
         nn.init.zeros_(self.layers.pre_logits.bias)
Пример #6
0
    def init_weights(self):
        super(SwinTransformer, self).init_weights()

        if (isinstance(self.init_cfg, dict)
                and self.init_cfg['type'] == 'Pretrained'):
            # Suppress default init if use pretrained model.
            return

        if self.use_abs_pos_embed:
            trunc_normal_(self.absolute_pos_embed, std=0.02)
Пример #7
0
 def _init_weights(self, m):
     if isinstance(m, nn.Linear):
         trunc_normal_(m.weight, std=.02)
         if isinstance(m, nn.Linear) and m.bias is not None:
             nn.init.constant_(m.bias, 0)
     elif isinstance(m, nn.LayerNorm):
         nn.init.constant_(m.weight, 1.0)
         nn.init.constant_(m.bias, 0)
     elif isinstance(m, nn.Conv2d):
         fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
         fan_out //= m.groups
         m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
         if m.bias is not None:
             m.bias.data.zero_()
Пример #8
0
    def __init__(self,
                 dim,
                 window_size,
                 num_heads,
                 qkv_bias=True,
                 qk_scale=None,
                 attn_drop=0.,
                 proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim**-0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
                        num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the
        # window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] \
            - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(
            1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :,
                        0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer('relative_position_index',
                             relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)
Пример #9
0
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight,
                                    mode='fan_out',
                                    nonlinearity='relu')
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1.)
            nn.init.constant_(m.bias, 0.)

        if hasattr(m, 'zero_init_last_bn'):
            m.zero_init_last_bn()
Пример #10
0
    def init_weights(self):

        def _init_weights(m):
            if isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)

        self.apply(_init_weights)
        self.fix_init_weight()

        if (isinstance(self.init_cfg, dict)
                and self.init_cfg.get('type') == 'Pretrained'):
            logger = get_root_logger()
            checkpoint = _load_checkpoint(
                self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
            state_dict = self.resize_rel_pos_embed(checkpoint)
            state_dict = self.resize_abs_pos_embed(state_dict)
            self.load_state_dict(state_dict, False)
        elif self.init_cfg is not None:
            super(MAE, self).init_weights()
        else:
            # We only implement the 'jax_impl' initialization implemented at
            # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353  # noqa: E501
            # Copyright 2019 Ross Wightman
            # Licensed under the Apache License, Version 2.0 (the "License")
            trunc_normal_(self.cls_token, std=.02)
            for n, m in self.named_modules():
                if isinstance(m, nn.Linear):
                    trunc_normal_(m.weight, std=.02)
                    if m.bias is not None:
                        if 'ffn' in n:
                            nn.init.normal_(m.bias, mean=0., std=1e-6)
                        else:
                            nn.init.constant_(m.bias, 0)
                elif isinstance(m, nn.Conv2d):
                    kaiming_init(m, mode='fan_in', bias=0.)
                elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
                    constant_init(m, val=1.0, bias=0.)
Пример #11
0
    def init_weights(self):
        if (isinstance(self.init_cfg, dict)
                and self.init_cfg.get('type') == 'Pretrained'):
            logger = get_root_logger()
            checkpoint = _load_checkpoint(
                self.init_cfg['checkpoint'], logger=logger, map_location='cpu')

            if 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
            else:
                state_dict = checkpoint

            if 'pos_embed' in state_dict.keys():
                if self.pos_embed.shape != state_dict['pos_embed'].shape:
                    logger.info(msg=f'Resize the pos_embed shape from '
                                f'{state_dict["pos_embed"].shape} to '
                                f'{self.pos_embed.shape}')
                    h, w = self.img_size
                    pos_size = int(
                        math.sqrt(state_dict['pos_embed'].shape[1] - 1))
                    state_dict['pos_embed'] = self.resize_pos_embed(
                        state_dict['pos_embed'],
                        (h // self.patch_size, w // self.patch_size),
                        (pos_size, pos_size), self.interpolate_mode)

            self.load_state_dict(state_dict, False)
        elif self.init_cfg is not None:
            super(VisionTransformer, self).init_weights()
        else:
            # We only implement the 'jax_impl' initialization implemented at
            # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353  # noqa: E501
            trunc_normal_(self.pos_embed, std=.02)
            trunc_normal_(self.cls_token, std=.02)
            for n, m in self.named_modules():
                if isinstance(m, nn.Linear):
                    trunc_normal_(m.weight, std=.02)
                    if m.bias is not None:
                        if 'ffn' in n:
                            nn.init.normal_(m.bias, mean=0., std=1e-6)
                        else:
                            nn.init.constant_(m.bias, 0)
                elif isinstance(m, nn.Conv2d):
                    kaiming_init(m, mode='fan_in', bias=0.)
                elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
                    constant_init(m, val=1.0, bias=0.)
Пример #12
0
 def _init_weights(self, m):
     if isinstance(m, nn.Linear):
         trunc_normal_(m.weight, std=.02)
         if isinstance(m, nn.Linear) and m.bias is not None:
             nn.init.constant_(m.bias, 0)
Пример #13
0
    def __init__(self,
                 pretrain_img_size=224,
                 patch_size=4,
                 in_chans=3,
                 embed_dim=96,
                 depths=[2, 2, 6, 2],
                 num_heads=[3, 6, 12, 24],
                 window_size=7,
                 mlp_ratio=4.,
                 qkv_bias=True,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.2,
                 norm_layer=nn.LayerNorm,
                 ape=False,
                 patch_norm=True,
                 out_indices=(0, 1, 2, 3),
                 frozen_stages=-1,
                 use_checkpoint=False,
                 pretrained_window_sizes=[0, 0, 0, 0],
                 init_cfg=None):
        super().__init__(init_cfg=init_cfg)
        self.pretrain_img_size = pretrain_img_size
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.ape = ape
        self.patch_norm = patch_norm
        self.out_indices = out_indices
        self.frozen_stages = frozen_stages

        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed(
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)

        # absolute position embedding
        if self.ape:
            pretrain_img_size = to_2tuple(pretrain_img_size)
            patch_size = to_2tuple(patch_size)
            patches_resolution = [
                pretrain_img_size[0] // patch_size[0],
                pretrain_img_size[1] // patch_size[1]
            ]

            self.absolute_pos_embed = nn.Parameter(
                torch.zeros(1, embed_dim, patches_resolution[0],
                            patches_resolution[1]))
            trunc_normal_(self.absolute_pos_embed, std=.02)

        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth
        dpr = [
            x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
        ]  # stochastic depth decay rule

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(
                dim=int(embed_dim * 2**i_layer),
                depth=depths[i_layer],
                num_heads=num_heads[i_layer],
                window_size=window_size,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                norm_layer=norm_layer,
                downsample=PatchMerging if
                (i_layer < self.num_layers - 1) else None,
                use_checkpoint=use_checkpoint,
                pretrained_window_size=pretrained_window_sizes[i_layer])
            self.layers.append(layer)

        num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
        self.num_features = num_features

        # add a norm layer for each output
        for i_layer in out_indices:
            layer = norm_layer(num_features[i_layer])
            layer_name = f'norm{i_layer}'
            self.add_module(layer_name, layer)

        self._freeze_stages()
Пример #14
0
    def init_weights(self):
        super(SwinTransformer, self).init_weights()

        if self.use_abs_pos_embed:
            trunc_normal_(self.absolute_pos_embed, std=0.02)
Пример #15
0
    def __init__(self,
                 arch='b',
                 img_size=224,
                 patch_size=16,
                 in_channels=3,
                 ffn_ratio=4,
                 qkv_bias=False,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.,
                 act_cfg=dict(type='GELU'),
                 norm_cfg=dict(type='LN'),
                 first_stride=4,
                 num_fcs=2,
                 init_cfg=[
                     dict(type='TruncNormal', layer='Linear', std=.02),
                     dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
                 ]):
        super(TNT, self).__init__(init_cfg=init_cfg)

        if isinstance(arch, str):
            arch = arch.lower()
            assert arch in set(self.arch_zoo), \
                f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
            self.arch_settings = self.arch_zoo[arch]
        else:
            essential_keys = {
                'embed_dims_outer', 'embed_dims_inner', 'num_layers',
                'num_heads_inner', 'num_heads_outer'
            }
            assert isinstance(arch, dict) and set(arch) == essential_keys, \
                f'Custom arch needs a dict with keys {essential_keys}'
            self.arch_settings = arch

        self.embed_dims_inner = self.arch_settings['embed_dims_inner']
        self.embed_dims_outer = self.arch_settings['embed_dims_outer']
        # embed_dims for consistency with other models
        self.embed_dims = self.embed_dims_outer
        self.num_layers = self.arch_settings['num_layers']
        self.num_heads_inner = self.arch_settings['num_heads_inner']
        self.num_heads_outer = self.arch_settings['num_heads_outer']

        self.pixel_embed = PixelEmbed(img_size=img_size,
                                      patch_size=patch_size,
                                      in_channels=in_channels,
                                      embed_dims_inner=self.embed_dims_inner,
                                      stride=first_stride)
        num_patches = self.pixel_embed.num_patches
        self.num_patches = num_patches
        new_patch_size = self.pixel_embed.new_patch_size
        num_pixel = new_patch_size[0] * new_patch_size[1]

        self.norm1_proj = build_norm_layer(norm_cfg, num_pixel *
                                           self.embed_dims_inner)[1]
        self.projection = nn.Linear(num_pixel * self.embed_dims_inner,
                                    self.embed_dims_outer)
        self.norm2_proj = build_norm_layer(norm_cfg, self.embed_dims_outer)[1]

        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims_outer))
        self.patch_pos = nn.Parameter(
            torch.zeros(1, num_patches + 1, self.embed_dims_outer))
        self.pixel_pos = nn.Parameter(
            torch.zeros(1, self.embed_dims_inner, new_patch_size[0],
                        new_patch_size[1]))
        self.drop_after_pos = nn.Dropout(p=drop_rate)

        dpr = [
            x.item()
            for x in torch.linspace(0, drop_path_rate, self.num_layers)
        ]  # stochastic depth decay rule
        self.layers = ModuleList()
        for i in range(self.num_layers):
            block_cfg = dict(ffn_ratio=ffn_ratio,
                             drop_rate=drop_rate,
                             attn_drop_rate=attn_drop_rate,
                             drop_path_rate=dpr[i],
                             num_fcs=num_fcs,
                             qkv_bias=qkv_bias,
                             norm_cfg=norm_cfg,
                             batch_first=True)
            self.layers.append(
                TnTLayer(num_pixel=num_pixel,
                         embed_dims_inner=self.embed_dims_inner,
                         embed_dims_outer=self.embed_dims_outer,
                         num_heads_inner=self.num_heads_inner,
                         num_heads_outer=self.num_heads_outer,
                         inner_block_cfg=block_cfg,
                         outer_block_cfg=block_cfg,
                         norm_cfg=norm_cfg))

        self.norm = build_norm_layer(norm_cfg, self.embed_dims_outer)[1]

        trunc_normal_(self.cls_token, std=.02)
        trunc_normal_(self.patch_pos, std=.02)
        trunc_normal_(self.pixel_pos, std=.02)
Пример #16
0
 def init_weights(self):
     trunc_normal_(self.relative_position_bias_table, std=0.02)
Пример #17
0
    def init_weights(self):
        logger = get_root_logger()
        if self.init_cfg is None:
            logger.warn(f'No pre-trained weights for '
                        f'{self.__class__.__name__}, '
                        f'training start from scratch')
            if self.use_abs_pos_embed:
                trunc_normal_(self.absolute_pos_embed, std=0.02)
            for m in self.modules():
                if isinstance(m, nn.Linear):
                    trunc_normal_init(m, std=.02, bias=0.)
                elif isinstance(m, nn.LayerNorm):
                    constant_init(m, 1.0)
        else:
            assert 'checkpoint' in self.init_cfg, f'Only support ' \
                                                  f'specify `Pretrained` in ' \
                                                  f'`init_cfg` in ' \
                                                  f'{self.__class__.__name__} '
            ckpt = _load_checkpoint(self.init_cfg.checkpoint,
                                    logger=logger,
                                    map_location='cpu')
            if 'state_dict' in ckpt:
                _state_dict = ckpt['state_dict']
            elif 'model' in ckpt:
                _state_dict = ckpt['model']
            else:
                _state_dict = ckpt
            if self.convert_weights:
                # supported loading weight from original repo,
                _state_dict = swin_converter(_state_dict)

            state_dict = OrderedDict()
            for k, v in _state_dict.items():
                if k.startswith('backbone.'):
                    state_dict[k[9:]] = v

            # strip prefix of state_dict
            if list(state_dict.keys())[0].startswith('module.'):
                state_dict = {k[7:]: v for k, v in state_dict.items()}

            # reshape absolute position embedding
            if state_dict.get('absolute_pos_embed') is not None:
                absolute_pos_embed = state_dict['absolute_pos_embed']
                N1, L, C1 = absolute_pos_embed.size()
                N2, C2, H, W = self.absolute_pos_embed.size()
                if N1 != N2 or C1 != C2 or L != H * W:
                    logger.warning('Error in loading absolute_pos_embed, pass')
                else:
                    state_dict['absolute_pos_embed'] = absolute_pos_embed.view(
                        N2, H, W, C2).permute(0, 3, 1, 2).contiguous()

            # interpolate position bias table if needed
            relative_position_bias_table_keys = [
                k for k in state_dict.keys()
                if 'relative_position_bias_table' in k
            ]
            for table_key in relative_position_bias_table_keys:
                table_pretrained = state_dict[table_key]
                table_current = self.state_dict()[table_key]
                L1, nH1 = table_pretrained.size()
                L2, nH2 = table_current.size()
                if nH1 != nH2:
                    logger.warning(f'Error in loading {table_key}, pass')
                elif L1 != L2:
                    S1 = int(L1**0.5)
                    S2 = int(L2**0.5)
                    table_pretrained_resized = F.interpolate(
                        table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1),
                        size=(S2, S2),
                        mode='bicubic')
                    state_dict[table_key] = table_pretrained_resized.view(
                        nH2, L2).permute(1, 0).contiguous()

            # load state_dict
            self.load_state_dict(state_dict, False)
Пример #18
0
 def init_weights(self):
     trunc_normal_(self.pos_embed, std=0.02)
Пример #19
0
    def init_weights(self):
        super(WindowMSA, self).init_weights()

        trunc_normal_(self.relative_position_bias_table, std=0.02)
Пример #20
0
 def _init_weights(m):
     if isinstance(m, nn.Conv2d):
         trunc_normal_(m.weight, std=.02)
         if m.bias is not None:
             nn.init.constant_(m.bias, 0)
    def init_weights(self):
        super(VisionTransformer, self).init_weights()

        if not (isinstance(self.init_cfg, dict)
                and self.init_cfg['type'] == 'Pretrained'):
            trunc_normal_(self.pos_embed, std=0.02)
Пример #22
0
    def __init__(self,
                 arch='tiny',
                 patch_size=16,
                 base_channels=64,
                 mlp_ratio=4.,
                 qkv_bias=True,
                 with_cls_token=True,
                 drop_path_rate=0.,
                 norm_eval=True,
                 frozen_stages=0,
                 out_indices=-1,
                 init_cfg=None):

        super().__init__(init_cfg=init_cfg)

        if isinstance(arch, str):
            arch = arch.lower()
            assert arch in set(self.arch_zoo), \
                f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
            self.arch_settings = self.arch_zoo[arch]
        else:
            essential_keys = {
                'embed_dims', 'depths', 'num_heads', 'channel_ratio'
            }
            assert isinstance(arch, dict) and set(arch) == essential_keys, \
                f'Custom arch needs a dict with keys {essential_keys}'
            self.arch_settings = arch

        self.num_features = self.embed_dims = self.arch_settings['embed_dims']
        self.depths = self.arch_settings['depths']
        self.num_heads = self.arch_settings['num_heads']
        self.channel_ratio = self.arch_settings['channel_ratio']

        if isinstance(out_indices, int):
            out_indices = [out_indices]
        assert isinstance(out_indices, Sequence), \
            f'"out_indices" must by a sequence or int, ' \
            f'get {type(out_indices)} instead.'
        for i, index in enumerate(out_indices):
            if index < 0:
                out_indices[i] = self.depths + index + 1
                assert out_indices[i] >= 0, f'Invalid out_indices {index}'
        self.out_indices = out_indices

        self.norm_eval = norm_eval
        self.frozen_stages = frozen_stages

        self.with_cls_token = with_cls_token
        if self.with_cls_token:
            self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))

        # stochastic depth decay rule
        self.trans_dpr = [
            x.item() for x in torch.linspace(0, drop_path_rate, self.depths)
        ]

        # Stem stage: get the feature maps by conv block
        self.conv1 = nn.Conv2d(3,
                               64,
                               kernel_size=7,
                               stride=2,
                               padding=3,
                               bias=False)  # 1 / 2 [112, 112]
        self.bn1 = nn.BatchNorm2d(64)
        self.act1 = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2,
                                    padding=1)  # 1 / 4 [56, 56]

        assert patch_size % 16 == 0, 'The patch size of Conformer must ' \
            'be divisible by 16.'
        trans_down_stride = patch_size // 4

        # To solve the issue #680
        # Auto pad the feature map to be divisible by trans_down_stride
        self.auto_pad = AdaptivePadding(trans_down_stride, trans_down_stride)

        # 1 stage
        stage1_channels = int(base_channels * self.channel_ratio)
        self.conv_1 = ConvBlock(in_channels=64,
                                out_channels=stage1_channels,
                                with_residual_conv=True,
                                stride=1)
        self.trans_patch_conv = nn.Conv2d(64,
                                          self.embed_dims,
                                          kernel_size=trans_down_stride,
                                          stride=trans_down_stride,
                                          padding=0)

        self.trans_1 = TransformerEncoderLayer(
            embed_dims=self.embed_dims,
            num_heads=self.num_heads,
            feedforward_channels=int(self.embed_dims * mlp_ratio),
            drop_path_rate=self.trans_dpr[0],
            qkv_bias=qkv_bias,
            norm_cfg=dict(type='LN', eps=1e-6))

        # 2~4 stage
        init_stage = 2
        fin_stage = self.depths // 3 + 1
        for i in range(init_stage, fin_stage):
            self.add_module(
                f'conv_trans_{i}',
                ConvTransBlock(in_channels=stage1_channels,
                               out_channels=stage1_channels,
                               embed_dims=self.embed_dims,
                               conv_stride=1,
                               with_residual_conv=False,
                               down_stride=trans_down_stride,
                               num_heads=self.num_heads,
                               mlp_ratio=mlp_ratio,
                               qkv_bias=qkv_bias,
                               drop_path_rate=self.trans_dpr[i - 1],
                               with_cls_token=self.with_cls_token))

        stage2_channels = int(base_channels * self.channel_ratio * 2)
        # 5~8 stage
        init_stage = fin_stage  # 5
        fin_stage = fin_stage + self.depths // 3  # 9
        for i in range(init_stage, fin_stage):
            if i == init_stage:
                conv_stride = 2
                in_channels = stage1_channels
            else:
                conv_stride = 1
                in_channels = stage2_channels

            with_residual_conv = True if i == init_stage else False
            self.add_module(
                f'conv_trans_{i}',
                ConvTransBlock(in_channels=in_channels,
                               out_channels=stage2_channels,
                               embed_dims=self.embed_dims,
                               conv_stride=conv_stride,
                               with_residual_conv=with_residual_conv,
                               down_stride=trans_down_stride // 2,
                               num_heads=self.num_heads,
                               mlp_ratio=mlp_ratio,
                               qkv_bias=qkv_bias,
                               drop_path_rate=self.trans_dpr[i - 1],
                               with_cls_token=self.with_cls_token))

        stage3_channels = int(base_channels * self.channel_ratio * 2 * 2)
        # 9~12 stage
        init_stage = fin_stage  # 9
        fin_stage = fin_stage + self.depths // 3  # 13
        for i in range(init_stage, fin_stage):
            if i == init_stage:
                conv_stride = 2
                in_channels = stage2_channels
                with_residual_conv = True
            else:
                conv_stride = 1
                in_channels = stage3_channels
                with_residual_conv = False

            last_fusion = (i == self.depths)

            self.add_module(
                f'conv_trans_{i}',
                ConvTransBlock(in_channels=in_channels,
                               out_channels=stage3_channels,
                               embed_dims=self.embed_dims,
                               conv_stride=conv_stride,
                               with_residual_conv=with_residual_conv,
                               down_stride=trans_down_stride // 4,
                               num_heads=self.num_heads,
                               mlp_ratio=mlp_ratio,
                               qkv_bias=qkv_bias,
                               drop_path_rate=self.trans_dpr[i - 1],
                               with_cls_token=self.with_cls_token,
                               last_fusion=last_fusion))
        self.fin_stage = fin_stage

        self.pooling = nn.AdaptiveAvgPool2d(1)
        self.trans_norm = nn.LayerNorm(self.embed_dims)

        if self.with_cls_token:
            trunc_normal_(self.cls_token, std=.02)