class BYOL(pl.LightningModule):
    def __init__(self,
                 num_classes,
                 learning_rate: float = 0.2,
                 weight_decay: float = 1.5e-6,
                 input_height: int = 32,
                 batch_size: int = 32,
                 num_workers: int = 0,
                 warmup_epochs: int = 10,
                 max_epochs: int = 1000,
                 **kwargs):
        """
        PyTorch Lightning implementation of `Bootstrap Your Own Latent (BYOL)
        <https://arxiv.org/pdf/2006.07733.pdf>`_

        Paper authors: Jean-Bastien Grill ,Florian Strub, Florent Altché, Corentin Tallec, Pierre H. Richemond, \
        Elena Buchatskaya, Carl Doersch, Bernardo Avila Pires, Zhaohan Daniel Guo, Mohammad Gheshlaghi Azar, \
        Bilal Piot, Koray Kavukcuoglu, Rémi Munos, Michal Valko.

        Model implemented by:
            - `Annika Brundyn <https://github.com/annikabrundyn>`_

        .. warning:: Work in progress. This implementation is still being verified.

        TODOs:
            - verify on CIFAR-10
            - verify on STL-10
            - pre-train on imagenet

        Example::

            import pytorch_lightning as pl
            from pl_bolts.models.self_supervised import BYOL
            from pl_bolts.datamodules import CIFAR10DataModule
            from pl_bolts.models.self_supervised.simclr.transforms import (
                SimCLREvalDataTransform, SimCLRTrainDataTransform)

            # model
            model = BYOL(num_classes=10)

            # data
            dm = CIFAR10DataModule(num_workers=0)
            dm.train_transforms = SimCLRTrainDataTransform(32)
            dm.val_transforms = SimCLREvalDataTransform(32)

            trainer = pl.Trainer()
            trainer.fit(model, dm)

        Train::

            trainer = Trainer()
            trainer.fit(model)

        CLI command::

            # cifar10
            python byol_module.py --gpus 1

            # imagenet
            python byol_module.py
                --gpus 8
                --dataset imagenet2012
                --data_dir /path/to/imagenet/
                --meta_dir /path/to/folder/with/meta.bin/
                --batch_size 32

        Args:
            datamodule: The datamodule
            learning_rate: the learning rate
            weight_decay: optimizer weight decay
            input_height: image input height
            batch_size: the batch size
            num_workers: number of workers
            warmup_epochs: num of epochs for scheduler warm up
            max_epochs: max epochs for scheduler
        """
        super().__init__()
        self.save_hyperparameters()

        self.online_network = SiameseArm()
        self.target_network = deepcopy(self.online_network)
        self.weight_callback = BYOLMAWeightUpdate()

    def on_train_batch_end(self, batch: Any, batch_idx: int,
                           dataloader_idx: int) -> None:
        # Add callback for user automatically since it's key to BYOL weight update
        self.weight_callback.on_train_batch_end(self.trainer, self, batch,
                                                batch_idx, dataloader_idx)

    def forward(self, x):
        y, _, _ = self.online_network(x)
        return y

    def cosine_similarity(self, a, b):
        a = F.normalize(a, dim=-1)
        b = F.normalize(b, dim=-1)
        sim = (a * b).sum(-1).mean()
        return sim

    def shared_step(self, batch, batch_idx):
        (img_1, img_2), y = batch

        # Image 1 to image 2 loss
        y1, z1, h1 = self.online_network(img_1)
        with torch.no_grad():
            y2, z2, h2 = self.target_network(img_2)
        loss_a = -2 * self.cosine_similarity(h1, z2)

        # Image 2 to image 1 loss
        y1, z1, h1 = self.online_network(img_2)
        with torch.no_grad():
            y2, z2, h2 = self.target_network(img_1)
        # L2 normalize
        loss_b = -2 * self.cosine_similarity(h1, z2)

        # Final loss
        total_loss = loss_a + loss_b

        return loss_a, loss_b, total_loss

    def training_step(self, batch, batch_idx):
        loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx)

        # log results
        result = pl.TrainResult(minimize=total_loss)
        result.log_dict({
            '1_2_loss': loss_a,
            '2_1_loss': loss_b,
            'train_loss': total_loss
        })

        return result

    def validation_step(self, batch, batch_idx):
        loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx)

        # log results
        result = pl.EvalResult(early_stop_on=total_loss,
                               checkpoint_on=total_loss)
        result.log_dict({
            '1_2_loss': loss_a,
            '2_1_loss': loss_b,
            'train_loss': total_loss
        })

        return result

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(),
                         lr=self.hparams.learning_rate,
                         weight_decay=self.hparams.weight_decay)
        optimizer = LARSWrapper(optimizer)
        scheduler = LinearWarmupCosineAnnealingLR(
            optimizer,
            warmup_epochs=self.hparams.warmup_epochs,
            max_epochs=self.hparams.max_epochs)
        return [optimizer], [scheduler]

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--online_ft',
                            action='store_true',
                            help='run online finetuner')
        parser.add_argument('--dataset',
                            type=str,
                            default='cifar10',
                            help='cifar10, imagenet2012, stl10')

        (args, _) = parser.parse_known_args()

        # Data
        parser.add_argument('--data_dir', type=str, default='.')
        parser.add_argument('--num_workers', default=0, type=int)

        # optim
        parser.add_argument('--batch_size', type=int, default=256)
        parser.add_argument('--learning_rate', type=float, default=1e-3)
        parser.add_argument('--weight_decay', type=float, default=1.5e-6)
        parser.add_argument('--warmup_epochs', type=float, default=10)

        # Model
        parser.add_argument('--meta_dir',
                            default='.',
                            type=str,
                            help='path to meta.bin for imagenet')

        return parser
