예제 #1
0
 def _init_weights(self, m):
     if isinstance(m, nn.Linear):
         trunc_normal_(m.weight, std=.02)
         if isinstance(m, nn.Linear) and m.bias is not None:
             nn.init.constant_(m.bias, 0)
     elif isinstance(m, nn.LayerNorm):
         nn.init.constant_(m.bias, 0)
         nn.init.constant_(m.weight, 1.0)
예제 #2
0
    def __init__(self,
                 args,
                 img_size=32,
                 patch_size=None,
                 in_chans=3,
                 num_classes=1,
                 embed_dim=None,
                 depth=7,
                 num_heads=4,
                 mlp_ratio=4.,
                 qkv_bias=False,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.,
                 hybrid_backbone=None,
                 norm_layer=nn.LayerNorm):
        super().__init__()
        self.num_classes = num_classes
        self.num_features = embed_dim = self.embed_dim = args.df_dim  # num_features for consistency with other models
        depth = args.d_depth
        self.args = args
        patch_size = args.patch_size
        self.patch_embed = nn.Conv2d(3,
                                     embed_dim,
                                     kernel_size=patch_size,
                                     stride=patch_size,
                                     padding=0)
        num_patches = (args.img_size // patch_size)**2

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
               ]  # stochastic depth decay rule
        self.blocks = nn.ModuleList([
            Block(dim=embed_dim,
                  num_heads=num_heads,
                  mlp_ratio=mlp_ratio,
                  qkv_bias=qkv_bias,
                  qk_scale=qk_scale,
                  drop=drop_rate,
                  attn_drop=attn_drop_rate,
                  drop_path=dpr[i],
                  norm_layer=norm_layer) for i in range(depth)
        ])
        self.norm = norm_layer(embed_dim)

        # Classifier head
        self.head = nn.Linear(
            embed_dim, num_classes) if num_classes > 0 else nn.Identity()

        trunc_normal_(self.pos_embed, std=.02)
        trunc_normal_(self.cls_token, std=.02)
        self.apply(self._init_weights)
예제 #3
0
    def __init__(self,
                 input_tokens=512,
                 token_dim=1024,
                 embed_dim=384,
                 depth=5,
                 n_cls=20,
                 n_subspaces=10,
                 num_heads=4,
                 mlp_ratio=4.,
                 qkv_bias=False,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.,
                 norm_layer=nn.LayerNorm):
        super(Transformer, self).__init__()
        self.input_tokens = input_tokens
        self.token_dim = token_dim
        self.embed_dim = embed_dim
        self.n_cls = n_cls
        self.n_subspaces = n_subspaces
        self.l1 = nn.Linear(self.token_dim, self.embed_dim)
        self.pos_embed_1 = nn.Parameter(
            torch.zeros(1, input_tokens + 2, embed_dim))
        self.cls_embed = nn.Parameter(torch.zeros(1, 2, embed_dim))
        self.pos_embed = [self.pos_embed_1]
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
               ]  # stochastic depth decay rule
        is_mask = False
        self.blocks = nn.ModuleList([
            Block(dim=embed_dim,
                  num_heads=num_heads,
                  mlp_ratio=mlp_ratio,
                  qkv_bias=qkv_bias,
                  qk_scale=qk_scale,
                  drop=drop_rate,
                  attn_drop=attn_drop_rate,
                  drop_path=dpr[i],
                  norm_layer=norm_layer,
                  is_mask=is_mask) for i in range(depth)
        ])

        for i in range(len(self.pos_embed)):
            trunc_normal_(self.pos_embed[i], std=.02)

        trunc_normal_(self.cls_embed, std=.02)

        self.l2 = nn.Linear(self.embed_dim, 1, bias=False)
        self.l3 = nn.Linear(self.embed_dim, self.n_cls, bias=False)
        self.l4 = nn.Linear(self.embed_dim, 1, bias=False)
        self.l5 = nn.Linear(self.embed_dim, self.n_subspaces, bias=False)
