def init_weights(self, pretrained=None): """Initiate the parameters either from existing checkpoint or from scratch.""" trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) if pretrained: self.pretrained = pretrained if isinstance(self.pretrained, str): logger = get_root_logger() logger.info(f'load model from: {self.pretrained}') state_dict = _load_checkpoint(self.pretrained) if 'state_dict' in state_dict: state_dict = state_dict['state_dict'] if self.attention_type == 'divided_space_time': # modify the key names of norm layers old_state_dict_keys = list(state_dict.keys()) for old_key in old_state_dict_keys: if 'norms' in old_key: new_key = old_key.replace('norms.0', 'attentions.0.norm') new_key = new_key.replace('norms.1', 'ffns.0.norm') state_dict[new_key] = state_dict.pop(old_key) # copy the parameters of space attention to time attention old_state_dict_keys = list(state_dict.keys()) for old_key in old_state_dict_keys: if 'attentions.0' in old_key: new_key = old_key.replace('attentions.0', 'attentions.1') state_dict[new_key] = state_dict[old_key].clone() load_state_dict(self, state_dict, strict=False, logger=logger)
def _init_weights(m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0)
def init_weights(self): super().init_weights() if (isinstance(self.init_cfg, dict) and self.init_cfg['type'] == 'Pretrained'): # Suppress custom init if use pretrained model. return trunc_normal_(self.cls_token, std=.02)
def init_weights(self): trunc_normal_(self.cls_emb, std=self.init_std) trunc_normal_init(self.patch_proj, std=self.init_std) trunc_normal_init(self.classes_proj, std=self.init_std) for n, m in self.named_modules(): if isinstance(m, nn.Linear): trunc_normal_init(m, std=self.init_std, bias=0) elif isinstance(m, nn.LayerNorm): constant_init(m, val=1.0, bias=0.0)
def init_weights(self): super(VisionTransformerClsHead, self).init_weights() # Modified from ClassyVision if hasattr(self.layers, 'pre_logits'): # Lecun norm trunc_normal_( self.layers.pre_logits.weight, std=math.sqrt(1 / self.layers.pre_logits.in_features)) nn.init.zeros_(self.layers.pre_logits.bias)
def init_weights(self): super(SwinTransformer, self).init_weights() if (isinstance(self.init_cfg, dict) and self.init_cfg['type'] == 'Pretrained'): # Suppress default init if use pretrained model. return if self.use_abs_pos_embed: trunc_normal_(self.absolute_pos_embed, std=0.02)
def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.weight, 1.0) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_()
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim**-0.5 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the # window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] \ - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute( 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer('relative_position_index', relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) trunc_normal_(self.relative_position_bias_table, std=.02) self.softmax = nn.Softmax(dim=-1)
def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1.) nn.init.constant_(m.bias, 0.) if hasattr(m, 'zero_init_last_bn'): m.zero_init_last_bn()
def init_weights(self): def _init_weights(m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) self.apply(_init_weights) self.fix_init_weight() if (isinstance(self.init_cfg, dict) and self.init_cfg.get('type') == 'Pretrained'): logger = get_root_logger() checkpoint = _load_checkpoint( self.init_cfg['checkpoint'], logger=logger, map_location='cpu') state_dict = self.resize_rel_pos_embed(checkpoint) state_dict = self.resize_abs_pos_embed(state_dict) self.load_state_dict(state_dict, False) elif self.init_cfg is not None: super(MAE, self).init_weights() else: # We only implement the 'jax_impl' initialization implemented at # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 # Copyright 2019 Ross Wightman # Licensed under the Apache License, Version 2.0 (the "License") trunc_normal_(self.cls_token, std=.02) for n, m in self.named_modules(): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if m.bias is not None: if 'ffn' in n: nn.init.normal_(m.bias, mean=0., std=1e-6) else: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Conv2d): kaiming_init(m, mode='fan_in', bias=0.) elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): constant_init(m, val=1.0, bias=0.)
def init_weights(self): if (isinstance(self.init_cfg, dict) and self.init_cfg.get('type') == 'Pretrained'): logger = get_root_logger() checkpoint = _load_checkpoint( self.init_cfg['checkpoint'], logger=logger, map_location='cpu') if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: state_dict = checkpoint if 'pos_embed' in state_dict.keys(): if self.pos_embed.shape != state_dict['pos_embed'].shape: logger.info(msg=f'Resize the pos_embed shape from ' f'{state_dict["pos_embed"].shape} to ' f'{self.pos_embed.shape}') h, w = self.img_size pos_size = int( math.sqrt(state_dict['pos_embed'].shape[1] - 1)) state_dict['pos_embed'] = self.resize_pos_embed( state_dict['pos_embed'], (h // self.patch_size, w // self.patch_size), (pos_size, pos_size), self.interpolate_mode) self.load_state_dict(state_dict, False) elif self.init_cfg is not None: super(VisionTransformer, self).init_weights() else: # We only implement the 'jax_impl' initialization implemented at # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) for n, m in self.named_modules(): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if m.bias is not None: if 'ffn' in n: nn.init.normal_(m.bias, mean=0., std=1e-6) else: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Conv2d): kaiming_init(m, mode='fan_in', bias=0.) elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): constant_init(m, val=1.0, bias=0.)
def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0)
def __init__(self, pretrain_img_size=224, patch_size=4, in_chans=3, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.2, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, out_indices=(0, 1, 2, 3), frozen_stages=-1, use_checkpoint=False, pretrained_window_sizes=[0, 0, 0, 0], init_cfg=None): super().__init__(init_cfg=init_cfg) self.pretrain_img_size = pretrain_img_size self.num_layers = len(depths) self.embed_dim = embed_dim self.ape = ape self.patch_norm = patch_norm self.out_indices = out_indices self.frozen_stages = frozen_stages # split image into non-overlapping patches self.patch_embed = PatchEmbed( patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None) # absolute position embedding if self.ape: pretrain_img_size = to_2tuple(pretrain_img_size) patch_size = to_2tuple(patch_size) patches_resolution = [ pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1] ] self.absolute_pos_embed = nn.Parameter( torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])) trunc_normal_(self.absolute_pos_embed, std=.02) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) ] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer( dim=int(embed_dim * 2**i_layer), depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint, pretrained_window_size=pretrained_window_sizes[i_layer]) self.layers.append(layer) num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)] self.num_features = num_features # add a norm layer for each output for i_layer in out_indices: layer = norm_layer(num_features[i_layer]) layer_name = f'norm{i_layer}' self.add_module(layer_name, layer) self._freeze_stages()
def init_weights(self): super(SwinTransformer, self).init_weights() if self.use_abs_pos_embed: trunc_normal_(self.absolute_pos_embed, std=0.02)
def __init__(self, arch='b', img_size=224, patch_size=16, in_channels=3, ffn_ratio=4, qkv_bias=False, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., act_cfg=dict(type='GELU'), norm_cfg=dict(type='LN'), first_stride=4, num_fcs=2, init_cfg=[ dict(type='TruncNormal', layer='Linear', std=.02), dict(type='Constant', layer='LayerNorm', val=1., bias=0.) ]): super(TNT, self).__init__(init_cfg=init_cfg) if isinstance(arch, str): arch = arch.lower() assert arch in set(self.arch_zoo), \ f'Arch {arch} is not in default archs {set(self.arch_zoo)}' self.arch_settings = self.arch_zoo[arch] else: essential_keys = { 'embed_dims_outer', 'embed_dims_inner', 'num_layers', 'num_heads_inner', 'num_heads_outer' } assert isinstance(arch, dict) and set(arch) == essential_keys, \ f'Custom arch needs a dict with keys {essential_keys}' self.arch_settings = arch self.embed_dims_inner = self.arch_settings['embed_dims_inner'] self.embed_dims_outer = self.arch_settings['embed_dims_outer'] # embed_dims for consistency with other models self.embed_dims = self.embed_dims_outer self.num_layers = self.arch_settings['num_layers'] self.num_heads_inner = self.arch_settings['num_heads_inner'] self.num_heads_outer = self.arch_settings['num_heads_outer'] self.pixel_embed = PixelEmbed(img_size=img_size, patch_size=patch_size, in_channels=in_channels, embed_dims_inner=self.embed_dims_inner, stride=first_stride) num_patches = self.pixel_embed.num_patches self.num_patches = num_patches new_patch_size = self.pixel_embed.new_patch_size num_pixel = new_patch_size[0] * new_patch_size[1] self.norm1_proj = build_norm_layer(norm_cfg, num_pixel * self.embed_dims_inner)[1] self.projection = nn.Linear(num_pixel * self.embed_dims_inner, self.embed_dims_outer) self.norm2_proj = build_norm_layer(norm_cfg, self.embed_dims_outer)[1] self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims_outer)) self.patch_pos = nn.Parameter( torch.zeros(1, num_patches + 1, self.embed_dims_outer)) self.pixel_pos = nn.Parameter( torch.zeros(1, self.embed_dims_inner, new_patch_size[0], new_patch_size[1])) self.drop_after_pos = nn.Dropout(p=drop_rate) dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers) ] # stochastic depth decay rule self.layers = ModuleList() for i in range(self.num_layers): block_cfg = dict(ffn_ratio=ffn_ratio, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=dpr[i], num_fcs=num_fcs, qkv_bias=qkv_bias, norm_cfg=norm_cfg, batch_first=True) self.layers.append( TnTLayer(num_pixel=num_pixel, embed_dims_inner=self.embed_dims_inner, embed_dims_outer=self.embed_dims_outer, num_heads_inner=self.num_heads_inner, num_heads_outer=self.num_heads_outer, inner_block_cfg=block_cfg, outer_block_cfg=block_cfg, norm_cfg=norm_cfg)) self.norm = build_norm_layer(norm_cfg, self.embed_dims_outer)[1] trunc_normal_(self.cls_token, std=.02) trunc_normal_(self.patch_pos, std=.02) trunc_normal_(self.pixel_pos, std=.02)
def init_weights(self): trunc_normal_(self.relative_position_bias_table, std=0.02)
def init_weights(self): logger = get_root_logger() if self.init_cfg is None: logger.warn(f'No pre-trained weights for ' f'{self.__class__.__name__}, ' f'training start from scratch') if self.use_abs_pos_embed: trunc_normal_(self.absolute_pos_embed, std=0.02) for m in self.modules(): if isinstance(m, nn.Linear): trunc_normal_init(m, std=.02, bias=0.) elif isinstance(m, nn.LayerNorm): constant_init(m, 1.0) else: assert 'checkpoint' in self.init_cfg, f'Only support ' \ f'specify `Pretrained` in ' \ f'`init_cfg` in ' \ f'{self.__class__.__name__} ' ckpt = _load_checkpoint(self.init_cfg.checkpoint, logger=logger, map_location='cpu') if 'state_dict' in ckpt: _state_dict = ckpt['state_dict'] elif 'model' in ckpt: _state_dict = ckpt['model'] else: _state_dict = ckpt if self.convert_weights: # supported loading weight from original repo, _state_dict = swin_converter(_state_dict) state_dict = OrderedDict() for k, v in _state_dict.items(): if k.startswith('backbone.'): state_dict[k[9:]] = v # strip prefix of state_dict if list(state_dict.keys())[0].startswith('module.'): state_dict = {k[7:]: v for k, v in state_dict.items()} # reshape absolute position embedding if state_dict.get('absolute_pos_embed') is not None: absolute_pos_embed = state_dict['absolute_pos_embed'] N1, L, C1 = absolute_pos_embed.size() N2, C2, H, W = self.absolute_pos_embed.size() if N1 != N2 or C1 != C2 or L != H * W: logger.warning('Error in loading absolute_pos_embed, pass') else: state_dict['absolute_pos_embed'] = absolute_pos_embed.view( N2, H, W, C2).permute(0, 3, 1, 2).contiguous() # interpolate position bias table if needed relative_position_bias_table_keys = [ k for k in state_dict.keys() if 'relative_position_bias_table' in k ] for table_key in relative_position_bias_table_keys: table_pretrained = state_dict[table_key] table_current = self.state_dict()[table_key] L1, nH1 = table_pretrained.size() L2, nH2 = table_current.size() if nH1 != nH2: logger.warning(f'Error in loading {table_key}, pass') elif L1 != L2: S1 = int(L1**0.5) S2 = int(L2**0.5) table_pretrained_resized = F.interpolate( table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1), size=(S2, S2), mode='bicubic') state_dict[table_key] = table_pretrained_resized.view( nH2, L2).permute(1, 0).contiguous() # load state_dict self.load_state_dict(state_dict, False)
def init_weights(self): trunc_normal_(self.pos_embed, std=0.02)
def init_weights(self): super(WindowMSA, self).init_weights() trunc_normal_(self.relative_position_bias_table, std=0.02)
def _init_weights(m): if isinstance(m, nn.Conv2d): trunc_normal_(m.weight, std=.02) if m.bias is not None: nn.init.constant_(m.bias, 0)
def init_weights(self): super(VisionTransformer, self).init_weights() if not (isinstance(self.init_cfg, dict) and self.init_cfg['type'] == 'Pretrained'): trunc_normal_(self.pos_embed, std=0.02)
def __init__(self, arch='tiny', patch_size=16, base_channels=64, mlp_ratio=4., qkv_bias=True, with_cls_token=True, drop_path_rate=0., norm_eval=True, frozen_stages=0, out_indices=-1, init_cfg=None): super().__init__(init_cfg=init_cfg) if isinstance(arch, str): arch = arch.lower() assert arch in set(self.arch_zoo), \ f'Arch {arch} is not in default archs {set(self.arch_zoo)}' self.arch_settings = self.arch_zoo[arch] else: essential_keys = { 'embed_dims', 'depths', 'num_heads', 'channel_ratio' } assert isinstance(arch, dict) and set(arch) == essential_keys, \ f'Custom arch needs a dict with keys {essential_keys}' self.arch_settings = arch self.num_features = self.embed_dims = self.arch_settings['embed_dims'] self.depths = self.arch_settings['depths'] self.num_heads = self.arch_settings['num_heads'] self.channel_ratio = self.arch_settings['channel_ratio'] if isinstance(out_indices, int): out_indices = [out_indices] assert isinstance(out_indices, Sequence), \ f'"out_indices" must by a sequence or int, ' \ f'get {type(out_indices)} instead.' for i, index in enumerate(out_indices): if index < 0: out_indices[i] = self.depths + index + 1 assert out_indices[i] >= 0, f'Invalid out_indices {index}' self.out_indices = out_indices self.norm_eval = norm_eval self.frozen_stages = frozen_stages self.with_cls_token = with_cls_token if self.with_cls_token: self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) # stochastic depth decay rule self.trans_dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, self.depths) ] # Stem stage: get the feature maps by conv block self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) # 1 / 2 [112, 112] self.bn1 = nn.BatchNorm2d(64) self.act1 = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # 1 / 4 [56, 56] assert patch_size % 16 == 0, 'The patch size of Conformer must ' \ 'be divisible by 16.' trans_down_stride = patch_size // 4 # To solve the issue #680 # Auto pad the feature map to be divisible by trans_down_stride self.auto_pad = AdaptivePadding(trans_down_stride, trans_down_stride) # 1 stage stage1_channels = int(base_channels * self.channel_ratio) self.conv_1 = ConvBlock(in_channels=64, out_channels=stage1_channels, with_residual_conv=True, stride=1) self.trans_patch_conv = nn.Conv2d(64, self.embed_dims, kernel_size=trans_down_stride, stride=trans_down_stride, padding=0) self.trans_1 = TransformerEncoderLayer( embed_dims=self.embed_dims, num_heads=self.num_heads, feedforward_channels=int(self.embed_dims * mlp_ratio), drop_path_rate=self.trans_dpr[0], qkv_bias=qkv_bias, norm_cfg=dict(type='LN', eps=1e-6)) # 2~4 stage init_stage = 2 fin_stage = self.depths // 3 + 1 for i in range(init_stage, fin_stage): self.add_module( f'conv_trans_{i}', ConvTransBlock(in_channels=stage1_channels, out_channels=stage1_channels, embed_dims=self.embed_dims, conv_stride=1, with_residual_conv=False, down_stride=trans_down_stride, num_heads=self.num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_path_rate=self.trans_dpr[i - 1], with_cls_token=self.with_cls_token)) stage2_channels = int(base_channels * self.channel_ratio * 2) # 5~8 stage init_stage = fin_stage # 5 fin_stage = fin_stage + self.depths // 3 # 9 for i in range(init_stage, fin_stage): if i == init_stage: conv_stride = 2 in_channels = stage1_channels else: conv_stride = 1 in_channels = stage2_channels with_residual_conv = True if i == init_stage else False self.add_module( f'conv_trans_{i}', ConvTransBlock(in_channels=in_channels, out_channels=stage2_channels, embed_dims=self.embed_dims, conv_stride=conv_stride, with_residual_conv=with_residual_conv, down_stride=trans_down_stride // 2, num_heads=self.num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_path_rate=self.trans_dpr[i - 1], with_cls_token=self.with_cls_token)) stage3_channels = int(base_channels * self.channel_ratio * 2 * 2) # 9~12 stage init_stage = fin_stage # 9 fin_stage = fin_stage + self.depths // 3 # 13 for i in range(init_stage, fin_stage): if i == init_stage: conv_stride = 2 in_channels = stage2_channels with_residual_conv = True else: conv_stride = 1 in_channels = stage3_channels with_residual_conv = False last_fusion = (i == self.depths) self.add_module( f'conv_trans_{i}', ConvTransBlock(in_channels=in_channels, out_channels=stage3_channels, embed_dims=self.embed_dims, conv_stride=conv_stride, with_residual_conv=with_residual_conv, down_stride=trans_down_stride // 4, num_heads=self.num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_path_rate=self.trans_dpr[i - 1], with_cls_token=self.with_cls_token, last_fusion=last_fusion)) self.fin_stage = fin_stage self.pooling = nn.AdaptiveAvgPool2d(1) self.trans_norm = nn.LayerNorm(self.embed_dims) if self.with_cls_token: trunc_normal_(self.cls_token, std=.02)