Exemplo n.º 1
0
    def __init__(self,
                 img_size=224,
                 patch_size=4,
                 in_chans=3,
                 embed_dim=96,
                 norm_layer=None):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        patches_resolution = [
            img_size[0] // patch_size[0], img_size[1] // patch_size[1]
        ]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_chans,
                              embed_dim,
                              kernel_size=patch_size,
                              stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None
Exemplo n.º 2
0
    def __init__(self,
                 img_size=224,
                 patch_size=16,
                 in_chans=3,
                 embed_dim=768,
                 isBin=True):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)

        self.img_size = img_size
        self.patch_size = patch_size
        assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \
            f"img_size {img_size} should be divided by patch_size {patch_size}."
        self.H, self.W = img_size[0] // patch_size[0], img_size[
            1] // patch_size[1]
        self.num_patches = self.H * self.W
        self.norm = nn.LayerNorm(embed_dim)
        self.isBin = isBin
        if self.isBin:
            self.binary_act = BinaryActivation()
            self.proj = HardBinaryConv(in_chans,
                                       embed_dim,
                                       kernel_size=patch_size[0],
                                       stride=patch_size)
        else:
            self.proj = nn.Conv2d(in_chans,
                                  embed_dim,
                                  kernel_size=patch_size,
                                  stride=patch_size)
Exemplo n.º 3
0
 def __init__(self,
              backbone,
              img_size=224,
              feature_size=None,
              in_chans=3,
              embed_dim=768):
     super().__init__()
     assert isinstance(backbone, nn.Module)
     img_size = to_2tuple(img_size)
     self.img_size = img_size
     self.backbone = backbone
     if feature_size is None:
         with torch.no_grad():
             training = backbone.training
             if training:
                 backbone.eval()
             o = self.backbone(
                 torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
             feature_size = o.shape[-2:]
             feature_dim = o.shape[1]
             backbone.train(training)
     else:
         feature_size = to_2tuple(feature_size)
         feature_dim = self.backbone.feature_info.channels()[-1]
     self.num_patches = feature_size[0] * feature_size[1]
     self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=1)
Exemplo n.º 4
0
    def __init__(self,
                 img_size=224,
                 patch_size=16,
                 in_chans=3,
                 outer_dim=768,
                 inner_dim=24,
                 inner_stride=4):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] //
                                                        patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.inner_dim = inner_dim
        self.num_words = math.ceil(patch_size[0] / inner_stride) * math.ceil(
            patch_size[1] / inner_stride)

        self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        self.proj = nn.Conv2d(in_chans,
                              inner_dim,
                              kernel_size=7,
                              padding=3,
                              stride=inner_stride)
