def __init__(self): super(L2L, self).__init__() self.num_features = 4 self.pos_embed = nn.Parameter(torch.zeros(1, 68, 4)) self.pos_drop = nn.Dropout(p=0.1) trunc_normal_(self.pos_embed, std=.02) self.audio_encoder = nn.Sequential( Conv2d(1, 32, kernel_size=3, stride=1, padding=1), Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(64, 128, kernel_size=3, stride=3, padding=1), Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(256, 512, kernel_size=3, stride=1, padding=0), Conv2d(512, 512, kernel_size=1, stride=1, padding=0), ) self.encoder_layer1 = nn.TransformerEncoderLayer(d_model=68, nhead=2) #self.encoder_layer2 = nn.TransformerEncoderLayer(d_model=208, nhead=2) self.transformer_encoder1 = nn.TransformerEncoder(self.encoder_layer1, num_layers=3) #self.transformer_encoder2 = nn.TransformerEncoder(self.encoder_layer2, num_layers=3) #self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=3) self.head1 = nn.Linear(self.num_features, 2) self.mlp = nn.Linear(512, 40) # 입부분만
def __init__(self, input_dim, output_dim, head_dim, window_size, type): super(WMSA, self).__init__() self.input_dim = input_dim self.output_dim = output_dim self.head_dim = head_dim self.scale = self.head_dim**-0.5 self.n_heads = input_dim // head_dim self.window_size = window_size self.type = type self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True) # TODO recover # self.relative_position_params = nn.Parameter(torch.zeros(self.n_heads, 2 * window_size - 1, 2 * window_size -1)) self.relative_position_params = nn.Parameter( torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads)) self.linear = nn.Linear(self.input_dim, self.output_dim) trunc_normal_(self.relative_position_params, std=.02) self.relative_position_params = torch.nn.Parameter( self.relative_position_params.view( 2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1, 2).transpose(0, 1))
def __init__(self, img_size=224, tokens_type='performer', in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, token_dim=64): super().__init__() self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.tokens_to_token = T2T_module( img_size=img_size, tokens_type=tokens_type, in_chans=in_chans, embed_dim=embed_dim, token_dim=token_dim) num_patches = self.tokens_to_token.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(data=get_sinusoid_encoding(n_position=num_patches + 1, d_hid=embed_dim), requires_grad=False) self.pos_drop = nn.Dropout(p=drop_rate) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.blocks = nn.ModuleList([ Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) for i in range(depth)]) self.norm = norm_layer(embed_dim) # Classifier head self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() trunc_normal_(self.cls_token, std=.02) self.apply(self._init_weights)
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) trunc_normal_(self.relative_position_bias_table, std=.02) self.softmax = nn.Softmax(dim=-1)
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, use_checkpoint=False, **kwargs): super().__init__() self.num_layers = len(depths) self.embed_dim = embed_dim self.ape = ape self.patch_norm = patch_norm self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) self.mlp_ratio = mlp_ratio # split image into non-overlapping patches self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None) num_patches = self.patch_embed.num_patches patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution # absolute position embedding if self.ape: self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) trunc_normal_(self.absolute_pos_embed, std=.02) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), input_resolution=(patches_resolution[0] // (2 ** i_layer), patches_resolution[1] // (2 ** i_layer)), depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint) self.layers.append(layer) #self.norm = norm_layer(self.num_features) #self.avgpool = nn.AdaptiveAvgPool1d(1) self.output_layer = nn.Sequential(norm_layer(self.num_features), Flatten(), nn.Linear(49*768, 512), nn.BatchNorm1d(512)) self.apply(self._init_weights)
def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) num_patches = self.patch_embed.num_patches self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim)) trunc_normal_(self.dist_token, std=0.02) trunc_normal_(self.pos_embed, std=0.02)
def __init__(self, height, width, embed_dim): super().__init__() self.height = height self.width = width self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, height, width)) self.cls_pos_embed = nn.Parameter(torch.zeros(1, 1, embed_dim)) trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_pos_embed, std=.02)
def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0)
def weight_initialization(self): for m in self.modules(): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0)
def __init__(self, args, dictionary): super().__init__(dictionary) img_size = args.vit_img_size patch_size = args.vit_patch_size in_chans = args.vit_channels embed_dim = args.vit_dim depth = args.vit_depth num_heads = args.vit_heads mlp_ratio = 4. qkv_bias = True qk_scale = None drop_rate = args.vit_dropout attn_drop_rate = args.vit_atten_dropout drop_path_rate = 0. hybrid_backbone = None norm_layer = None self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) if hybrid_backbone is not None: self.patch_embed = HybridEmbed(hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) else: self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter( torch.zeros(1, num_patches + 1, embed_dim)) self.pos_drop = nn.Dropout(p=drop_rate) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) ] # stochastic depth decay rule self.blocks = nn.ModuleList([ Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) for i in range(depth) ]) self.norm = norm_layer(embed_dim) trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) self.apply(self._init_weights)
def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) num_patches = self.patch_embed.num_patches self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim)) self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() trunc_normal_(self.dist_token, std=.02) trunc_normal_(self.pos_embed, std=.02) self.head_dist.apply(self._init_weights)
def _init_weights(self, m): if isinstance(m, nn.Conv2d): trunc_normal_(m.weight, std=0.02) elif isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): nn.init.constant_(m.weight, 1.0) nn.init.constant_(m.bias, 0)
def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) num_patches = self.patch_embed.num_patches self.pos_embed = nn.Parameter( torch.zeros(1, num_patches + 1, self.embed_dim)) trunc_normal_(self.pos_embed, std=.02) self.output1 = nn.Sequential(nn.ReLU(), nn.Dropout(0.5), nn.Linear(1000, 1)) self.output1.apply(self._init_weights)
def __init__(self, head_embed_dim, length=14,) -> None: super().__init__() self.head_embed_dim = head_embed_dim self.legnth = length self.embeddings_table_v = nn.Parameter( torch.randn(length * 2 + 2, head_embed_dim)) self.embeddings_table_h = nn.Parameter( torch.randn(length * 2 + 2, head_embed_dim)) trunc_normal_(self.embeddings_table_v, std=.02) trunc_normal_(self.embeddings_table_h, std=.02)
def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.cls_token = nn.Parameter(torch.randn( 1, 2, self.base_dims[0] * self.heads[0]), requires_grad=True) if self.num_classes > 0: self.head_dist = nn.Linear(self.base_dims[-1] * self.heads[-1], self.num_classes) else: self.head_dist = nn.Identity() trunc_normal_(self.cls_token, std=0.02) self.head_dist.apply(self._init_weights)
def weights_init(m): if isinstance(m, nn.Conv2d): # xavier(m.weight.data) m.weight.data.normal_(0, 0.01) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0)
def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_()
def _init_weights(m): if isinstance(m, nn.Conv2d): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Conv2d) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) else: print(f'Warning: {type(m)} uses default initialization...')
def __init__(self, patch_size, nx, ny, in_chans=3, embed_dim=768, nglo=1, norm_layer=nn.LayerNorm, norm_embed=True, drop_rate=0.0, ape=True): # maximal global/x-direction/y-direction tokens: nglo, nx, ny super().__init__() patch_size = to_2tuple(patch_size) self.patch_size = patch_size self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm_embed = norm_layer(embed_dim) if norm_embed else None self.nx = nx self.ny = ny self.Nglo = nglo if nglo >= 1: self.cls_token = nn.Parameter(torch.zeros(1, nglo, embed_dim)) trunc_normal_(self.cls_token, std=.02) else: self.cls_token = None self.ape = ape if ape: self.cls_pos_embed = nn.Parameter(torch.zeros(1, nglo, embed_dim)) self.x_pos_embed = nn.Parameter(torch.zeros(1, nx, embed_dim // 2)) self.y_pos_embed = nn.Parameter(torch.zeros(1, ny, embed_dim // 2)) trunc_normal_(self.cls_pos_embed, std=.02) trunc_normal_(self.x_pos_embed, std=.02) trunc_normal_(self.y_pos_embed, std=.02) self.pos_drop = nn.Dropout(p=drop_rate)
def __init__(self, configs=None, img_size=224, in_chans=3, num_classes=1000, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, se=0): super().__init__() self.num_classes = num_classes depths = configs['depths'] outer_dims = configs['outer_dims'] inner_dims = configs['inner_dims'] outer_heads = configs['outer_heads'] inner_heads = configs['inner_heads'] sr_ratios = [4, 2, 1, 1] dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule self.num_features = outer_dims[-1] # num_features for consistency with other models self.patch_embed = Stem( img_size=img_size, in_chans=in_chans, outer_dim=outer_dims[0], inner_dim=inner_dims[0]) num_patches = self.patch_embed.num_patches num_words = self.patch_embed.num_words self.outer_pos = nn.Parameter(torch.zeros(1, num_patches, outer_dims[0])) self.inner_pos = nn.Parameter(torch.zeros(1, num_words, inner_dims[0])) self.pos_drop = nn.Dropout(p=drop_rate) depth = 0 self.word_merges = nn.ModuleList([]) self.sentence_merges = nn.ModuleList([]) self.stages = nn.ModuleList([]) for i in range(4): if i > 0: self.word_merges.append(WordAggregation(inner_dims[i-1], inner_dims[i], stride=2)) self.sentence_merges.append(SentenceAggregation(outer_dims[i-1], outer_dims[i], stride=2)) self.stages.append(Stage(depths[i], outer_dim=outer_dims[i], inner_dim=inner_dims[i], outer_head=outer_heads[i], inner_head=inner_heads[i], num_patches=num_patches // (2 ** i) // (2 ** i), num_words=num_words, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[depth:depth+depths[i]], norm_layer=norm_layer, se=se, sr_ratio=sr_ratios[i]) ) depth += depths[i] self.norm = norm_layer(outer_dims[-1]) # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here # self.repr = nn.Linear(outer_dim, representation_size) # self.repr_act = nn.Tanh() # Classifier head self.head = nn.Linear(outer_dims[-1], num_classes) if num_classes > 0 else nn.Identity() trunc_normal_(self.outer_pos, std=.02) trunc_normal_(self.inner_pos, std=.02) self.apply(self._init_weights)
def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1.) nn.init.constant_(m.bias, 0.) elif isinstance(m, nn.GroupNorm): nn.init.constant_(m.weight, 1.) nn.init.constant_(m.bias, 0.)
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None, ape=False, mask_ratio=0.0): super().__init__(img_size=img_size, patch_size=patch_size, in_chans=in_chans, num_classes=num_classes, embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, representation_size=representation_size, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate, hybrid_backbone=hybrid_backbone, norm_layer=norm_layer) self.ape = ape self.mask_ratio = mask_ratio self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) self.patch_embed = PatchEmbedForApe( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches if not self.ape else 576 self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim)) self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() trunc_normal_(self.dist_token, std=.02) trunc_normal_(self.pos_embed, std=.02) self.head_dist.apply(self._init_weights)
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, group = False, re_atten=True, cos_reg = False, use_cnn_embed=False, apply_transform=None, transform_scale=False, scale_adjustment=1.): super().__init__() self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models # use cosine similarity as a regularization term self.cos_reg = cos_reg if hybrid_backbone is not None: self.patch_embed = HybridEmbed( hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) else: if use_cnn_embed: self.patch_embed = PatchEmbed_CNN(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) else: self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) self.pos_drop = nn.Dropout(p=drop_rate) d = depth if isinstance(depth, int) else len(depth) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, d)] # stochastic depth decay rule self.blocks = nn.ModuleList([ Block( dim=embed_dim, share=depth[i], num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, group = group, re_atten=re_atten, apply_transform=apply_transform[i], transform_scale=transform_scale, scale_adjustment=scale_adjustment) for i in range(len(depth))]) self.norm = norm_layer(embed_dim) # Classifier head self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) self.apply(self._init_weights)
def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) if mode == 'fan_in': denom = fan_in elif mode == 'fan_out': denom = fan_out elif mode == 'fan_avg': denom = (fan_in + fan_out) / 2 variance = scale / denom if distribution == "truncated_normal": # constant is stddev of standard normal truncated to (-2, 2) trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978) elif distribution == "normal": tensor.normal_(std=math.sqrt(variance)) elif distribution == "uniform": bound = math.sqrt(3 * variance) tensor.uniform_(-bound, bound) else: raise ValueError(f"invalid distribution {distribution}")
def __init__( self, image_size, patch_size, n_layers, d_model, d_ff, n_heads, n_cls, dropout=0.1, drop_path_rate=0.0, distilled=False, channels=3, ): super().__init__() self.patch_embed = PatchEmbedding( image_size, patch_size, d_model, channels, ) self.patch_size = patch_size self.n_layers = n_layers self.d_model = d_model self.d_ff = d_ff self.n_heads = n_heads self.dropout = nn.Dropout(dropout) self.n_cls = n_cls # cls and pos tokens self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) self.distilled = distilled if self.distilled: self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model)) self.pos_embed = nn.Parameter( torch.randn(1, self.patch_embed.num_patches + 2, d_model) ) self.head_dist = nn.Linear(d_model, n_cls) else: self.pos_embed = nn.Parameter( torch.randn(1, self.patch_embed.num_patches + 1, d_model) ) # transformer blocks dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)] self.blocks = nn.ModuleList( [Block(d_model, n_heads, d_ff, dropout, dpr[i]) for i in range(n_layers)] ) # output head self.norm = nn.LayerNorm(d_model) self.head = nn.Linear(d_model, n_cls) trunc_normal_(self.pos_embed, std=0.02) trunc_normal_(self.cls_token, std=0.02) if self.distilled: trunc_normal_(self.dist_token, std=0.02) self.pre_logits = nn.Identity() self.apply(init_weights)
def __init__( self, n_cls, patch_size, d_encoder, n_layers, n_heads, d_model, d_ff, drop_path_rate, dropout, ): super().__init__() self.d_encoder = d_encoder self.patch_size = patch_size self.n_cls = n_cls self.d_model = d_model self.d_ff = d_ff self.scale = d_model**-0.5 dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)] self.blocks = nn.ModuleList([ Block(d_model, n_heads, d_ff, dropout, dpr[i]) for i in range(n_layers) ]) self.cls_emb = nn.Parameter(torch.randn(1, n_cls, d_model)) self.proj_dec = nn.Linear(d_encoder, d_model) self.proj_patch = nn.Parameter(self.scale * torch.randn(d_model, d_model)) self.proj_classes = nn.Parameter(self.scale * torch.randn(d_model, d_model)) self.decoder_norm = nn.LayerNorm(d_model) self.mask_norm = nn.LayerNorm(n_cls) self.apply(init_weights) trunc_normal_(self.cls_emb, std=0.02)
def __init__(self, *, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, positional_encoding='learned', learned_positional_encoding_size=(14, 14), block_cls=LinearBlock): super().__init__() # Config self.num_classes = num_classes self.patch_size = patch_size self.num_features = self.embed_dim = embed_dim # Patch embedding self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) # Class token self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # Positional encoding if positional_encoding == 'learned': height, width = self.learned_positional_encoding_size = learned_positional_encoding_size self.pos_encoding = LearnedPositionalEncoding(height, width, embed_dim) else: raise NotImplementedError('Unsupposed positional encoding') self.pos_drop = nn.Dropout(p=drop_rate) # Stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] self.blocks = nn.ModuleList([ block_cls(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, num_tokens=1 + (224 // patch_size)**2) for i in range(depth)]) self.norm = norm_layer(embed_dim) # Classifier head self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() # Init trunc_normal_(self.cls_token, std=.02) self.apply(self._init_weights)
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, rpe_config=None): super().__init__() self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models if hybrid_backbone is not None: self.patch_embed = HybridEmbed( hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) else: self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) self.pos_drop = nn.Dropout(p=drop_rate) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.blocks = nn.ModuleList([ RPEBlock( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, rpe_config=rpe_config) for i in range(depth)]) self.norm = norm_layer(embed_dim) # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here #self.repr = nn.Linear(embed_dim, representation_size) #self.repr_act = nn.Tanh() # Classifier head self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) self.apply(self._init_weights)
def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = False): """ ViT weight initialization * When called without n, head_bias, jax_impl args it will behave exactly the same as my original init for compatibility with prev hparam / downstream use cases (ie DeiT). * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl """ if isinstance(m, nn.Linear): if n.startswith('head'): nn.init.zeros_(m.weight) nn.init.constant_(m.bias, head_bias) elif n.startswith('pre_logits'): lecun_normal_(m.weight) nn.init.zeros_(m.bias) else: if jax_impl: nn.init.xavier_uniform_(m.weight) if m.bias is not None: if 'mlp' in n: nn.init.normal_(m.bias, std=1e-6) else: nn.init.zeros_(m.bias) else: trunc_normal_(m.weight, std=.02) if m.bias is not None: nn.init.zeros_(m.bias) elif jax_impl and isinstance(m, nn.Conv2d): # NOTE conv was left to pytorch default in my original init lecun_normal_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.LayerNorm): nn.init.zeros_(m.bias) nn.init.ones_(m.weight)
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., rpe=False, wx=14, wy=14, nglo=1): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) # Inspired by swin transformer: # https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py#L88-L103 # define parameter tables for local and global relative position bias self.rpe = rpe if rpe: self.wx = wx self.wy = wy self.nglo = nglo self.local_relative_position_bias_table = nn.Parameter( torch.zeros((2 * wx - 1) * (2 * wy - 1), num_heads)) # (2*wx-1, 2*wy-1, nH) trunc_normal_(self.local_relative_position_bias_table, std=.02) if nglo >= 1: self.g2l_relative_position_bias = nn.Parameter( torch.zeros(2, num_heads, nglo)) # (2, nH, nglo) self.g2g_relative_position_bias = nn.Parameter( torch.zeros(num_heads, nglo, nglo)) # (nH, nglo, nglo) trunc_normal_(self.g2l_relative_position_bias, std=.02) trunc_normal_(self.g2g_relative_position_bias, std=.02) # get pair-wise relative position index coords_h = torch.arange(wx) coords_w = torch.arange(wy) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, wx, wy coords_flatten = torch.flatten(coords, 1) # 2, Wx*Wy relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wx*Wy, Wx*Wy relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wx*Wy, Wx*Wy, 2 relative_coords[:, :, 0] += wx - 1 # shift to start from 0 relative_coords[:, :, 1] += wy - 1 relative_coords[:, :, 0] *= 2 * wy - 1 relative_position_index = relative_coords.sum(-1) # Wx*Wy, Wx*Wy self.register_buffer("relative_position_index", relative_position_index)