예제 #4
0
    def __init__(self,
                 args,
                 img_size=224,
                 patch_size=16,
                 in_chans=3,
                 num_classes=10,
                 embed_dim=384,
                 depth=5,
                 num_heads=4,
                 mlp_ratio=4.,
                 qkv_bias=False,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.,
                 hybrid_backbone=None,
                 norm_layer=nn.LayerNorm):
        super(Generator, self).__init__()
        self.args = args
        self.ch = embed_dim
        self.bottom_width = args.bottom_width
        self.embed_dim = embed_dim = args.gf_dim
        self.l1 = nn.Linear(args.latent_dim,
                            (self.bottom_width**2) * self.embed_dim)
        self.pos_embed_1 = nn.Parameter(
            torch.zeros(1, self.bottom_width**2, embed_dim))
        self.pos_embed_2 = nn.Parameter(
            torch.zeros(1, (self.bottom_width * 2)**2, embed_dim // 4))
        self.pos_embed_3 = nn.Parameter(
            torch.zeros(1, (self.bottom_width * 4)**2, embed_dim // 16))
        self.pos_embed_4 = nn.Parameter(
            torch.zeros(1, (self.bottom_width * 8)**2, embed_dim // 64))
        self.pos_embed = [
            self.pos_embed_1, self.pos_embed_2, self.pos_embed_3,
            self.pos_embed_4
        ]
        is_mask = True
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
               ]  # stochastic depth decay rule
        self.blocks = nn.ModuleList([
            Block(dim=embed_dim,
                  num_heads=num_heads,
                  mlp_ratio=mlp_ratio,
                  qkv_bias=qkv_bias,
                  qk_scale=qk_scale,
                  drop=drop_rate,
                  attn_drop=attn_drop_rate,
                  drop_path=dpr[i],
                  norm_layer=norm_layer) for i in range(depth)
        ])
        self.upsample_blocks = nn.ModuleList([
            nn.ModuleList([
                Block(dim=embed_dim // 4,
                      num_heads=num_heads,
                      mlp_ratio=mlp_ratio,
                      qkv_bias=qkv_bias,
                      qk_scale=qk_scale,
                      drop=drop_rate,
                      attn_drop=attn_drop_rate,
                      drop_path=0,
                      norm_layer=norm_layer),
                Block(dim=embed_dim // 4,
                      num_heads=num_heads,
                      mlp_ratio=mlp_ratio,
                      qkv_bias=qkv_bias,
                      qk_scale=qk_scale,
                      drop=drop_rate,
                      attn_drop=attn_drop_rate,
                      drop_path=0,
                      norm_layer=norm_layer,
                      is_mask=0),
                Block(dim=embed_dim // 4,
                      num_heads=num_heads,
                      mlp_ratio=mlp_ratio,
                      qkv_bias=qkv_bias,
                      qk_scale=qk_scale,
                      drop=drop_rate,
                      attn_drop=attn_drop_rate,
                      drop_path=0,
                      norm_layer=norm_layer,
                      is_mask=0)
            ]),
            nn.ModuleList([
                Block(dim=embed_dim // 16,
                      num_heads=num_heads,
                      mlp_ratio=mlp_ratio,
                      qkv_bias=qkv_bias,
                      qk_scale=qk_scale,
                      drop=drop_rate,
                      attn_drop=attn_drop_rate,
                      drop_path=0,
                      norm_layer=norm_layer),
                Block(dim=embed_dim // 16,
                      num_heads=num_heads,
                      mlp_ratio=mlp_ratio,
                      qkv_bias=qkv_bias,
                      qk_scale=qk_scale,
                      drop=drop_rate,
                      attn_drop=attn_drop_rate,
                      drop_path=0,
                      norm_layer=norm_layer,
                      is_mask=0),
                Block(dim=embed_dim // 16,
                      num_heads=num_heads,
                      mlp_ratio=mlp_ratio,
                      qkv_bias=qkv_bias,
                      qk_scale=qk_scale,
                      drop=drop_rate,
                      attn_drop=attn_drop_rate,
                      drop_path=0,
                      norm_layer=norm_layer,
                      is_mask=0)
            ]),
            nn.ModuleList([
                # Block(
                #     dim=embed_dim//16, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                #     drop=drop_rate, attn_drop=attn_drop_rate, drop_path=0, norm_layer=norm_layer),
                Block(dim=embed_dim // 64,
                      num_heads=num_heads,
                      mlp_ratio=mlp_ratio,
                      qkv_bias=qkv_bias,
                      qk_scale=qk_scale,
                      drop=drop_rate,
                      attn_drop=attn_drop_rate,
                      drop_path=0,
                      norm_layer=norm_layer,
                      is_mask=0),
                Block(dim=embed_dim // 64,
                      num_heads=num_heads,
                      mlp_ratio=mlp_ratio,
                      qkv_bias=qkv_bias,
                      qk_scale=qk_scale,
                      drop=drop_rate,
                      attn_drop=attn_drop_rate,
                      drop_path=0,
                      norm_layer=norm_layer,
                      is_mask=(self.bottom_width * 8)**2)
            ])
        ])
        for i in range(len(self.pos_embed)):
            trunc_normal_(self.pos_embed[i], std=.02)

        self.to_rgb = nn.Sequential(
            nn.BatchNorm2d(args.gf_dim),
            nn.ReLU(),
            # nn.Conv2d(args.gf_dim, 3, 3, 1, 1),
            nn.Tanh())

        self.deconv = nn.Sequential(
            # nn.BatchNorm2d(self.embed_dim),
            # nn.ReLU(),
            nn.Conv2d(self.embed_dim // 64, 3, 1, 1, 0))
예제 #5
0
    def __init__(self,
                 input_tokens=512,
                 token_dim=1024,
                 embed_dim=384,
                 encoder_depth=5,
                 decoder_depth=5,
                 n_cls=20,
                 num_heads=4,
                 mlp_ratio=4.,
                 qkv_bias=False,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.,
                 norm_layer=nn.LayerNorm,
                 task_nums=6):
        super(FullTransformer, self).__init__()
        self.input_tokens = input_tokens
        self.token_dim = token_dim
        self.embed_dim = embed_dim
        self.task_nums = task_nums
        self.n_cls = n_cls
        self.l1 = nn.Linear(self.token_dim, self.embed_dim)
        self.pos_embed = nn.Parameter(
            torch.zeros(1, input_tokens + 1, embed_dim))
        self.cls_embed = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.out_embed = nn.Parameter(
            torch.zeros(self.task_nums, input_tokens + 1, embed_dim))
        self.embeds = [self.pos_embed, self.out_embed]
        dpr = [
            x.item() for x in torch.linspace(0, drop_path_rate, encoder_depth)
        ]
        is_mask = False

        self.enc_blocks = nn.ModuleList([
            Block(dim=embed_dim,
                  num_heads=num_heads,
                  mlp_ratio=mlp_ratio,
                  qkv_bias=qkv_bias,
                  qk_scale=qk_scale,
                  drop=drop_rate,
                  attn_drop=attn_drop_rate,
                  drop_path=dpr[i],
                  norm_layer=norm_layer,
                  is_mask=is_mask) for i in range(encoder_depth)
        ])

        self.dec_blocks = nn.ModuleList([
            DecBlock(dim=embed_dim,
                     num_heads=num_heads,
                     mlp_ratio=mlp_ratio,
                     qkv_bias=qkv_bias,
                     qk_scale=qk_scale,
                     drop=drop_rate,
                     attn_drop=attn_drop_rate,
                     drop_path=dpr[i],
                     norm_layer=norm_layer,
                     is_mask=is_mask) for i in range(decoder_depth)
        ])

        self.l2 = nn.Linear(self.embed_dim, 1, bias=False)
        self.l3 = nn.Linear(self.embed_dim, self.n_cls, bias=False)
        self.l4 = nn.Linear(self.embed_dim, 1, bias=False)

        for i in range(len(self.embeds)):
            trunc_normal_(self.embeds[i], std=1.0)

        trunc_normal_(self.cls_embed, std=.02)
예제 #6
0
파일: ViTrans.py 프로젝트: guanyuelee/PmSFC
    def __init__(self,
                 img_size=32,
                 patch_size=16,
                 in_chans=3,
                 num_dirs=20,
                 embed_dim=None,
                 depth=7,
                 num_heads=4,
                 mlp_ratio=4.,
                 qkv_bias=False,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.,
                 hybrid_backbone=None,
                 norm_layer=nn.LayerNorm):
        super().__init__()
        self.num_dirs = num_dirs
        self.embed_dim = embed_dim
        self.depth = depth
        self.patch_size = patch_size
        self.img_size = img_size
        if hybrid_backbone is not None:
            self.patch_embed = HybridEmbed(hybrid_backbone,
                                           img_size=img_size,
                                           in_chans=in_chans,
                                           embed_dim=embed_dim)
        else:
            self.patch_embed = nn.Conv2d(in_chans,
                                         embed_dim,
                                         kernel_size=patch_size,
                                         stride=patch_size,
                                         padding=0)
        num_patches = (img_size // patch_size)**2

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
               ]  # stochastic depth decay rule
        self.blocks = nn.ModuleList([
            Block(dim=embed_dim,
                  num_heads=num_heads,
                  mlp_ratio=mlp_ratio,
                  qkv_bias=qkv_bias,
                  qk_scale=qk_scale,
                  drop=drop_rate,
                  attn_drop=attn_drop_rate,
                  drop_path=dpr[i],
                  norm_layer=norm_layer) for i in range(depth)
        ])
        self.norm = norm_layer(embed_dim)

        self.head1 = nn.Linear(embed_dim,
                               1) if self.num_dirs > 0 else nn.Identity()
        self.head2 = nn.Linear(
            embed_dim, num_dirs) if self.num_dirs > 0 else nn.Identity()

        trunc_normal_(self.pos_embed, std=.02)
        trunc_normal_(self.cls_token, std=.02)
        self.apply(self._init_weights)
예제 #7
0
    def __init__(
        self,
        args,
        img_size=32,
        patch_size=4,
        in_chans=3,
        num_classes=1,
        embed_dim=None,
        depth=7,
        num_heads=4,
        mlp_ratio=4.0,
        qkv_bias=False,
        qk_scale=None,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        hybrid_backbone=None,
        norm_layer=nn.LayerNorm,
        augments="",
    ):
        super().__init__()
        self.num_classes = num_classes
        self.num_features = args.df_dim
        embed_dim = args.df_dim
        self.embed_dim = args.df_dim  # num_features for consistency with other models
        depth = args.d_depth
        self.args = args
        patch_size = args.patch_size
        if hybrid_backbone is not None:
            self.patch_embed = HybridEmbed(
                hybrid_backbone,
                img_size=img_size,
                in_chans=in_chans,
                embed_dim=embed_dim,
            )
        else:
            self.patch_embed = nn.Conv2d(3,
                                         embed_dim,
                                         kernel_size=patch_size,
                                         stride=patch_size,
                                         padding=0)
        num_patches = (args.img_size // patch_size)**2

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
               ]  # stochastic depth decay rule
        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[i],
                norm_layer=norm_layer,
            ) for i in range(depth)
        ])
        self.norm = norm_layer(embed_dim)

        # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
        # self.repr = nn.Linear(embed_dim, representation_size)
        # self.repr_act = nn.Tanh()

        # Classifier head
        self.head = (nn.Linear(embed_dim, num_classes)
                     if num_classes > 0 else nn.Identity())

        trunc_normal_(self.pos_embed, std=0.02)
        trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)
        self.augments = augments