def _init_weights(self, m): print("Initialization...") if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0)
def _make_position_embedding(self, w, h, d_model, pe_type='sine'): ''' d_model: embedding size in transformer encoder ''' assert pe_type in ['none', 'learnable', 'sine', 'sine-full'] if pe_type == 'none': self.pos_embedding = None print("==> Without any PositionEmbedding~") else: with torch.no_grad(): self.pe_h = h self.pe_w = w length = self.pe_h * self.pe_w if pe_type == 'learnable': self.pos_embedding = nn.Parameter( torch.zeros(1, self.num_patches + self.num_keypoints, d_model)) trunc_normal_(self.pos_embedding, std=.02) print("==> Add Learnable PositionEmbedding~") else: self.pos_embedding = nn.Parameter( self._make_sine_position_embedding(d_model), requires_grad=False) print("==> Add Sine PositionEmbedding~")
def __init__(self, *, feature_size, patch_size, num_keypoints, dim, depth, heads, mlp_dim, apply_init=False, hidden_heatmap_dim=64 * 6, heatmap_dim=64 * 48, heatmap_size=[64, 48], channels=3, dropout=0., emb_dropout=0., pos_embedding_type="learnable"): super().__init__() assert isinstance(feature_size, list) and isinstance( patch_size, list), 'image_size and patch_size should be list' assert feature_size[0] % patch_size[0] == 0 and feature_size[ 1] % patch_size[ 1] == 0, 'Image dimensions must be divisible by the patch size.' num_patches = (feature_size[0] // (patch_size[0])) * (feature_size[1] // (patch_size[1])) patch_dim = channels * patch_size[0] * patch_size[1] assert pos_embedding_type in ['sine', 'learnable', 'sine-full'] self.inplanes = 64 self.patch_size = patch_size self.heatmap_size = heatmap_size self.num_keypoints = num_keypoints self.num_patches = num_patches self.pos_embedding_type = pos_embedding_type self.all_attn = (self.pos_embedding_type == "sine-full") self.keypoint_token = nn.Parameter( torch.zeros(1, self.num_keypoints, dim)) h, w = feature_size[0] // (self.patch_size[0]), feature_size[1] // ( self.patch_size[1]) # for normal self._make_position_embedding(w, h, dim, pos_embedding_type) self.patch_to_embedding = nn.Linear(patch_dim, dim) self.dropout = nn.Dropout(emb_dropout) # transformer self.transformer1 = Transformer(dim, depth, heads, mlp_dim, dropout, num_keypoints=num_keypoints, all_attn=self.all_attn, scale_with_head=True) self.transformer2 = Transformer(dim, depth, heads, mlp_dim, dropout, num_keypoints=num_keypoints, all_attn=self.all_attn, scale_with_head=True) self.transformer3 = Transformer(dim, depth, heads, mlp_dim, dropout, num_keypoints=num_keypoints, all_attn=self.all_attn, scale_with_head=True) self.to_keypoint_token = nn.Identity() self.mlp_head = nn.Sequential( nn.LayerNorm(dim * 3), nn.Linear(dim * 3, hidden_heatmap_dim), nn.LayerNorm(hidden_heatmap_dim), nn.Linear(hidden_heatmap_dim, heatmap_dim)) if ( dim * 3 <= hidden_heatmap_dim * 0.5 and apply_multi) else nn.Sequential( nn.LayerNorm(dim * 3), nn.Linear(dim * 3, heatmap_dim)) trunc_normal_(self.keypoint_token, std=.02) if apply_init: self.apply(self._init_weights)
def __init__(self, *, image_size, patch_size, num_keypoints, dim, depth, heads, mlp_dim, apply_init=False, apply_multi=True, hidden_heatmap_dim=64 * 6, heatmap_dim=64 * 48, heatmap_size=[64, 48], channels=3, dropout=0., emb_dropout=0., pos_embedding_type="learnable"): super().__init__() assert isinstance(image_size, list) and isinstance( patch_size, list), 'image_size and patch_size should be list' assert image_size[0] % patch_size[0] == 0 and image_size[ 1] % patch_size[ 1] == 0, 'Image dimensions must be divisible by the patch size.' num_patches = (image_size[0] // (4 * patch_size[0])) * (image_size[1] // (4 * patch_size[1])) patch_dim = channels * patch_size[0] * patch_size[1] assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size' assert pos_embedding_type in ['sine', 'none', 'learnable', 'sine-full'] self.inplanes = 64 self.patch_size = patch_size self.heatmap_size = heatmap_size self.num_keypoints = num_keypoints self.num_patches = num_patches self.pos_embedding_type = pos_embedding_type self.all_attn = (self.pos_embedding_type == "sine-full") self.keypoint_token = nn.Parameter( torch.zeros(1, self.num_keypoints, dim)) h, w = image_size[0] // (4 * self.patch_size[0]), image_size[1] // ( 4 * self.patch_size[1]) self._make_position_embedding(w, h, dim, pos_embedding_type) self.patch_to_embedding = nn.Linear(patch_dim, dim) self.dropout = nn.Dropout(emb_dropout) # stem net self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) self.relu = nn.ReLU(inplace=True) self.layer1 = self._make_layer(Bottleneck, 64, 4) # transformer self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout, num_keypoints=num_keypoints, all_attn=self.all_attn) self.to_keypoint_token = nn.Identity() self.mlp_head = nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, hidden_heatmap_dim), nn.LayerNorm(hidden_heatmap_dim), nn.Linear(hidden_heatmap_dim, heatmap_dim)) if ( dim <= hidden_heatmap_dim * 0.5 and apply_multi) else nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, heatmap_dim)) trunc_normal_(self.keypoint_token, std=.02) if apply_init: self.apply(self._init_weights)