def deit_tiny_colab_patch16_224(pretrained=False, all_key_dim=None, **kwargs): model = VisionTransformer( patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs, ) model.default_cfg = _cfg() if pretrained: checkpoint = torch.hub.load_state_dict_from_url( url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth", map_location="cpu", check_hash=True, ) model.load_state_dict(checkpoint["model"]) model.cuda() collaborate_attention.swap(model, all_key_dim) model.cpu() return model
def deit_base_patch16_384(pretrained=False, **kwargs): model = VisionTransformer( img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model.default_cfg = _cfg() if pretrained: checkpoint = torch.hub.load_state_dict_from_url( url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth", map_location="cpu", check_hash=True ) model.load_state_dict(checkpoint["model"]) return model
def deit_tiny_patch16_224_ex6(pretrained=False, **kwargs): # the expanded Deit-T in Table 1 model = VisionTransformer(patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=6, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model.default_cfg = _cfg() return model
def deit_base3_patch16_224(pretrained=False, **kwargs): model = VisionTransformer( patch_size=16, embed_dim=768, depth=3, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs, ) model.default_cfg = _cfg() assert not pretrained return model
def deit_base_patch16_224_collab256(pretrained=False, **kwargs): model = VisionTransformer( patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs, ) model.default_cfg = _cfg() collaborate_attention.swap(model, compressed_key_dim=256, reparametrize=False) return model
def load_pretrained(self): model = VisionTransformer( patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=nn.LayerNorm) model.default_cfg = _cfg() checkpoint = torch.hub.load_state_dict_from_url( url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', map_location="cpu", check_hash=True ) model.load_state_dict(checkpoint['model']) self.pos_embed = model.pos_embed self.patch_embed = model.patch_embed self.blocks = model.blocks self.norm = model.norm
def deit_base3_patch16_224_key96(pretrained=False, **kwargs): import timm.models.vision_transformer from collaborate_attention import FlexibleKeyDimensionAttention timm.models.vision_transformer.Attention = partial(FlexibleKeyDimensionAttention, all_key_dim=96) model = VisionTransformer( patch_size=16, embed_dim=768, depth=3, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs, ) model.default_cfg = _cfg() assert not pretrained return model
def deit_132_3_24(pretrained=False, **kwargs): # 1258.532028 model = VisionTransformer(patch_size=16, embed_dim=132, depth=24, num_heads=3, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model.default_cfg = _cfg() if pretrained: checkpoint = torch.hub.load_state_dict_from_url( url= "https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth", map_location="cpu", check_hash=True) model.load_state_dict(checkpoint["model"]) return model
def deit_base3_patch16_224_collab96(pretrained=False, models_directory=None, **kwargs): model = VisionTransformer( patch_size=16, embed_dim=768, depth=3, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs, ) collaborate_attention.swap(model, compressed_key_dim=96, reparametrize=False) model.default_cfg = _cfg() if pretrained: checkpoint_path = pathlib.Path(models_directory) / "deit_base3_patch16_224_collab96.pth" print(f"Load model from '{checkpoint_path}'") checkpoint = torch.load(checkpoint_path, map_location="cpu") model.load_state_dict(checkpoint["model"]) return model
def deit_base_patch16_224(pretrained=False, **kwargs): if "num_landmarks" not in kwargs: model = VisionTransformer(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) else: model = Nystromformer(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model.default_cfg = _cfg() if pretrained: checkpoint = torch.hub.load_state_dict_from_url( url= "http://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", map_location="cpu", check_hash=True) model.load_state_dict(checkpoint["model"]) return model
def __init__( self, arch: str, pretrained: bool, lr: float, weight_decay: int, data_path: str, batch_size: int, workers: int, **kwargs, ): super().__init__() self.save_hyperparameters() self.arch = arch self.pretrained = pretrained self.lr = lr self.weight_decay = weight_decay self.data_path = data_path self.batch_size = batch_size self.workers = workers if self.arch == "tiny": # Tiny empty model for development purposes. img_size = [32, 32] self.model_cfg = _cfg(input_size=[3] + img_size) self.model = VisionTransformer(img_size=img_size, patch_size=4, in_chans=3, num_classes=1000, embed_dim=16, depth=2, num_heads=1) else: self.model: VisionTransformer = timm.create_model( self.arch, pretrained=self.pretrained) self.model_cfg = vision_transformer.default_cfgs[self.arch] # TODO: delete me. Hack so that auto_lr_find works self.model.reset_classifier(10)
class ImageNetLightningModelForVit(LightningModule): # pull out vit models MODEL_NAMES = sorted( list(vision_transformer.default_cfgs.keys() | {"tiny"})) def __init__( self, arch: str, pretrained: bool, lr: float, weight_decay: int, data_path: str, batch_size: int, workers: int, **kwargs, ): super().__init__() self.save_hyperparameters() self.arch = arch self.pretrained = pretrained self.lr = lr self.weight_decay = weight_decay self.data_path = data_path self.batch_size = batch_size self.workers = workers if self.arch == "tiny": # Tiny empty model for development purposes. img_size = [32, 32] self.model_cfg = _cfg(input_size=[3] + img_size) self.model = VisionTransformer(img_size=img_size, patch_size=4, in_chans=3, num_classes=1000, embed_dim=16, depth=2, num_heads=1) else: self.model: VisionTransformer = timm.create_model( self.arch, pretrained=self.pretrained) self.model_cfg = vision_transformer.default_cfgs[self.arch] # TODO: delete me. Hack so that auto_lr_find works self.model.reset_classifier(10) def setup(self, stage: str): # Configuring the head of the model to the number of classes train_dir = os.path.join(self.data_path, 'train') train_dataset = datasets.ImageFolder(train_dir) self.model.reset_classifier(len(train_dataset.classes)) def forward(self, x): return self.model(x) def training_step(self, batch, batch_idx): images, target = batch output = self(images) loss_train = F.cross_entropy(output, target) acc1, acc5 = self.__accuracy(output, target, topk=(1, 5)) self.log('train_loss', loss_train, on_step=True, on_epoch=True, logger=True) self.log('train_acc1', acc1, on_step=True, prog_bar=True, on_epoch=True, logger=True) self.log('train_acc5', acc5, on_step=True, on_epoch=True, logger=True) return loss_train def validation_step(self, batch, batch_idx): images, target = batch output = self(images) loss_val = F.cross_entropy(output, target) acc1, acc5 = self.__accuracy(output, target, topk=(1, 5)) self.log('val_loss', loss_val, on_step=True, on_epoch=True) self.log('val_acc1', acc1, on_step=True, prog_bar=True, on_epoch=True) self.log('val_acc5', acc5, on_step=True, on_epoch=True) @staticmethod def __accuracy(output, target, topk=(1, )): """Computes the accuracy over the k top predictions for the specified values of k""" with torch.no_grad(): maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res def configure_optimizers(self): # optimizer = optim.SGD( # self.parameters(), # lr=self.lr, # momentum=self.momentum, # weight_decay=self.weight_decay # ) optimizer = optim.Adam(self.parameters(), lr=(self.lr or self.learning_rate), betas=[0.9, 0.999], weight_decay=self.weight_decay) scheduler = lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.1**(epoch // 30)) return [optimizer], [scheduler] def train_dataloader(self): cfg = self.model_cfg normalize = transforms.Normalize( mean=cfg['mean'], std=cfg['std'], ) train_dir = os.path.join(self.data_path, 'train') train_dataset = datasets.ImageFolder( train_dir, transforms.Compose([ transforms.RandomResizedCrop(cfg['input_size'][1:]), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_loader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.workers, ) return train_loader def val_dataloader(self): cfg = self.model_cfg crop_pct = cfg['crop_pct'] img_size = cfg['input_size'][1:] scale_size = tuple([int(x / crop_pct) for x in img_size]) normalize = transforms.Normalize( mean=cfg['mean'], std=cfg['std'], ) val_dir = os.path.join(self.data_path, 'val') val_loader = torch.utils.data.DataLoader( datasets.ImageFolder( val_dir, transforms.Compose([ transforms.Resize(scale_size, _pil_interp(cfg['interpolation'])), transforms.CenterCrop(img_size), transforms.ToTensor(), normalize, ])), batch_size=self.batch_size, shuffle=False, num_workers=self.workers, ) return val_loader def test_dataloader(self): return self.val_dataloader() def test_step(self, *args, **kwargs): return self.validation_step(*args, **kwargs) def test_epoch_end(self, *args, **kwargs): outputs = self.validation_epoch_end(*args, **kwargs) def substitute_val_keys(out): return {k.replace('val', 'test'): v for k, v in out.items()} outputs = { 'test_loss': outputs['val_loss'], 'progress_bar': substitute_val_keys(outputs['progress_bar']), 'log': substitute_val_keys(outputs['log']), } return outputs @staticmethod def add_model_specific_args(parent_parser): # pragma: no-cover model_names = ImageNetLightningModelForVit.MODEL_NAMES parser = ArgumentParser(parents=[parent_parser]) parser.add_argument( '-a', '--arch', metavar='ARCH', default=model_names[0], choices=model_names, help=('model architecture: ' + ' | '.join(model_names) + ' (default: resnet18)')) parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument( '-b', '--batch-size', default=256, type=int, metavar='N', help='mini-batch size (default: 256), this is the total ' 'batch size of all GPUs on the current node when ' 'using Data Parallel or Distributed Data Parallel') parser.add_argument('--lr', '--learning-rate', default=0.3, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--wd', '--weight-decay', default=1e-1, type=float, metavar='W', help='weight decay (default: 1e-4)', dest='weight_decay') parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model') return parser