Beispiel #1
0
    def forward_features(self, x):
        if "None" not in self.args.diff_aug:
            x = DiffAugment(x, self.args.diff_aug, True)

        x = x.permute(0, 2, 3, 1)  # b * 32 * 32 * 3
        x = self.linear0(x)  # b * 32 * 32 * 256

        #         cls_tokens = self.cls_token.expand(B, -1, -1, -1)  # b * 1 * 16 * 256
        #         x = torch.cat((cls_tokens, x), dim=1) # b * 65 * 16 * 256
        for blk in self.blocks:
            x = blk(x)
        b, h, w, c = x.shape
        x = x.view(b, h * w, c)
        x = self.avgpool(x).squeeze(1)
        return x
Beispiel #2
0
    def forward_features(self, x):
        if "None" not in self.args.diff_aug:
            x = DiffAugment(x, self.args.diff_aug, True)
        B = x.shape[0]
        x = self.patch_embed(x).flatten(2).permute(0,2,1)

        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        return x[:,0]