def cli_main(): # pragma: no-cover pl.seed_everything(1234) parser = ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) parser.add_argument('--dataset', type=str, help='stl10, cifar10', default='cifar10') parser.add_argument('--ckpt_path', type=str, help='path to ckpt') parser.add_argument('--data_dir', type=str, help='path to ckpt', default=os.getcwd()) args = parser.parse_args() # load the backbone backbone = SimCLR.load_from_checkpoint(args.ckpt_path, strict=False) # init default datamodule if args.dataset == 'cifar10': dm = CIFAR10DataModule.from_argparse_args(args) dm.train_transforms = SimCLRTrainDataTransform(32) dm.val_transforms = SimCLREvalDataTransform(32) dm.test_transforms = SimCLREvalDataTransform(32) args.num_samples = dm.num_samples elif args.dataset == 'stl10': dm = STL10DataModule.from_argparse_args(args) dm.train_dataloader = dm.train_dataloader_labeled dm.val_dataloader = dm.val_dataloader_labeled args.num_samples = dm.num_labeled_samples (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) elif args.dataset == 'imagenet2012': dm = ImagenetDataModule.from_argparse_args(args, image_size=196) (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) # finetune tuner = SSLFineTuner(backbone, in_features=2048 * 2 * 2, num_classes=dm.num_classes, hidden_dim=None) trainer = pl.Trainer.from_argparse_args(args, early_stop_callback=True) trainer.fit(tuner, dm) trainer.test(datamodule=dm)
def cli_main(): parser = ArgumentParser() # trainer args parser = pl.Trainer.add_argparse_args(parser) # model args parser = SimCLR.add_model_specific_args(parser) args = parser.parse_args() # init default datamodule if args.dataset == 'cifar10': dm = CIFAR10DataModule.from_argparse_args(args) dm.train_transforms = SimCLRTrainDataTransform(32) dm.val_transforms = SimCLREvalDataTransform(32) args.num_samples = dm.num_samples elif args.dataset == 'stl10': dm = STL10DataModule.from_argparse_args(args) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed args.num_samples = dm.num_unlabeled_samples (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) elif args.dataset == 'imagenet2012': dm = ImagenetDataModule.from_argparse_args(args, image_size=196) (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) model = SimCLR(**args.__dict__) # finetune in real-time def to_device(batch, device): (x1, x2), y = batch x1 = x1.to(device) y = y.to(device) return x1, y online_eval = SSLOnlineEvaluator(z_dim=2048 * 2 * 2, num_classes=dm.num_classes) online_eval.to_device = to_device trainer = pl.Trainer.from_argparse_args(args, callbacks=[online_eval]) trainer.fit(model, dm)
def test_simclr_transforms(tmpdir, img_size): pl.seed_everything(0) (c, h, w) = img_size x = torch.rand(c, h, w) x = transforms.ToPILImage(mode='RGB')(x) transform = SimCLREvalDataTransform(input_height=h) transform(x) transform = SimCLRTrainDataTransform(input_height=h) transform(x)
def test_simclr(tmpdir): reset_seed() datamodule = CIFAR10DataModule(tmpdir, num_workers=0) datamodule.train_transforms = SimCLRTrainDataTransform(32) datamodule.val_transforms = SimCLREvalDataTransform(32) model = SimCLR(data_dir=tmpdir, batch_size=2, datamodule=datamodule, online_ft=True) trainer = pl.Trainer(overfit_batches=2, max_epochs=1, default_root_dir=tmpdir) trainer.fit(model) loss = trainer.callback_metrics['loss'] assert loss > 0
def test_simclr(tmpdir): seed_everything() datamodule = CIFAR10DataModule(tmpdir, num_workers=0, batch_size=2) datamodule.train_transforms = SimCLRTrainDataTransform(32) datamodule.val_transforms = SimCLREvalDataTransform(32) model = SimCLR(batch_size=2, num_samples=datamodule.num_samples) trainer = pl.Trainer(fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir) trainer.fit(model, datamodule) loss = trainer.progress_bar_dict['loss'] assert float(loss) > 0
from argparse import ArgumentParser parser = ArgumentParser() # trainer args parser = pl.Trainer.add_argparse_args(parser) # model args parser = SimCLR.add_model_specific_args(parser) args = parser.parse_args() # init default datamodule if args.dataset == 'cifar10': dm = CIFAR10DataModule.from_argparse_args(args) dm.train_transforms = SimCLRTrainDataTransform(32) dm.val_transforms = SimCLREvalDataTransform(32) args.num_samples = dm.num_samples elif args.dataset == 'stl10': dm = STL10DataModule.from_argparse_args(args) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed args.num_samples = dm.num_unlabeled_samples (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) elif args.dataset == 'imagenet2012': dm = ImagenetDataModule.from_argparse_args(args, image_size=196) (c, h, w) = dm.size()
def __init__(self, datamodule: pl_bolts.datamodules.LightningDataModule = None, data_dir: str = '', learning_rate: float = 0.00006, weight_decay: float = 0.0005, input_height: int = 32, batch_size: int = 128, online_ft: bool = False, num_workers: int = 4, optimizer: str = 'lars', lr_sched_step: float = 30.0, lr_sched_gamma: float = 0.5, lars_momentum: float = 0.9, lars_eta: float = 0.001, loss_temperature: float = 0.5, **kwargs): """ PyTorch Lightning implementation of `SIMCLR <https://arxiv.org/abs/2002.05709.>`_ Paper authors: Ting Chen, Simon Kornblith, Mohammad Norouzi, Geoffrey Hinton. Model implemented by: - `William Falcon <https://github.com/williamFalcon>`_ - `Tullie Murrell <https://github.com/tullie>`_ Example: >>> from pl_bolts.models.self_supervised import SimCLR ... >>> model = SimCLR() Train:: trainer = Trainer() trainer.fit(model) CLI command:: # cifar10 python simclr_module.py --gpus 1 # imagenet python simclr_module.py --gpus 8 --dataset imagenet2012 --data_dir /path/to/imagenet/ --meta_dir /path/to/folder/with/meta.bin/ --batch_size 32 Args: datamodule: The datamodule data_dir: directory to store data learning_rate: the learning rate weight_decay: optimizer weight decay input_height: image input height batch_size: the batch size online_ft: whether to tune online or not num_workers: number of workers optimizer: optimizer name lr_sched_step: step for learning rate scheduler lr_sched_gamma: gamma for learning rate scheduler lars_momentum: the mom param for lars optimizer lars_eta: for lars optimizer loss_temperature: float = 0. """ super().__init__() self.save_hyperparameters() self.online_evaluator = online_ft # init default datamodule if datamodule is None: datamodule = CIFAR10DataModule(data_dir, num_workers=num_workers) datamodule.train_transforms = SimCLRTrainDataTransform( input_height) datamodule.val_transforms = SimCLREvalDataTransform(input_height) self.datamodule = datamodule self.loss_func = self.init_loss() self.encoder = self.init_encoder() self.projection = self.init_projection() if self.online_evaluator: z_dim = self.projection.output_dim num_classes = self.datamodule.num_classes self.non_linear_evaluator = SSLEvaluator(n_input=z_dim, n_classes=num_classes, p=0.2, n_hidden=1024)
# trainer args parser = pl.Trainer.add_argparse_args(parser) # model args parser = SimCLR.add_model_specific_args(parser) args = parser.parse_args() # pick data datamodule = None if args.dataset == 'stl10': datamodule = STL10DataModule.from_argparse_args(args) datamodule.train_dataloader = datamodule.train_dataloader_mixed datamodule.val_dataloader = datamodule.val_dataloader_mixed (c, h, w) = datamodule.size() datamodule.train_transforms = SimCLRTrainDataTransform(h) datamodule.val_transforms = SimCLREvalDataTransform(h) elif args.dataset == 'imagenet2012': datamodule = ImagenetDataModule.from_argparse_args(args, image_size=196) (c, h, w) = datamodule.size() datamodule.train_transforms = SimCLRTrainDataTransform(h) datamodule.val_transforms = SimCLREvalDataTransform(h) model = SimCLR(**args.__dict__, datamodule=datamodule) trainer = pl.Trainer.from_argparse_args(args) trainer.fit(model)
def __init__(self, datamodule: pl.LightningDataModule = None, data_dir: str = './', learning_rate: float = 0.2, weight_decay: float = 15e-6, input_height: int = 32, batch_size: int = 32, num_workers: int = 4, warmup_epochs: int = 10, max_epochs: int = 1000, **kwargs): """ PyTorch Lightning implementation of `Bring Your Own Latent (BYOL) <https://arxiv.org/pdf/2006.07733.pdf.>`_ Paper authors: Jean-Bastien Grill ,Florian Strub, Florent Altché, Corentin Tallec, Pierre H. Richemond, \ Elena Buchatskaya, Carl Doersch, Bernardo Avila Pires, Zhaohan Daniel Guo, Mohammad Gheshlaghi Azar, \ Bilal Piot, Koray Kavukcuoglu, Rémi Munos, Michal Valko. Model implemented by: - `Annika Brundyn <https://github.com/annikabrundyn>`_ .. warning:: Work in progress. This implementation is still being verified. TODOs: - verify on CIFAR-10 - verify on STL-10 - pre-train on imagenet Example: >>> from pl_bolts.models.self_supervised import BYOL ... >>> model = BYOL() Train:: trainer = Trainer() trainer.fit(model) CLI command:: # cifar10 python byol_module.py --gpus 1 # imagenet python byol_module.py --gpus 8 --dataset imagenet2012 --data_dir /path/to/imagenet/ --meta_dir /path/to/folder/with/meta.bin/ --batch_size 32 Args: datamodule: The datamodule data_dir: directory to store data learning_rate: the learning rate weight_decay: optimizer weight decay input_height: image input height batch_size: the batch size num_workers: number of workers warmup_epochs: num of epochs for scheduler warm up max_epochs: max epochs for scheduler """ super().__init__() self.save_hyperparameters() # init default datamodule if datamodule is None: datamodule = CIFAR10DataModule(data_dir, num_workers=num_workers, batch_size=batch_size) datamodule.train_transforms = SimCLRTrainDataTransform( input_height) datamodule.val_transforms = SimCLREvalDataTransform(input_height) self.datamodule = datamodule self.online_network = SiameseArm() self.target_network = deepcopy(self.online_network) self.weight_callback = BYOLMAWeightUpdate() # for finetuning callback self.z_dim = 2048 self.num_classes = self.datamodule.num_classes