Exemplo n.º 5
0
 def __init__(self,
              backbone,
              img_size=224,
              feature_size=None,
              in_chans=3,
              embed_dim=768):
     super().__init__()
     assert isinstance(backbone, nn.Module)
     img_size = to_2tuple(img_size)
     self.img_size = img_size
     self.backbone = backbone
     if feature_size is None:
         with torch.no_grad():
             # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
             # map for all networks, the feature metadata has reliable channel and stride info, but using
             # stride to calc feature dim requires info about padding of each stage that isn't captured.
             training = backbone.training
             if training:
                 backbone.eval()
             o = self.backbone(
                 torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
             feature_size = o.shape[-2:]
             feature_dim = o.shape[1]
             backbone.train(training)
     else:
         feature_size = to_2tuple(feature_size)
         feature_dim = self.backbone.feature_info.channels()[-1]
     self.num_patches = feature_size[0] * feature_size[1]
     self.proj = nn.Linear(feature_dim, embed_dim)
Exemplo n.º 6
0
 def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
     super().__init__()
     img_size = to_2tuple(img_size)
     patch_size = to_2tuple(patch_size)
     num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
     self.img_size = img_size
     self.patch_size = patch_size
     self.num_patches = num_patches
     self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
Exemplo n.º 7
0
    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim
Exemplo n.º 8
0
    def __init__(self, img_size=256, patch_size = 16, in_channels=3, embed_dim=768):
        super(PatchEmbed, self).__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)

        self.img_size = img_size
        self.patch_size = patch_size
        assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, f"img_size {img_size} should be divided by patch_size {patch_size}"
        self.H,self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
        self.num_patches = self.H * self.W
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = nn.LayerNorm(embed_dim)
Exemplo n.º 9
0
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)

        self.img_size = img_size
        self.patch_size = patch_size
        assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \
            f"img_size {img_size} should be divided by patch_size {patch_size}."
        self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]   # Note: self.H, self.W and self.num_patches are not used
        self.num_patches = self.H * self.W                                            #       since the image size may change on the fly.
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = nn.LayerNorm(embed_dim)
Exemplo n.º 10
0
 def __init__(self, img_size=224, patch_size=4, embed_dim=96):
     super().__init__()
     img_size = to_2tuple(img_size)
     patch_size = to_2tuple(patch_size)
     patches_resolution = [
         img_size[0] // patch_size[0], img_size[1] // patch_size[1]
     ]
     self.img_size = img_size
     self.patch_size = patch_size
     self.patches_resolution = patches_resolution
     self.num_patches = patches_resolution[0] * patches_resolution[1]
     self.embed_dim = embed_dim
     self.proj = nn.Conv2d(embed_dim // 2,
                           embed_dim,
                           kernel_size=4,
                           stride=4)
