def deit_tiny_colab_patch16_224(pretrained=False, all_key_dim=None, **kwargs): model = VisionTransformer( patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs, ) model.default_cfg = _cfg() if pretrained: checkpoint = torch.hub.load_state_dict_from_url( url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth", map_location="cpu", check_hash=True, ) model.load_state_dict(checkpoint["model"]) model.cuda() collaborate_attention.swap(model, all_key_dim) model.cpu() return model
def deit_base_patch16_384(pretrained=False, **kwargs): if "num_landmarks" not in kwargs: model = VisionTransformer(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) else: model = Nystromformer(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model.default_cfg = _cfg() if pretrained: checkpoint = torch.hub.load_state_dict_from_url( url= "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth", map_location="cpu", check_hash=True) model.load_state_dict(checkpoint["model"]) return model
def deit_small_patch16_224(pretrained=False, **kwargs): model = VisionTransformer( patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model.default_cfg = _cfg() if pretrained: checkpoint = torch.hub.load_state_dict_from_url( url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth", map_location="cpu", check_hash=True ) model.load_state_dict(checkpoint["model"]) return model
def deit_tiny_patch16_224_ex6(pretrained=False, **kwargs): # the expanded Deit-T in Table 1 model = VisionTransformer(patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=6, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model.default_cfg = _cfg() return model
def deit_base3_patch16_224(pretrained=False, **kwargs): model = VisionTransformer( patch_size=16, embed_dim=768, depth=3, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs, ) model.default_cfg = _cfg() assert not pretrained return model
def deit_base_patch16_224_collab256(pretrained=False, **kwargs): model = VisionTransformer( patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs, ) model.default_cfg = _cfg() collaborate_attention.swap(model, compressed_key_dim=256, reparametrize=False) return model
def load_pretrained(self): model = VisionTransformer( patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=nn.LayerNorm) model.default_cfg = _cfg() checkpoint = torch.hub.load_state_dict_from_url( url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', map_location="cpu", check_hash=True ) model.load_state_dict(checkpoint['model']) self.pos_embed = model.pos_embed self.patch_embed = model.patch_embed self.blocks = model.blocks self.norm = model.norm
def deit_base_patch16_384(pretrained=False, **kwargs): model = VisionTransformer(img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model.default_cfg = _cfg() if pretrained: checkpoint = torch.hub.load_state_dict_from_url(url="", map_location="cpu", check_hash=True) model.load_state_dict(checkpoint["model"]) return model
def deit_base3_patch16_224_key96(pretrained=False, **kwargs): import timm.models.vision_transformer from collaborate_attention import FlexibleKeyDimensionAttention timm.models.vision_transformer.Attention = partial(FlexibleKeyDimensionAttention, all_key_dim=96) model = VisionTransformer( patch_size=16, embed_dim=768, depth=3, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs, ) model.default_cfg = _cfg() assert not pretrained return model
def deit_base3_patch16_224_collab96(pretrained=False, models_directory=None, **kwargs): model = VisionTransformer( patch_size=16, embed_dim=768, depth=3, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs, ) collaborate_attention.swap(model, compressed_key_dim=96, reparametrize=False) model.default_cfg = _cfg() if pretrained: checkpoint_path = pathlib.Path(models_directory) / "deit_base3_patch16_224_collab96.pth" print(f"Load model from '{checkpoint_path}'") checkpoint = torch.load(checkpoint_path, map_location="cpu") model.load_state_dict(checkpoint["model"]) return model