예제 #1
0
def deit_tiny_colab_patch16_224(pretrained=False, all_key_dim=None, **kwargs):
    model = VisionTransformer(
        patch_size=16,
        embed_dim=192,
        depth=12,
        num_heads=3,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
            map_location="cpu",
            check_hash=True,
        )
        model.load_state_dict(checkpoint["model"])

    model.cuda()
    collaborate_attention.swap(model, all_key_dim)
    model.cpu()

    return model
예제 #2
0
def deit_base_patch16_384(pretrained=False, **kwargs):
    if "num_landmarks" not in kwargs:
        model = VisionTransformer(patch_size=16,
                                  embed_dim=768,
                                  depth=12,
                                  num_heads=12,
                                  mlp_ratio=4,
                                  qkv_bias=True,
                                  norm_layer=partial(nn.LayerNorm, eps=1e-6),
                                  **kwargs)
    else:
        model = Nystromformer(patch_size=16,
                              embed_dim=768,
                              depth=12,
                              num_heads=12,
                              mlp_ratio=4,
                              qkv_bias=True,
                              norm_layer=partial(nn.LayerNorm, eps=1e-6),
                              **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url=
            "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth",
            map_location="cpu",
            check_hash=True)
        model.load_state_dict(checkpoint["model"])
    return model
예제 #3
0
def deit_small_patch16_224(pretrained=False, **kwargs):
    model = VisionTransformer(
        patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model
예제 #4
0
def deit_tiny_patch16_224_ex6(pretrained=False, **kwargs):
    # the expanded Deit-T in Table 1
    model = VisionTransformer(patch_size=16,
                              embed_dim=192,
                              depth=12,
                              num_heads=3,
                              mlp_ratio=6,
                              qkv_bias=True,
                              norm_layer=partial(nn.LayerNorm, eps=1e-6),
                              **kwargs)
    model.default_cfg = _cfg()
    return model
예제 #5
0
def deit_base3_patch16_224(pretrained=False, **kwargs):
    model = VisionTransformer(
        patch_size=16,
        embed_dim=768,
        depth=3,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )
    model.default_cfg = _cfg()
    assert not pretrained
    return model
예제 #6
0
def deit_base_patch16_224_collab256(pretrained=False, **kwargs):
    model = VisionTransformer(
        patch_size=16,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )
    model.default_cfg = _cfg()
    collaborate_attention.swap(model, compressed_key_dim=256, reparametrize=False)
    return model
예제 #7
0
    def load_pretrained(self):
        model = VisionTransformer(
            patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
            norm_layer=nn.LayerNorm)
        model.default_cfg = _cfg()
        checkpoint = torch.hub.load_state_dict_from_url(
            url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',
            map_location="cpu", check_hash=True
        )        
        model.load_state_dict(checkpoint['model'])

        self.pos_embed = model.pos_embed
        self.patch_embed = model.patch_embed
        self.blocks = model.blocks
        self.norm = model.norm
예제 #8
0
def deit_base_patch16_384(pretrained=False, **kwargs):
    model = VisionTransformer(img_size=384,
                              patch_size=16,
                              embed_dim=768,
                              depth=12,
                              num_heads=12,
                              mlp_ratio=4,
                              qkv_bias=True,
                              norm_layer=partial(nn.LayerNorm, eps=1e-6),
                              **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(url="",
                                                        map_location="cpu",
                                                        check_hash=True)
        model.load_state_dict(checkpoint["model"])
    return model
예제 #9
0
def deit_base3_patch16_224_key96(pretrained=False, **kwargs):
    import timm.models.vision_transformer
    from collaborate_attention import FlexibleKeyDimensionAttention
    timm.models.vision_transformer.Attention = partial(FlexibleKeyDimensionAttention, all_key_dim=96)

    model = VisionTransformer(
        patch_size=16,
        embed_dim=768,
        depth=3,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )
    model.default_cfg = _cfg()
    assert not pretrained
    return model
예제 #10
0
def deit_base3_patch16_224_collab96(pretrained=False, models_directory=None, **kwargs):
    model = VisionTransformer(
        patch_size=16,
        embed_dim=768,
        depth=3,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )
    collaborate_attention.swap(model, compressed_key_dim=96, reparametrize=False)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint_path = pathlib.Path(models_directory) / "deit_base3_patch16_224_collab96.pth"
        print(f"Load model from '{checkpoint_path}'")
        checkpoint = torch.load(checkpoint_path, map_location="cpu")
        model.load_state_dict(checkpoint["model"])
    return model