Esempio n. 1
0
 def _init_weights(self, m):
     print("Initialization...")
     if isinstance(m, nn.Linear):
         trunc_normal_(m.weight, std=.02)
         if isinstance(m, nn.Linear) and m.bias is not None:
             nn.init.constant_(m.bias, 0)
     elif isinstance(m, nn.LayerNorm):
         nn.init.constant_(m.bias, 0)
         nn.init.constant_(m.weight, 1.0)
Esempio n. 2
0
 def _make_position_embedding(self, w, h, d_model, pe_type='sine'):
     '''
     d_model: embedding size in transformer encoder
     '''
     assert pe_type in ['none', 'learnable', 'sine', 'sine-full']
     if pe_type == 'none':
         self.pos_embedding = None
         print("==> Without any PositionEmbedding~")
     else:
         with torch.no_grad():
             self.pe_h = h
             self.pe_w = w
             length = self.pe_h * self.pe_w
         if pe_type == 'learnable':
             self.pos_embedding = nn.Parameter(
                 torch.zeros(1, self.num_patches + self.num_keypoints,
                             d_model))
             trunc_normal_(self.pos_embedding, std=.02)
             print("==> Add Learnable PositionEmbedding~")
         else:
             self.pos_embedding = nn.Parameter(
                 self._make_sine_position_embedding(d_model),
                 requires_grad=False)
             print("==> Add Sine PositionEmbedding~")
Esempio n. 3
0
    def __init__(self,
                 *,
                 feature_size,
                 patch_size,
                 num_keypoints,
                 dim,
                 depth,
                 heads,
                 mlp_dim,
                 apply_init=False,
                 hidden_heatmap_dim=64 * 6,
                 heatmap_dim=64 * 48,
                 heatmap_size=[64, 48],
                 channels=3,
                 dropout=0.,
                 emb_dropout=0.,
                 pos_embedding_type="learnable"):
        super().__init__()
        assert isinstance(feature_size, list) and isinstance(
            patch_size, list), 'image_size and patch_size should be list'
        assert feature_size[0] % patch_size[0] == 0 and feature_size[
            1] % patch_size[
                1] == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (feature_size[0] //
                       (patch_size[0])) * (feature_size[1] // (patch_size[1]))
        patch_dim = channels * patch_size[0] * patch_size[1]
        assert pos_embedding_type in ['sine', 'learnable', 'sine-full']

        self.inplanes = 64
        self.patch_size = patch_size
        self.heatmap_size = heatmap_size
        self.num_keypoints = num_keypoints
        self.num_patches = num_patches
        self.pos_embedding_type = pos_embedding_type
        self.all_attn = (self.pos_embedding_type == "sine-full")

        self.keypoint_token = nn.Parameter(
            torch.zeros(1, self.num_keypoints, dim))
        h, w = feature_size[0] // (self.patch_size[0]), feature_size[1] // (
            self.patch_size[1])

        # for normal
        self._make_position_embedding(w, h, dim, pos_embedding_type)

        self.patch_to_embedding = nn.Linear(patch_dim, dim)
        self.dropout = nn.Dropout(emb_dropout)

        # transformer
        self.transformer1 = Transformer(dim,
                                        depth,
                                        heads,
                                        mlp_dim,
                                        dropout,
                                        num_keypoints=num_keypoints,
                                        all_attn=self.all_attn,
                                        scale_with_head=True)
        self.transformer2 = Transformer(dim,
                                        depth,
                                        heads,
                                        mlp_dim,
                                        dropout,
                                        num_keypoints=num_keypoints,
                                        all_attn=self.all_attn,
                                        scale_with_head=True)
        self.transformer3 = Transformer(dim,
                                        depth,
                                        heads,
                                        mlp_dim,
                                        dropout,
                                        num_keypoints=num_keypoints,
                                        all_attn=self.all_attn,
                                        scale_with_head=True)

        self.to_keypoint_token = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim * 3), nn.Linear(dim * 3, hidden_heatmap_dim),
            nn.LayerNorm(hidden_heatmap_dim),
            nn.Linear(hidden_heatmap_dim, heatmap_dim)) if (
                dim * 3 <= hidden_heatmap_dim * 0.5
                and apply_multi) else nn.Sequential(
                    nn.LayerNorm(dim * 3), nn.Linear(dim * 3, heatmap_dim))
        trunc_normal_(self.keypoint_token, std=.02)
        if apply_init:
            self.apply(self._init_weights)
Esempio n. 4
0
    def __init__(self,
                 *,
                 image_size,
                 patch_size,
                 num_keypoints,
                 dim,
                 depth,
                 heads,
                 mlp_dim,
                 apply_init=False,
                 apply_multi=True,
                 hidden_heatmap_dim=64 * 6,
                 heatmap_dim=64 * 48,
                 heatmap_size=[64, 48],
                 channels=3,
                 dropout=0.,
                 emb_dropout=0.,
                 pos_embedding_type="learnable"):
        super().__init__()
        assert isinstance(image_size, list) and isinstance(
            patch_size, list), 'image_size and patch_size should be list'
        assert image_size[0] % patch_size[0] == 0 and image_size[
            1] % patch_size[
                1] == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_size[0] //
                       (4 * patch_size[0])) * (image_size[1] //
                                               (4 * patch_size[1]))
        patch_dim = channels * patch_size[0] * patch_size[1]
        assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'
        assert pos_embedding_type in ['sine', 'none', 'learnable', 'sine-full']

        self.inplanes = 64
        self.patch_size = patch_size
        self.heatmap_size = heatmap_size
        self.num_keypoints = num_keypoints
        self.num_patches = num_patches
        self.pos_embedding_type = pos_embedding_type
        self.all_attn = (self.pos_embedding_type == "sine-full")

        self.keypoint_token = nn.Parameter(
            torch.zeros(1, self.num_keypoints, dim))
        h, w = image_size[0] // (4 * self.patch_size[0]), image_size[1] // (
            4 * self.patch_size[1])
        self._make_position_embedding(w, h, dim, pos_embedding_type)

        self.patch_to_embedding = nn.Linear(patch_dim, dim)
        self.dropout = nn.Dropout(emb_dropout)

        # stem net
        self.conv1 = nn.Conv2d(3,
                               64,
                               kernel_size=3,
                               stride=2,
                               padding=1,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
        self.conv2 = nn.Conv2d(64,
                               64,
                               kernel_size=3,
                               stride=2,
                               padding=1,
                               bias=False)
        self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(Bottleneck, 64, 4)

        # transformer
        self.transformer = Transformer(dim,
                                       depth,
                                       heads,
                                       mlp_dim,
                                       dropout,
                                       num_keypoints=num_keypoints,
                                       all_attn=self.all_attn)

        self.to_keypoint_token = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim), nn.Linear(dim, hidden_heatmap_dim),
            nn.LayerNorm(hidden_heatmap_dim),
            nn.Linear(hidden_heatmap_dim, heatmap_dim)) if (
                dim <= hidden_heatmap_dim * 0.5
                and apply_multi) else nn.Sequential(
                    nn.LayerNorm(dim), nn.Linear(dim, heatmap_dim))
        trunc_normal_(self.keypoint_token, std=.02)
        if apply_init:
            self.apply(self._init_weights)