Пример #1
0
def deit_small_distilled_patch16_224(pretrained=False, **kwargs):
    model = DistilledVisionTransformer(patch_size=16,
                                       embed_dim=384,
                                       depth=12,
                                       num_heads=6,
                                       mlp_ratio=4,
                                       qkv_bias=True,
                                       norm_layer=partial(nn.LayerNorm,
                                                          eps=1e-6),
                                       **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(url="",
                                                        map_location="cpu",
                                                        check_hash=True)
        model.load_state_dict(checkpoint["model"])
    return model
Пример #2
0
def deit_base_patch16_224(pretrained=False, **kwargs):
    model = VisionTransformer(patch_size=16,
                              embed_dim=768,
                              depth=12,
                              num_heads=12,
                              mlp_ratio=4,
                              qkv_bias=True,
                              norm_layer=partial(nn.LayerNorm, eps=1e-6),
                              **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url=
            "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
            map_location="cpu",
            check_hash=True)
        model.load_state_dict(checkpoint["model"])
    return model
Пример #3
0
def base_patch16_384_token(pretrained=False, **kwargs):
    model = VisionTransformer_token(img_size=384,
                                    patch_size=16,
                                    embed_dim=768,
                                    depth=12,
                                    num_heads=12,
                                    mlp_ratio=4,
                                    qkv_bias=True,
                                    norm_layer=partial(nn.LayerNorm, eps=1e-6),
                                    **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        '''download from https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth'''
        checkpoint = torch.load(
            './Networks/deit_base_patch16_384-8de9b5d1.pth')
        model.load_state_dict(checkpoint["model"], strict=False)
        print("load transformer pretrained")
    return model
Пример #4
0
def deit_base3_patch16_224_key96(pretrained=False, **kwargs):
    import timm.models.vision_transformer
    from collaborate_attention import FlexibleKeyDimensionAttention
    timm.models.vision_transformer.Attention = partial(FlexibleKeyDimensionAttention, all_key_dim=96)

    model = VisionTransformer(
        patch_size=16,
        embed_dim=768,
        depth=3,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )
    model.default_cfg = _cfg()
    assert not pretrained
    return model
Пример #5
0
def deit_base3_patch16_224_collab96(pretrained=False, models_directory=None, **kwargs):
    model = VisionTransformer(
        patch_size=16,
        embed_dim=768,
        depth=3,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )
    collaborate_attention.swap(model, compressed_key_dim=96, reparametrize=False)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint_path = pathlib.Path(models_directory) / "deit_base3_patch16_224_collab96.pth"
        print(f"Load model from '{checkpoint_path}'")
        checkpoint = torch.load(checkpoint_path, map_location="cpu")
        model.load_state_dict(checkpoint["model"])
    return model
Пример #6
0
def resmlp_36(pretrained=False, dist=False, **kwargs):
    model = resmlp_models(patch_size=16,
                          embed_dim=384,
                          depth=36,
                          Patch_layer=PatchEmbed,
                          init_scale=1e-6,
                          **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        if dist:
            url_path = "https://dl.fbaipublicfiles.com/deit/resmlp_36_dist.pth"
        else:
            url_path = "https://dl.fbaipublicfiles.com/deit/resmlp_36_no_dist.pth"
        checkpoint = torch.hub.load_state_dict_from_url(url=url_path,
                                                        map_location="cpu",
                                                        check_hash=True)

        model.load_state_dict(checkpoint)
    return model
Пример #7
0
Файл: pvt.py Проект: zlapp/PVT
def pvt_small(pretrained=False, **kwargs):
    model = PyramidVisionTransformer(patch_size=4,
                                     embed_dims=[64, 128, 320, 512],
                                     num_heads=[1, 2, 5, 8],
                                     mlp_ratios=[8, 8, 4, 4],
                                     qkv_bias=True,
                                     norm_layer=partial(nn.LayerNorm,
                                                        eps=1e-6),
                                     depths=[3, 4, 6, 3],
                                     sr_ratios=[8, 4, 2, 1],
                                     **kwargs)
    model.default_cfg = _cfg()
    # if pretrained:
    #     checkpoint = torch.hub.load_state_dict_from_url(
    #         url=None,
    #         map_location="cpu", check_hash=True
    #     )
    #     model.load_state_dict(checkpoint["model"])

    return model
def vitstr_small_distilled_patch16_224(pretrained=False, **kwargs):
    kwargs['in_chans'] = 1
    kwargs['distilled'] = True
    model = ViTSTR(patch_size=16,
                   embed_dim=384,
                   depth=12,
                   num_heads=6,
                   mlp_ratio=4,
                   qkv_bias=True,
                   **kwargs)
    model.default_cfg = _cfg(
        url=
        "https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth"
    )
    if pretrained:
        load_pretrained(model,
                        num_classes=model.num_classes,
                        in_chans=kwargs.get('in_chans', 1),
                        filter_fn=_conv_filter)
    return model
def vitstr_base_patch16_224(pretrained=False, **kwargs):
    kwargs['in_chans'] = 1
    model = ViTSTR(patch_size=16,
                   embed_dim=768,
                   depth=12,
                   num_heads=12,
                   mlp_ratio=4,
                   qkv_bias=True,
                   **kwargs)
    model.default_cfg = _cfg(
        #url='https://github.com/roatienza/public/releases/download/v0.1-deit-base/deit_base_patch16_224-b5f2ef4d.pth'
        url=
        'https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth'
    )
    if pretrained:
        load_pretrained(model,
                        num_classes=model.num_classes,
                        in_chans=kwargs.get('in_chans', 1),
                        filter_fn=_conv_filter)
    return model
Пример #10
0
def goodit3_base_patch16_224(pretrained=False, **kwargs):
    """
    convolutional + pooling stem
    local enhanced feedforward
    attention over cls_tokens
    """
    backbone = ConvStem3()
    model = VisionTransformer(hybrid_backbone=backbone,
                              patch_size=4,
                              embed_dim=768,
                              depth=12,
                              num_heads=12,
                              mlp_ratio=4,
                              qkv_bias=True,
                              norm_layer=partial(nn.LayerNorm, eps=1e-6),
                              feedforward_type='conv',
                              **kwargs)
    model.default_cfg = _cfg()
    print(model)
    return model
Пример #11
0
 def __init__(
     self,
     arch: str,
     pretrained: bool,
     lr: float,
     weight_decay: int,
     data_path: str,
     batch_size: int,
     workers: int,
     **kwargs,
 ):
     super().__init__()
     self.save_hyperparameters()
     self.arch = arch
     self.pretrained = pretrained
     self.lr = lr
     self.weight_decay = weight_decay
     self.data_path = data_path
     self.batch_size = batch_size
     self.workers = workers
     if self.arch == "tiny":
         # Tiny empty model for development purposes.
         img_size = [32, 32]
         self.model_cfg = _cfg(input_size=[3] + img_size)
         self.model = VisionTransformer(img_size=img_size,
                                        patch_size=4,
                                        in_chans=3,
                                        num_classes=1000,
                                        embed_dim=16,
                                        depth=2,
                                        num_heads=1)
     else:
         self.model: VisionTransformer = timm.create_model(
             self.arch, pretrained=self.pretrained)
         self.model_cfg = vision_transformer.default_cfgs[self.arch]
     # TODO: delete me. Hack so that auto_lr_find works
     self.model.reset_classifier(10)
Пример #12
0
def deit_small_distilled_patch16_384(pretrained=False, **kwargs):
    model = DistilledVisionTransformer(
        img_size=384, patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth",
            map_location="cpu", check_hash=True
        )
        # adapt 224 model to 384
        model_seq_len = model.state_dict()['pos_embed'].shape[1]
        ckpt_seq_len = checkpoint['model']['pos_embed'].shape[1]
        logger.warning('Deit load {:d} seq len to {:d} APE {}'.format(ckpt_seq_len, model_seq_len, str(model.ape)))
        if not model.ape:
            if model_seq_len <= ckpt_seq_len:
                checkpoint['model']['pos_embed'] = checkpoint['model']['pos_embed'][:, :model_seq_len, :]
            else:
                t = model.state_dict()['pos_embed']
                t[:, :ckpt_seq_len, :] = checkpoint['model']['pos_embed']
                checkpoint['model']['pos_embed'] = t

        model.load_state_dict(checkpoint["model"])
    return model
Пример #13
0
def beit_large_patch16_384(pretrained=False, **kwargs):
    model = AdaptedVisionTransformer(
        img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=False,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    return model
def linear_large(pretrained=False, **kwargs):
    model = LinearVisionTransformer(
        patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    return model