Esempio n. 2
0
class BYOL(pl.LightningModule):
    def __init__(self,
                 lr: float = 0.2,
                 weight_decay: float = 15e-6,
                 input_height: int = 32,
                 batch_size: int = 32,
                 num_workers: int = 0,
                 warmup_epochs: int = 10,
                 max_epochs: int = 1000,
                 **kwargs):
        super().__init__()
        self.save_hyperparameters()

        self.online_network = SiameseArm()
        self.target_network = deepcopy(self.online_network)
        self.weight_callback = BYOLMAWeightUpdate()

    def on_train_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
        # Add callback for user automatically since it's key to BYOL weight update
        self.weight_callback.on_train_batch_end(self.trainer, self, batch, batch_idx, dataloader_idx)

    def forward(self, x):
        y, _, _ = self.online_network(x)
        return y

    def shared_step(self, batch, batch_idx):
        img_1, img_2 = batch

        # Image 1 to image 2 loss
        y1, z1, h1 = self.online_network(img_1)
        with torch.no_grad():
            y2, z2, h2 = self.target_network(img_2)
        # L2 normalize
        h1_norm = F.normalize(h1, p=2, dim=1)
        z2_norm = F.normalize(z2, p=2, dim=1)
        loss_a = F.mse_loss(h1_norm, z2_norm)

        # Image 2 to image 1 loss
        y1, z1, h1 = self.online_network(img_2)
        with torch.no_grad():
            y2, z2, h2 = self.target_network(img_1)
        # L2 normalize
        h1_norm = F.normalize(h1, p=2, dim=1)
        z2_norm = F.normalize(z2, p=2, dim=1)
        loss_b = F.mse_loss(h1_norm, z2_norm)

        # Final loss
        total_loss = loss_a + loss_b

        return loss_a, loss_b, total_loss

    def training_step(self, batch, batch_idx):
        loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx)

        # log results
        result = pl.TrainResult(minimize=total_loss)
        #  result.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b, 'train_loss': total_loss})
        result.log_dict({'train_loss': total_loss})

        return result

    def validation_step(self, batch, batch_idx):
        loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx)

        # log results
        result = pl.EvalResult(early_stop_on=total_loss, checkpoint_on=total_loss)
        #  result.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b, 'train_loss': total_loss})
        result.log_dict({'valid_loss': total_loss})

        return result

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        optimizer = LARSWrapper(optimizer)
        scheduler = LinearWarmupCosineAnnealingLR(
            optimizer,
            warmup_epochs=self.hparams.warmup_epochs,
            max_epochs=self.hparams.max_epochs
        )
        return [optimizer], [scheduler]

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False, conflict_handler='resolve')
        (args, _) = parser.parse_known_args()
        # Data
        parser.add_argument('--data_dir', type=str, default='.')
        parser.add_argument('--num_workers', default=10, type=int)
        # optim
        parser.add_argument('--batch_size', type=int, default=256)
        parser.add_argument('--lr', type=float, default=1e-3)
        parser.add_argument('--weight_decay', type=float, default=15e-6)
        parser.add_argument('--warmup_epochs', type=float, default=10)
        # Model
        parser.add_argument('--meta_dir', default='.', type=str)
        # Custom
        parser.add_argument('--exp', type=str, default='Pretrain')
        parser.add_argument('--name', type=str, default='test')
        parser.add_argument('--benchmark', type=bool, default=True)
        parser.add_argument('--amp_level', type=str, default='O0')

        return parser