コード例 #1
0
def deit_tiny_colab_patch16_224(pretrained=False, all_key_dim=None, **kwargs):
    model = VisionTransformer(
        patch_size=16,
        embed_dim=192,
        depth=12,
        num_heads=3,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
            map_location="cpu",
            check_hash=True,
        )
        model.load_state_dict(checkpoint["model"])

    model.cuda()
    collaborate_attention.swap(model, all_key_dim)
    model.cpu()

    return model
コード例 #2
0
def deit_base_patch16_224_collab768(pretrained=False, models_directory="./models", **kwargs):
    model = deit_base_patch16_224(pretrained=False)
    collaborate_attention.swap(model, compressed_key_dim=768, reparametrize=False)
    if pretrained:
        checkpoint_path = pathlib.Path(models_directory) / "deit_base_patch16_224_collab768.pth"
        print(f"Load model from '{checkpoint_path}'")
        checkpoint = torch.load(checkpoint_path, map_location="cpu")
        model.load_state_dict(checkpoint["model"])
    return model
コード例 #3
0
def deit_base_patch16_224_collab256(pretrained=False, **kwargs):
    model = VisionTransformer(
        patch_size=16,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )
    model.default_cfg = _cfg()
    collaborate_attention.swap(model, compressed_key_dim=256, reparametrize=False)
    return model
コード例 #4
0
def deit_base3_patch16_224_collab96(pretrained=False, models_directory=None, **kwargs):
    model = VisionTransformer(
        patch_size=16,
        embed_dim=768,
        depth=3,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )
    collaborate_attention.swap(model, compressed_key_dim=96, reparametrize=False)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint_path = pathlib.Path(models_directory) / "deit_base3_patch16_224_collab96.pth"
        print(f"Load model from '{checkpoint_path}'")
        checkpoint = torch.load(checkpoint_path, map_location="cpu")
        model.load_state_dict(checkpoint["model"])
    return model