class BYOL(pl.LightningModule): def __init__(self, num_classes, learning_rate: float = 0.2, weight_decay: float = 1.5e-6, input_height: int = 32, batch_size: int = 32, num_workers: int = 0, warmup_epochs: int = 10, max_epochs: int = 1000, **kwargs): """ PyTorch Lightning implementation of `Bootstrap Your Own Latent (BYOL) <https://arxiv.org/pdf/2006.07733.pdf>`_ Paper authors: Jean-Bastien Grill ,Florian Strub, Florent Altché, Corentin Tallec, Pierre H. Richemond, \ Elena Buchatskaya, Carl Doersch, Bernardo Avila Pires, Zhaohan Daniel Guo, Mohammad Gheshlaghi Azar, \ Bilal Piot, Koray Kavukcuoglu, Rémi Munos, Michal Valko. Model implemented by: - `Annika Brundyn <https://github.com/annikabrundyn>`_ .. warning:: Work in progress. This implementation is still being verified. TODOs: - verify on CIFAR-10 - verify on STL-10 - pre-train on imagenet Example:: import pytorch_lightning as pl from pl_bolts.models.self_supervised import BYOL from pl_bolts.datamodules import CIFAR10DataModule from pl_bolts.models.self_supervised.simclr.transforms import ( SimCLREvalDataTransform, SimCLRTrainDataTransform) # model model = BYOL(num_classes=10) # data dm = CIFAR10DataModule(num_workers=0) dm.train_transforms = SimCLRTrainDataTransform(32) dm.val_transforms = SimCLREvalDataTransform(32) trainer = pl.Trainer() trainer.fit(model, dm) Train:: trainer = Trainer() trainer.fit(model) CLI command:: # cifar10 python byol_module.py --gpus 1 # imagenet python byol_module.py --gpus 8 --dataset imagenet2012 --data_dir /path/to/imagenet/ --meta_dir /path/to/folder/with/meta.bin/ --batch_size 32 Args: datamodule: The datamodule learning_rate: the learning rate weight_decay: optimizer weight decay input_height: image input height batch_size: the batch size num_workers: number of workers warmup_epochs: num of epochs for scheduler warm up max_epochs: max epochs for scheduler """ super().__init__() self.save_hyperparameters() self.online_network = SiameseArm() self.target_network = deepcopy(self.online_network) self.weight_callback = BYOLMAWeightUpdate() def on_train_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: # Add callback for user automatically since it's key to BYOL weight update self.weight_callback.on_train_batch_end(self.trainer, self, batch, batch_idx, dataloader_idx) def forward(self, x): y, _, _ = self.online_network(x) return y def cosine_similarity(self, a, b): a = F.normalize(a, dim=-1) b = F.normalize(b, dim=-1) sim = (a * b).sum(-1).mean() return sim def shared_step(self, batch, batch_idx): (img_1, img_2), y = batch # Image 1 to image 2 loss y1, z1, h1 = self.online_network(img_1) with torch.no_grad(): y2, z2, h2 = self.target_network(img_2) loss_a = -2 * self.cosine_similarity(h1, z2) # Image 2 to image 1 loss y1, z1, h1 = self.online_network(img_2) with torch.no_grad(): y2, z2, h2 = self.target_network(img_1) # L2 normalize loss_b = -2 * self.cosine_similarity(h1, z2) # Final loss total_loss = loss_a + loss_b return loss_a, loss_b, total_loss def training_step(self, batch, batch_idx): loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx) # log results result = pl.TrainResult(minimize=total_loss) result.log_dict({ '1_2_loss': loss_a, '2_1_loss': loss_b, 'train_loss': total_loss }) return result def validation_step(self, batch, batch_idx): loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx) # log results result = pl.EvalResult(early_stop_on=total_loss, checkpoint_on=total_loss) result.log_dict({ '1_2_loss': loss_a, '2_1_loss': loss_b, 'train_loss': total_loss }) return result def configure_optimizers(self): optimizer = Adam(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) optimizer = LARSWrapper(optimizer) scheduler = LinearWarmupCosineAnnealingLR( optimizer, warmup_epochs=self.hparams.warmup_epochs, max_epochs=self.hparams.max_epochs) return [optimizer], [scheduler] @staticmethod def add_model_specific_args(parent_parser): parser = ArgumentParser(parents=[parent_parser], add_help=False) parser.add_argument('--online_ft', action='store_true', help='run online finetuner') parser.add_argument('--dataset', type=str, default='cifar10', help='cifar10, imagenet2012, stl10') (args, _) = parser.parse_known_args() # Data parser.add_argument('--data_dir', type=str, default='.') parser.add_argument('--num_workers', default=0, type=int) # optim parser.add_argument('--batch_size', type=int, default=256) parser.add_argument('--learning_rate', type=float, default=1e-3) parser.add_argument('--weight_decay', type=float, default=1.5e-6) parser.add_argument('--warmup_epochs', type=float, default=10) # Model parser.add_argument('--meta_dir', default='.', type=str, help='path to meta.bin for imagenet') return parser
class BYOL(pl.LightningModule): def __init__(self, lr: float = 0.2, weight_decay: float = 15e-6, input_height: int = 32, batch_size: int = 32, num_workers: int = 0, warmup_epochs: int = 10, max_epochs: int = 1000, **kwargs): super().__init__() self.save_hyperparameters() self.online_network = SiameseArm() self.target_network = deepcopy(self.online_network) self.weight_callback = BYOLMAWeightUpdate() def on_train_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: # Add callback for user automatically since it's key to BYOL weight update self.weight_callback.on_train_batch_end(self.trainer, self, batch, batch_idx, dataloader_idx) def forward(self, x): y, _, _ = self.online_network(x) return y def shared_step(self, batch, batch_idx): img_1, img_2 = batch # Image 1 to image 2 loss y1, z1, h1 = self.online_network(img_1) with torch.no_grad(): y2, z2, h2 = self.target_network(img_2) # L2 normalize h1_norm = F.normalize(h1, p=2, dim=1) z2_norm = F.normalize(z2, p=2, dim=1) loss_a = F.mse_loss(h1_norm, z2_norm) # Image 2 to image 1 loss y1, z1, h1 = self.online_network(img_2) with torch.no_grad(): y2, z2, h2 = self.target_network(img_1) # L2 normalize h1_norm = F.normalize(h1, p=2, dim=1) z2_norm = F.normalize(z2, p=2, dim=1) loss_b = F.mse_loss(h1_norm, z2_norm) # Final loss total_loss = loss_a + loss_b return loss_a, loss_b, total_loss def training_step(self, batch, batch_idx): loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx) # log results result = pl.TrainResult(minimize=total_loss) # result.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b, 'train_loss': total_loss}) result.log_dict({'train_loss': total_loss}) return result def validation_step(self, batch, batch_idx): loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx) # log results result = pl.EvalResult(early_stop_on=total_loss, checkpoint_on=total_loss) # result.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b, 'train_loss': total_loss}) result.log_dict({'valid_loss': total_loss}) return result def configure_optimizers(self): optimizer = Adam(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay) optimizer = LARSWrapper(optimizer) scheduler = LinearWarmupCosineAnnealingLR( optimizer, warmup_epochs=self.hparams.warmup_epochs, max_epochs=self.hparams.max_epochs ) return [optimizer], [scheduler] @staticmethod def add_model_specific_args(parent_parser): parser = ArgumentParser(parents=[parent_parser], add_help=False, conflict_handler='resolve') (args, _) = parser.parse_known_args() # Data parser.add_argument('--data_dir', type=str, default='.') parser.add_argument('--num_workers', default=10, type=int) # optim parser.add_argument('--batch_size', type=int, default=256) parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--weight_decay', type=float, default=15e-6) parser.add_argument('--warmup_epochs', type=float, default=10) # Model parser.add_argument('--meta_dir', default='.', type=str) # Custom parser.add_argument('--exp', type=str, default='Pretrain') parser.add_argument('--name', type=str, default='test') parser.add_argument('--benchmark', type=bool, default=True) parser.add_argument('--amp_level', type=str, default='O0') return parser