Exemplo n.º 11
0
    def __init__(self, patch_size, nx, ny, in_chans=3, embed_dim=768, nglo=1,
                 norm_layer=nn.LayerNorm, norm_embed=True, drop_rate=0.0,
                 ape=True):
        # maximal global/x-direction/y-direction tokens: nglo, nx, ny
        super().__init__()
        patch_size = to_2tuple(patch_size)
        self.patch_size = patch_size

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size,
                              stride=patch_size)

        self.norm_embed = norm_layer(embed_dim) if norm_embed else None

        self.nx = nx
        self.ny = ny
        self.Nglo = nglo
        if nglo >= 1:
            self.cls_token = nn.Parameter(torch.zeros(1, nglo, embed_dim))
            trunc_normal_(self.cls_token, std=.02)
        else:
            self.cls_token = None
        self.ape = ape
        if ape:
            self.cls_pos_embed = nn.Parameter(torch.zeros(1, nglo, embed_dim))
            self.x_pos_embed = nn.Parameter(torch.zeros(1, nx, embed_dim // 2))
            self.y_pos_embed = nn.Parameter(torch.zeros(1, ny, embed_dim // 2))
            trunc_normal_(self.cls_pos_embed, std=.02)
            trunc_normal_(self.x_pos_embed, std=.02)
            trunc_normal_(self.y_pos_embed, std=.02)

        self.pos_drop = nn.Dropout(p=drop_rate)
Exemplo n.º 12
0
    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        act_layer=nn.GELU,
        gate_layer=None,
        drop=0.0,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        drop_probs = to_2tuple(drop)

        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop_probs[0])
        if gate_layer is not None:
            assert hidden_features % 2 == 0
            self.gate = gate_layer(hidden_features)
            hidden_features = (hidden_features // 2
                               )  # FIXME base reduction on gate property?
        else:
            self.gate = nn.Identity()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop2 = nn.Dropout(drop_probs[1])
Exemplo n.º 13
0
    def __init__(self, patch_size=16, embed_dim=768):
        super().__init__()
        patch_size = to_2tuple(patch_size)
        self.patch_size = patch_size

        if patch_size[0] == 16:
            self.proj = torch.nn.Sequential(
                conv3x3(3, embed_dim // 8, 2),
                nn.GELU(),
                conv3x3(embed_dim // 8, embed_dim // 4, 2),
                nn.GELU(),
                conv3x3(embed_dim // 4, embed_dim // 2, 2),
                nn.GELU(),
                conv3x3(embed_dim // 2, embed_dim, 2),
            )
        elif patch_size[0] == 8:
            self.proj = torch.nn.Sequential(
                conv3x3(3, embed_dim // 4, 2),
                nn.GELU(),
                conv3x3(embed_dim // 4, embed_dim // 2, 2),
                nn.GELU(),
                conv3x3(embed_dim // 2, embed_dim, 2),
            )
        else:
            raise (
                "For convolutional projection, patch size has to be in [8, 16]"
            )
Exemplo n.º 14
0
 def __init__(
     self,
     dim,
     seq_len,
     mlp_ratio=(0.5, 4.0),
     mlp_layer=Mlp,
     norm_layer=partial(nn.LayerNorm, eps=1e-6),
     act_layer=nn.GELU,
     drop=0.0,
     drop_path=0.0,
 ):
     super().__init__()
     tokens_dim, channels_dim = [int(x * dim) for x in to_2tuple(mlp_ratio)]
     self.norm1 = norm_layer(dim)
     self.mlp_tokens = mlp_layer(seq_len,
                                 tokens_dim,
                                 act_layer=act_layer,
                                 drop=drop)
     self.drop_path = DropPath(
         drop_path) if drop_path > 0.0 else nn.Identity()
     self.norm2 = norm_layer(dim)
     self.mlp_channels = mlp_layer(dim,
                                   channels_dim,
                                   act_layer=act_layer,
                                   drop=drop)
Exemplo n.º 15
0
    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            # if window size is larger than input resolution, we don't partition windows
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        if self.shift_size > 0:
            attn_mask = self.calculate_mask(self.input_resolution)
        else:
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)
Exemplo n.º 16
0
 def __init__(self, img_size=224, in_chans=3, outer_dim=768, inner_dim=24):
     super().__init__()
     img_size = to_2tuple(img_size)
     self.img_size = img_size
     self.inner_dim = inner_dim
     self.num_patches = img_size[0] // 8 * img_size[1] // 8
     self.num_words = 16
     
     self.common_conv = nn.Sequential(
         nn.Conv2d(in_chans, inner_dim*2, 3, stride=2, padding=1),
         nn.BatchNorm2d(inner_dim*2),
         nn.ReLU(inplace=True),
     )
     self.inner_convs = nn.Sequential(
         nn.Conv2d(inner_dim*2, inner_dim, 3, stride=1, padding=1),
         nn.BatchNorm2d(inner_dim),
         nn.ReLU(inplace=False),
     )
     self.outer_convs = nn.Sequential(
         nn.Conv2d(inner_dim*2, inner_dim*4, 3, stride=2, padding=1),
         nn.BatchNorm2d(inner_dim*4),
         nn.ReLU(inplace=True),
         nn.Conv2d(inner_dim*4, inner_dim*8, 3, stride=2, padding=1),
         nn.BatchNorm2d(inner_dim*8),
         nn.ReLU(inplace=True),
         nn.Conv2d(inner_dim*8, outer_dim, 3, stride=1, padding=1),
         nn.BatchNorm2d(outer_dim),
         nn.ReLU(inplace=False),
     )
     
     self.unfold = nn.Unfold(kernel_size=4, padding=0, stride=4)
Exemplo n.º 17
0
    def __init__(self, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        patch_size = to_2tuple(patch_size)

        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = nn.LayerNorm(embed_dim)
    def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
        super().__init__()
        
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        
        assert max(patch_size) > stride, "Set larger patch_size than stride"
        
        self.img_size = img_size
        self.patch_size = patch_size
        self.H, self.W = img_size[0] // stride, img_size[1] // stride
        self.num_patches = self.H * self.W
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
                              padding=(patch_size[0] // 2, patch_size[1] // 2))
        self.norm = nn.LayerNorm(embed_dim)

        self.apply(self._init_weights)
Exemplo n.º 19
0
    def __init__(self, patch_size=16, in_ch=3, out_ch=768, with_pos=False):
        super().__init__()
        self.patch_size = to_2tuple(patch_size)
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=patch_size+1, stride=patch_size, padding=patch_size // 2)
        self.norm = nn.BatchNorm2d(out_ch)

        self.with_pos = with_pos
        if self.with_pos:
            self.pos = PA(out_ch)
Exemplo n.º 20
0
    def __init__(self, img_size: int = 224, patch_size: int = 16, in_chans: int = 3, embed_dim: int = 768):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Sequential(
            conv3x3(in_chans, embed_dim // 8, 2),
            nn.GELU(),
            conv3x3(embed_dim // 8, embed_dim // 4, 2),
            nn.GELU(),
            conv3x3(embed_dim // 4, embed_dim // 2, 2),
            nn.GELU(),
            conv3x3(embed_dim // 2, embed_dim, 2),
        )
Exemplo n.º 21
0
    def __init__(
        self,
        dim,
        input_resolution,
        num_heads,
        group_size=7,
        lsda_flag=0,
        mlp_ratio=4.0,
        qkv_bias=True,
        qk_scale=None,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        num_patch_size=1,
    ):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.group_size = group_size
        self.lsda_flag = lsda_flag
        self.mlp_ratio = mlp_ratio
        self.num_patch_size = num_patch_size
        if min(self.input_resolution) <= self.group_size:
            # if group size is larger than input resolution, we don't partition groups
            self.lsda_flag = 0
            self.group_size = min(self.input_resolution)

        self.norm1 = norm_layer(dim)

        self.attn = Attention(
            dim,
            group_size=to_2tuple(self.group_size),
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
            position_bias=True,
        )

        self.drop_path = DropPath(
            drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )

        attn_mask = None
        self.register_buffer("attn_mask", attn_mask)
    def __init__(self,
                 img_size=224,
                 patch_size=16,
                 in_chans=3,
                 embed_dim=768,
                 norm_layer=None):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.patch_grid = (img_size[0] // patch_size[0],
                           img_size[1] // patch_size[1])
        self.num_patches = self.patch_grid[0] * self.patch_grid[1]

        self.proj = nn.Conv2d(in_chans,
                              embed_dim,
                              kernel_size=patch_size,
                              stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
Exemplo n.º 23
0
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()

        new_patch_size = to_2tuple(patch_size // 2)

        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] //
                                                        patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.embed_dim = embed_dim

        self.conv1 = nn.Conv2d(in_chans,
                               128,
                               kernel_size=7,
                               stride=2,
                               padding=3,
                               bias=False)  # 112x112
        self.bn1 = nn.BatchNorm2d(128)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(128,
                               128,
                               kernel_size=3,
                               stride=1,
                               padding=1,
                               bias=False)  # 112x112
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128,
                               128,
                               kernel_size=3,
                               stride=1,
                               padding=1,
                               bias=False)
        self.bn3 = nn.BatchNorm2d(128)

        self.proj = nn.Conv2d(128,
                              embed_dim,
                              kernel_size=new_patch_size,
                              stride=new_patch_size)
Exemplo n.º 24
0
    def __init__(self,
                 img_size=224,
                 patch_size=16,
                 in_chans=3,
                 embed_dim=768,
                 attn_cfg=None):
        super().__init__(**attn_cfg)

        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)

        self.img_size = img_size
        self.patch_size = patch_size
        assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \
            f"img_size {img_size} should be divided by patch_size {patch_size}."
        self.H, self.W = img_size[0] // patch_size[0], img_size[
            1] // patch_size[1]
        self.num_patches = self.H * self.W
        self.norm = nn.LayerNorm(embed_dim)
        self.fc = nn.Linear(in_chans, embed_dim)
        self.proj = nn.Linear(embed_dim, embed_dim)
Exemplo n.º 25
0
 def __init__(self,
              backbone,
              img_size=224,
              patch_size=16,
              feature_size=None,
              in_chans=3,
              embed_dim=768):
     super().__init__()
     assert isinstance(backbone, nn.Module)
     img_size = to_2tuple(img_size)
     self.img_size = img_size
     self.backbone = backbone
     if feature_size is None:
         with torch.no_grad():
             # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
             # map for all networks, the feature metadata has reliable channel and stride info, but using
             # stride to calc feature dim requires info about padding of each stage that isn't captured.
             training = backbone.training
             if training:
                 backbone.eval()
             o = self.backbone(
                 torch.zeros(1, in_chans, img_size[0], img_size[1]))
             if isinstance(o, (list, tuple)):
                 o = o[
                     -1]  # last feature if backbone outputs list/tuple of features
             feature_size = o.shape[-2:]
             feature_dim = o.shape[1]
             backbone.train(training)
     else:
         feature_size = to_2tuple(feature_size)
         feature_dim = self.backbone.feature_info.channels()[-1]
     print('feature_size is {}, feature_dim is {}, patch_size is {}'.format(
         feature_size, feature_dim, patch_size))
     self.num_patches = (feature_size[0] //
                         patch_size) * (feature_size[1] // patch_size)
     self.proj = nn.Conv2d(feature_dim,
                           embed_dim,
                           kernel_size=patch_size,
                           stride=patch_size)
Exemplo n.º 26
0
    def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        patch_size = to_2tuple(patch_size)
        self.patch_size = patch_size

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None
Exemplo n.º 27
0
    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            # if window size is larger than input resolution, we don't partition windows
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)
Exemplo n.º 28
0
 def __init__(self, in_chans, embed_dim, resolution, activation):
     super().__init__()
     img_size: Tuple[int, int] = to_2tuple(resolution)
     self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
     self.num_patches = self.patches_resolution[0] * \
         self.patches_resolution[1]
     self.in_chans = in_chans
     self.embed_dim = embed_dim
     n = embed_dim
     self.seq = nn.Sequential(
         Conv2d_BN(in_chans, n // 2, 3, 2, 1),
         activation(),
         Conv2d_BN(n // 2, n, 3, 2, 1),
     )
Exemplo n.º 29
0
 def __init__(self,
              img_size=224,
              patch_size=16,
              in_chans=3,
              embed_dim=768,
              norm_layer=nn.SyncBatchNorm):
     super().__init__()
     img_size = to_2tuple(img_size)
     patch_size = to_2tuple(patch_size)
     num_patches = (img_size[1] // patch_size[1]) * (img_size[0] //
                                                     patch_size[0])
     self.img_size = img_size
     self.patch_size = patch_size
     self.num_patches = num_patches
     self.proj = torch.nn.Sequential(*[
         nn.Conv2d(in_chans, embed_dim // 4, kernel_size=4, stride=4),
         norm_layer(embed_dim // 4),
         nn.GELU(),
         nn.Conv2d(embed_dim // 4, embed_dim // 4, kernel_size=2, stride=2),
         norm_layer(embed_dim // 4),
         nn.GELU(),
         nn.Conv2d(embed_dim // 4, embed_dim, kernel_size=2, stride=2),
         norm_layer(embed_dim),
     ])
Exemplo n.º 30
0
    def forward(self, x):
        x = self.forward_features(x)
        
        x_rot = self.pre_logits_rot(x[:, 0])
        x_rot = self.rot_head(x_rot)
        
        x_contrastive = self.pre_logits_contrastive(x[:, 1])
        x_contrastive = self.contrastive_head(x_contrastive)

        if self.training_mode == 'finetune':
            return x_rot, x_contrastive
    
        x_rec = x[:, 2:].transpose(1, 2)
        x_rec = self.convTrans(x_rec.unflatten(2, to_2tuple(int(math.sqrt(x_rec.size()[2])))))
        
        return x_rot, x_contrastive, x_rec, self.rot_w, self.contrastive_w, self.recons_w