Exemplo n.º 1
0
def deit_tiny_colab_patch16_224(pretrained=False, all_key_dim=None, **kwargs):
    model = VisionTransformer(
        patch_size=16,
        embed_dim=192,
        depth=12,
        num_heads=3,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
            map_location="cpu",
            check_hash=True,
        )
        model.load_state_dict(checkpoint["model"])

    model.cuda()
    collaborate_attention.swap(model, all_key_dim)
    model.cpu()

    return model
Exemplo n.º 2
0
def deit_base_patch16_384(pretrained=False, **kwargs):
    model = VisionTransformer(
        img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model
Exemplo n.º 3
0
def deit_tiny_patch16_224_ex6(pretrained=False, **kwargs):
    # the expanded Deit-T in Table 1
    model = VisionTransformer(patch_size=16,
                              embed_dim=192,
                              depth=12,
                              num_heads=3,
                              mlp_ratio=6,
                              qkv_bias=True,
                              norm_layer=partial(nn.LayerNorm, eps=1e-6),
                              **kwargs)
    model.default_cfg = _cfg()
    return model
Exemplo n.º 4
0
def deit_base3_patch16_224(pretrained=False, **kwargs):
    model = VisionTransformer(
        patch_size=16,
        embed_dim=768,
        depth=3,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )
    model.default_cfg = _cfg()
    assert not pretrained
    return model
Exemplo n.º 5
0
def deit_base_patch16_224_collab256(pretrained=False, **kwargs):
    model = VisionTransformer(
        patch_size=16,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )
    model.default_cfg = _cfg()
    collaborate_attention.swap(model, compressed_key_dim=256, reparametrize=False)
    return model
Exemplo n.º 6
0
    def load_pretrained(self):
        model = VisionTransformer(
            patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
            norm_layer=nn.LayerNorm)
        model.default_cfg = _cfg()
        checkpoint = torch.hub.load_state_dict_from_url(
            url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',
            map_location="cpu", check_hash=True
        )        
        model.load_state_dict(checkpoint['model'])

        self.pos_embed = model.pos_embed
        self.patch_embed = model.patch_embed
        self.blocks = model.blocks
        self.norm = model.norm
Exemplo n.º 7
0
def deit_base3_patch16_224_key96(pretrained=False, **kwargs):
    import timm.models.vision_transformer
    from collaborate_attention import FlexibleKeyDimensionAttention
    timm.models.vision_transformer.Attention = partial(FlexibleKeyDimensionAttention, all_key_dim=96)

    model = VisionTransformer(
        patch_size=16,
        embed_dim=768,
        depth=3,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )
    model.default_cfg = _cfg()
    assert not pretrained
    return model
Exemplo n.º 8
0
def deit_132_3_24(pretrained=False, **kwargs):  # 1258.532028
    model = VisionTransformer(patch_size=16,
                              embed_dim=132,
                              depth=24,
                              num_heads=3,
                              mlp_ratio=4,
                              qkv_bias=True,
                              norm_layer=partial(nn.LayerNorm, eps=1e-6),
                              **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url=
            "https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
            map_location="cpu",
            check_hash=True)
        model.load_state_dict(checkpoint["model"])
    return model
Exemplo n.º 9
0
def deit_base3_patch16_224_collab96(pretrained=False, models_directory=None, **kwargs):
    model = VisionTransformer(
        patch_size=16,
        embed_dim=768,
        depth=3,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )
    collaborate_attention.swap(model, compressed_key_dim=96, reparametrize=False)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint_path = pathlib.Path(models_directory) / "deit_base3_patch16_224_collab96.pth"
        print(f"Load model from '{checkpoint_path}'")
        checkpoint = torch.load(checkpoint_path, map_location="cpu")
        model.load_state_dict(checkpoint["model"])
    return model
Exemplo n.º 10
0
def deit_base_patch16_224(pretrained=False, **kwargs):
    if "num_landmarks" not in kwargs:
        model = VisionTransformer(patch_size=16,
                                  embed_dim=768,
                                  depth=12,
                                  num_heads=12,
                                  mlp_ratio=4,
                                  qkv_bias=True,
                                  norm_layer=partial(nn.LayerNorm, eps=1e-6),
                                  **kwargs)
    else:
        model = Nystromformer(patch_size=16,
                              embed_dim=768,
                              depth=12,
                              num_heads=12,
                              mlp_ratio=4,
                              qkv_bias=True,
                              norm_layer=partial(nn.LayerNorm, eps=1e-6),
                              **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url=
            "http://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
            map_location="cpu",
            check_hash=True)
        model.load_state_dict(checkpoint["model"])
    return model
Exemplo n.º 11
0
 def __init__(
     self,
     arch: str,
     pretrained: bool,
     lr: float,
     weight_decay: int,
     data_path: str,
     batch_size: int,
     workers: int,
     **kwargs,
 ):
     super().__init__()
     self.save_hyperparameters()
     self.arch = arch
     self.pretrained = pretrained
     self.lr = lr
     self.weight_decay = weight_decay
     self.data_path = data_path
     self.batch_size = batch_size
     self.workers = workers
     if self.arch == "tiny":
         # Tiny empty model for development purposes.
         img_size = [32, 32]
         self.model_cfg = _cfg(input_size=[3] + img_size)
         self.model = VisionTransformer(img_size=img_size,
                                        patch_size=4,
                                        in_chans=3,
                                        num_classes=1000,
                                        embed_dim=16,
                                        depth=2,
                                        num_heads=1)
     else:
         self.model: VisionTransformer = timm.create_model(
             self.arch, pretrained=self.pretrained)
         self.model_cfg = vision_transformer.default_cfgs[self.arch]
     # TODO: delete me. Hack so that auto_lr_find works
     self.model.reset_classifier(10)
Exemplo n.º 12
0
class ImageNetLightningModelForVit(LightningModule):
    # pull out vit models
    MODEL_NAMES = sorted(
        list(vision_transformer.default_cfgs.keys() | {"tiny"}))

    def __init__(
        self,
        arch: str,
        pretrained: bool,
        lr: float,
        weight_decay: int,
        data_path: str,
        batch_size: int,
        workers: int,
        **kwargs,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.arch = arch
        self.pretrained = pretrained
        self.lr = lr
        self.weight_decay = weight_decay
        self.data_path = data_path
        self.batch_size = batch_size
        self.workers = workers
        if self.arch == "tiny":
            # Tiny empty model for development purposes.
            img_size = [32, 32]
            self.model_cfg = _cfg(input_size=[3] + img_size)
            self.model = VisionTransformer(img_size=img_size,
                                           patch_size=4,
                                           in_chans=3,
                                           num_classes=1000,
                                           embed_dim=16,
                                           depth=2,
                                           num_heads=1)
        else:
            self.model: VisionTransformer = timm.create_model(
                self.arch, pretrained=self.pretrained)
            self.model_cfg = vision_transformer.default_cfgs[self.arch]
        # TODO: delete me. Hack so that auto_lr_find works
        self.model.reset_classifier(10)

    def setup(self, stage: str):
        # Configuring the head of the model to the number of classes
        train_dir = os.path.join(self.data_path, 'train')
        train_dataset = datasets.ImageFolder(train_dir)
        self.model.reset_classifier(len(train_dataset.classes))

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        images, target = batch
        output = self(images)
        loss_train = F.cross_entropy(output, target)
        acc1, acc5 = self.__accuracy(output, target, topk=(1, 5))
        self.log('train_loss',
                 loss_train,
                 on_step=True,
                 on_epoch=True,
                 logger=True)
        self.log('train_acc1',
                 acc1,
                 on_step=True,
                 prog_bar=True,
                 on_epoch=True,
                 logger=True)
        self.log('train_acc5', acc5, on_step=True, on_epoch=True, logger=True)
        return loss_train

    def validation_step(self, batch, batch_idx):
        images, target = batch
        output = self(images)
        loss_val = F.cross_entropy(output, target)
        acc1, acc5 = self.__accuracy(output, target, topk=(1, 5))
        self.log('val_loss', loss_val, on_step=True, on_epoch=True)
        self.log('val_acc1', acc1, on_step=True, prog_bar=True, on_epoch=True)
        self.log('val_acc5', acc5, on_step=True, on_epoch=True)

    @staticmethod
    def __accuracy(output, target, topk=(1, )):
        """Computes the accuracy over the k top predictions for the specified values of k"""
        with torch.no_grad():
            maxk = max(topk)
            batch_size = target.size(0)

            _, pred = output.topk(maxk, 1, True, True)
            pred = pred.t()
            correct = pred.eq(target.view(1, -1).expand_as(pred))

            res = []
            for k in topk:
                correct_k = correct[:k].reshape(-1).float().sum(0,
                                                                keepdim=True)
                res.append(correct_k.mul_(100.0 / batch_size))
            return res

    def configure_optimizers(self):
        # optimizer = optim.SGD(
        #     self.parameters(),
        #     lr=self.lr,
        #     momentum=self.momentum,
        #     weight_decay=self.weight_decay
        # )
        optimizer = optim.Adam(self.parameters(),
                               lr=(self.lr or self.learning_rate),
                               betas=[0.9, 0.999],
                               weight_decay=self.weight_decay)
        scheduler = lr_scheduler.LambdaLR(optimizer,
                                          lambda epoch: 0.1**(epoch // 30))
        return [optimizer], [scheduler]

    def train_dataloader(self):
        cfg = self.model_cfg
        normalize = transforms.Normalize(
            mean=cfg['mean'],
            std=cfg['std'],
        )

        train_dir = os.path.join(self.data_path, 'train')
        train_dataset = datasets.ImageFolder(
            train_dir,
            transforms.Compose([
                transforms.RandomResizedCrop(cfg['input_size'][1:]),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        train_loader = torch.utils.data.DataLoader(
            dataset=train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.workers,
        )
        return train_loader

    def val_dataloader(self):
        cfg = self.model_cfg
        crop_pct = cfg['crop_pct']
        img_size = cfg['input_size'][1:]
        scale_size = tuple([int(x / crop_pct) for x in img_size])

        normalize = transforms.Normalize(
            mean=cfg['mean'],
            std=cfg['std'],
        )
        val_dir = os.path.join(self.data_path, 'val')
        val_loader = torch.utils.data.DataLoader(
            datasets.ImageFolder(
                val_dir,
                transforms.Compose([
                    transforms.Resize(scale_size,
                                      _pil_interp(cfg['interpolation'])),
                    transforms.CenterCrop(img_size),
                    transforms.ToTensor(),
                    normalize,
                ])),
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.workers,
        )
        return val_loader

    def test_dataloader(self):
        return self.val_dataloader()

    def test_step(self, *args, **kwargs):
        return self.validation_step(*args, **kwargs)

    def test_epoch_end(self, *args, **kwargs):
        outputs = self.validation_epoch_end(*args, **kwargs)

        def substitute_val_keys(out):
            return {k.replace('val', 'test'): v for k, v in out.items()}

        outputs = {
            'test_loss': outputs['val_loss'],
            'progress_bar': substitute_val_keys(outputs['progress_bar']),
            'log': substitute_val_keys(outputs['log']),
        }
        return outputs

    @staticmethod
    def add_model_specific_args(parent_parser):  # pragma: no-cover
        model_names = ImageNetLightningModelForVit.MODEL_NAMES
        parser = ArgumentParser(parents=[parent_parser])
        parser.add_argument(
            '-a',
            '--arch',
            metavar='ARCH',
            default=model_names[0],
            choices=model_names,
            help=('model architecture: ' + ' | '.join(model_names) +
                  ' (default: resnet18)'))
        parser.add_argument('-j',
                            '--workers',
                            default=4,
                            type=int,
                            metavar='N',
                            help='number of data loading workers (default: 4)')
        parser.add_argument(
            '-b',
            '--batch-size',
            default=256,
            type=int,
            metavar='N',
            help='mini-batch size (default: 256), this is the total '
            'batch size of all GPUs on the current node when '
            'using Data Parallel or Distributed Data Parallel')
        parser.add_argument('--lr',
                            '--learning-rate',
                            default=0.3,
                            type=float,
                            metavar='LR',
                            help='initial learning rate',
                            dest='lr')
        parser.add_argument('--wd',
                            '--weight-decay',
                            default=1e-1,
                            type=float,
                            metavar='W',
                            help='weight decay (default: 1e-4)',
                            dest='weight_decay')
        parser.add_argument('--pretrained',
                            dest='pretrained',
                            action='store_true',
                            help='use pre-trained model')
        return parser