def deit_small_distilled_patch16_224(pretrained=False, **kwargs): model = DistilledVisionTransformer(patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model.default_cfg = _cfg() if pretrained: checkpoint = torch.hub.load_state_dict_from_url(url="", map_location="cpu", check_hash=True) model.load_state_dict(checkpoint["model"]) return model
def deit_base_patch16_224(pretrained=False, **kwargs): model = VisionTransformer(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model.default_cfg = _cfg() if pretrained: checkpoint = torch.hub.load_state_dict_from_url( url= "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", map_location="cpu", check_hash=True) model.load_state_dict(checkpoint["model"]) return model
def base_patch16_384_token(pretrained=False, **kwargs): model = VisionTransformer_token(img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model.default_cfg = _cfg() if pretrained: '''download from https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth''' checkpoint = torch.load( './Networks/deit_base_patch16_384-8de9b5d1.pth') model.load_state_dict(checkpoint["model"], strict=False) print("load transformer pretrained") return model
def deit_base3_patch16_224_key96(pretrained=False, **kwargs): import timm.models.vision_transformer from collaborate_attention import FlexibleKeyDimensionAttention timm.models.vision_transformer.Attention = partial(FlexibleKeyDimensionAttention, all_key_dim=96) model = VisionTransformer( patch_size=16, embed_dim=768, depth=3, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs, ) model.default_cfg = _cfg() assert not pretrained return model
def deit_base3_patch16_224_collab96(pretrained=False, models_directory=None, **kwargs): model = VisionTransformer( patch_size=16, embed_dim=768, depth=3, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs, ) collaborate_attention.swap(model, compressed_key_dim=96, reparametrize=False) model.default_cfg = _cfg() if pretrained: checkpoint_path = pathlib.Path(models_directory) / "deit_base3_patch16_224_collab96.pth" print(f"Load model from '{checkpoint_path}'") checkpoint = torch.load(checkpoint_path, map_location="cpu") model.load_state_dict(checkpoint["model"]) return model
def resmlp_36(pretrained=False, dist=False, **kwargs): model = resmlp_models(patch_size=16, embed_dim=384, depth=36, Patch_layer=PatchEmbed, init_scale=1e-6, **kwargs) model.default_cfg = _cfg() if pretrained: if dist: url_path = "https://dl.fbaipublicfiles.com/deit/resmlp_36_dist.pth" else: url_path = "https://dl.fbaipublicfiles.com/deit/resmlp_36_no_dist.pth" checkpoint = torch.hub.load_state_dict_from_url(url=url_path, map_location="cpu", check_hash=True) model.load_state_dict(checkpoint) return model
def pvt_small(pretrained=False, **kwargs): model = PyramidVisionTransformer(patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], **kwargs) model.default_cfg = _cfg() # if pretrained: # checkpoint = torch.hub.load_state_dict_from_url( # url=None, # map_location="cpu", check_hash=True # ) # model.load_state_dict(checkpoint["model"]) return model
def vitstr_small_distilled_patch16_224(pretrained=False, **kwargs): kwargs['in_chans'] = 1 kwargs['distilled'] = True model = ViTSTR(patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, **kwargs) model.default_cfg = _cfg( url= "https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth" ) if pretrained: load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 1), filter_fn=_conv_filter) return model
def vitstr_base_patch16_224(pretrained=False, **kwargs): kwargs['in_chans'] = 1 model = ViTSTR(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, **kwargs) model.default_cfg = _cfg( #url='https://github.com/roatienza/public/releases/download/v0.1-deit-base/deit_base_patch16_224-b5f2ef4d.pth' url= 'https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth' ) if pretrained: load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 1), filter_fn=_conv_filter) return model
def goodit3_base_patch16_224(pretrained=False, **kwargs): """ convolutional + pooling stem local enhanced feedforward attention over cls_tokens """ backbone = ConvStem3() model = VisionTransformer(hybrid_backbone=backbone, patch_size=4, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), feedforward_type='conv', **kwargs) model.default_cfg = _cfg() print(model) return model
def __init__( self, arch: str, pretrained: bool, lr: float, weight_decay: int, data_path: str, batch_size: int, workers: int, **kwargs, ): super().__init__() self.save_hyperparameters() self.arch = arch self.pretrained = pretrained self.lr = lr self.weight_decay = weight_decay self.data_path = data_path self.batch_size = batch_size self.workers = workers if self.arch == "tiny": # Tiny empty model for development purposes. img_size = [32, 32] self.model_cfg = _cfg(input_size=[3] + img_size) self.model = VisionTransformer(img_size=img_size, patch_size=4, in_chans=3, num_classes=1000, embed_dim=16, depth=2, num_heads=1) else: self.model: VisionTransformer = timm.create_model( self.arch, pretrained=self.pretrained) self.model_cfg = vision_transformer.default_cfgs[self.arch] # TODO: delete me. Hack so that auto_lr_find works self.model.reset_classifier(10)
def deit_small_distilled_patch16_384(pretrained=False, **kwargs): model = DistilledVisionTransformer( img_size=384, patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model.default_cfg = _cfg() if pretrained: checkpoint = torch.hub.load_state_dict_from_url( url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth", map_location="cpu", check_hash=True ) # adapt 224 model to 384 model_seq_len = model.state_dict()['pos_embed'].shape[1] ckpt_seq_len = checkpoint['model']['pos_embed'].shape[1] logger.warning('Deit load {:d} seq len to {:d} APE {}'.format(ckpt_seq_len, model_seq_len, str(model.ape))) if not model.ape: if model_seq_len <= ckpt_seq_len: checkpoint['model']['pos_embed'] = checkpoint['model']['pos_embed'][:, :model_seq_len, :] else: t = model.state_dict()['pos_embed'] t[:, :ckpt_seq_len, :] = checkpoint['model']['pos_embed'] checkpoint['model']['pos_embed'] = t model.load_state_dict(checkpoint["model"]) return model
def beit_large_patch16_384(pretrained=False, **kwargs): model = AdaptedVisionTransformer( img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=False, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model.default_cfg = _cfg() return model
def linear_large(pretrained=False, **kwargs): model = LinearVisionTransformer( patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model.default_cfg = _cfg() return model