def cli_main(): # pragma: no-cover pl.seed_everything(1234) parser = ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) parser.add_argument('--dataset', type=str, help='stl10, cifar10', default='cifar10') parser.add_argument('--ckpt_path', type=str, help='path to ckpt') parser.add_argument('--data_dir', type=str, help='path to ckpt', default=os.getcwd()) args = parser.parse_args() # load the backbone backbone = CPCV2.load_from_checkpoint(args.ckpt_path, strict=False) if args.dataset == 'cifar10': dm = CIFAR10DataModule.from_argparse_args(args) dm.train_transforms = CPCTrainTransformsCIFAR10() dm.val_transforms = CPCEvalTransformsCIFAR10() dm.test_transforms = CPCEvalTransformsCIFAR10() elif args.dataset == 'stl10': dm = STL10DataModule.from_argparse_args(args) dm.train_dataloader = dm.train_dataloader_labeled dm.val_dataloader = dm.val_dataloader_labeled dm.train_transforms = CPCTrainTransformsSTL10() dm.val_transforms = CPCEvalTransformsSTL10() dm.test_transforms = CPCEvalTransformsSTL10() # finetune tuner = SSLFineTuner(backbone, in_features=backbone.z_dim, num_classes=backbone.num_classes) trainer = pl.Trainer.from_argparse_args(args, early_stop_callback=True) trainer.fit(tuner, dm) trainer.test(datamodule=dm)
def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule pl.seed_everything(1234) parser = ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) parser = CPCV2.add_model_specific_args(parser) parser.add_argument('--dataset', default='cifar10', type=str) parser.add_argument('--data_dir', default='.', type=str) parser.add_argument('--meta_dir', default='.', type=str, help='path to meta.bin for imagenet') parser.add_argument('--num_workers', default=8, type=int) parser.add_argument('--batch_size', type=int, default=128) args = parser.parse_args() datamodule = None online_evaluator = SSLOnlineEvaluator() if args.dataset == 'cifar10': datamodule = CIFAR10DataModule.from_argparse_args(args) datamodule.train_transforms = CPCTrainTransformsCIFAR10() datamodule.val_transforms = CPCEvalTransformsCIFAR10() args.patch_size = 8 elif args.dataset == 'stl10': datamodule = STL10DataModule.from_argparse_args(args) datamodule.train_dataloader = datamodule.train_dataloader_mixed datamodule.val_dataloader = datamodule.val_dataloader_mixed datamodule.train_transforms = CPCTrainTransformsSTL10() datamodule.val_transforms = CPCEvalTransformsSTL10() args.patch_size = 16 # 16 GB RAM - 64 # 32 GB RAM - 144 args.batch_size = 144 def to_device(batch, device): (_, _), (x2, y2) = batch x2 = x2.to(device) y2 = y2.to(device) return x2, y2 online_evaluator.to_device = to_device elif args.dataset == 'imagenet2012': datamodule = SSLImagenetDataModule.from_argparse_args(args) datamodule.train_transforms = CPCTrainTransformsImageNet128() datamodule.val_transforms = CPCEvalTransformsImageNet128() args.patch_size = 32 model = CPCV2(**vars(args)) trainer = pl.Trainer.from_argparse_args(args, callbacks=[online_evaluator]) trainer.fit(model, datamodule)
def cli_main(): from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule pl.seed_everything(1234) parser = ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) parser = CPCV2.add_model_specific_args(parser) args = parser.parse_args() args.online_ft = True datamodule = None online_evaluator = SSLOnlineEvaluator() if args.dataset == 'cifar10': datamodule = CIFAR10DataModule.from_argparse_args(args) datamodule.train_transforms = CPCTrainTransformsCIFAR10() datamodule.val_transforms = CPCEvalTransformsCIFAR10() args.patch_size = 8 elif args.dataset == 'stl10': datamodule = STL10DataModule.from_argparse_args(args) datamodule.train_dataloader = datamodule.train_dataloader_mixed datamodule.val_dataloader = datamodule.val_dataloader_mixed datamodule.train_transforms = CPCTrainTransformsSTL10() datamodule.val_transforms = CPCEvalTransformsSTL10() args.patch_size = 16 # 16 GB RAM - 64 # 32 GB RAM - 144 args.batch_size = 144 def to_device(batch, device): (_, _), (x2, y2) = batch x2 = x2.to(device) y2 = y2.to(device) return x2, y2 online_evaluator.to_device = to_device elif args.dataset == 'imagenet2012': datamodule = SSLImagenetDataModule.from_argparse_args(args) datamodule.train_transforms = CPCTrainTransformsImageNet128() datamodule.val_transforms = CPCEvalTransformsImageNet128() args.patch_size = 32 model = CPCV2(**vars(args), datamodule=datamodule) trainer = pl.Trainer.from_argparse_args(args, callbacks=[online_evaluator]) trainer.fit(model)
def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule seed_everything(1234) parser = ArgumentParser() parser = Trainer.add_argparse_args(parser) parser = CPC_v2.add_model_specific_args(parser) parser.add_argument("--dataset", default="cifar10", type=str) parser.add_argument("--data_dir", default=".", type=str) parser.add_argument("--meta_dir", default=".", type=str, help="path to meta.bin for imagenet") parser.add_argument("--num_workers", default=8, type=int) parser.add_argument("--hidden_mlp", default=2048, type=int, help="hidden layer dimension in projection head") parser.add_argument("--batch_size", type=int, default=128) args = parser.parse_args() datamodule = None if args.dataset == "cifar10": datamodule = CIFAR10DataModule.from_argparse_args(args) datamodule.train_transforms = CPCTrainTransformsCIFAR10() datamodule.val_transforms = CPCEvalTransformsCIFAR10() args.patch_size = 8 elif args.dataset == "stl10": datamodule = STL10DataModule.from_argparse_args(args) datamodule.train_dataloader = datamodule.train_dataloader_mixed datamodule.val_dataloader = datamodule.val_dataloader_mixed datamodule.train_transforms = CPCTrainTransformsSTL10() datamodule.val_transforms = CPCEvalTransformsSTL10() args.patch_size = 16 elif args.dataset == "imagenet2012": datamodule = SSLImagenetDataModule.from_argparse_args(args) datamodule.train_transforms = CPCTrainTransformsImageNet128() datamodule.val_transforms = CPCEvalTransformsImageNet128() args.patch_size = 32 online_evaluator = SSLOnlineEvaluator( drop_p=0.0, hidden_dim=None, z_dim=args.hidden_mlp, num_classes=datamodule.num_classes, dataset=args.dataset, ) if args.dataset == "stl10": # 16 GB RAM - 64 # 32 GB RAM - 144 args.batch_size = 144 def to_device(batch, device): (_, _), (x2, y2) = batch x2 = x2.to(device) y2 = y2.to(device) return x2, y2 online_evaluator.to_device = to_device model = CPC_v2(**vars(args)) trainer = Trainer.from_argparse_args(args, callbacks=[online_evaluator]) trainer.fit(model, datamodule=datamodule)
def __init__( self, datamodule: pl.LightningDataModule = None, encoder: Union[str, torch.nn.Module, pl.LightningModule] = 'cpc_encoder', patch_size: int = 8, patch_overlap: int = 4, online_ft: int = True, task: str = 'cpc', num_workers: int = 4, learning_rate: int = 1e-4, data_dir: str = '', batch_size: int = 32, pretrained: str = None, **kwargs, ): """ Args: datamodule: A Datamodule (optional). Otherwise set the dataloaders directly encoder: A string for any of the resnets in torchvision, or the original CPC encoder, or a custon nn.Module encoder patch_size: How big to make the image patches patch_overlap: How much overlap should each patch have. online_ft: Enable a 1024-unit MLP to fine-tune online task: Which self-supervised task to use ('cpc', 'amdim', etc...) num_workers: num dataloader worksers learning_rate: what learning rate to use data_dir: where to store data batch_size: batch size pretrained: If true, will use the weights pretrained (using CPC) on Imagenet """ super().__init__() self.save_hyperparameters() self.online_evaluator = self.hparams.online_ft if pretrained: self.hparams.dataset = pretrained self.online_evaluator = True # link data if datamodule is None: datamodule = CIFAR10DataModule( self.hparams.data_dir, num_workers=self.hparams.num_workers, batch_size=batch_size) datamodule.train_transforms = CPCTrainTransformsCIFAR10() datamodule.val_transforms = CPCEvalTransformsCIFAR10() self.datamodule = datamodule # init encoder self.encoder = encoder if isinstance(encoder, str): self.encoder = self.init_encoder() # info nce loss c, h = self.__compute_final_nb_c(self.hparams.patch_size) self.contrastive_task = CPCTask(num_input_channels=c, target_dim=64, embed_scale=0.1) self.z_dim = c * h * h self.num_classes = self.datamodule.num_classes if pretrained: self.load_pretrained(encoder)
def __init__( self, datamodule: pl.LightningDataModule = None, encoder: Union[str, torch.nn.Module, pl.LightningModule] = 'cpc_encoder', patch_size: int = 8, patch_overlap: int = 4, online_ft: int = True, task: str = 'cpc', num_workers: int = 4, learning_rate: int = 1e-4, data_dir: str = '', batch_size: int = 32, pretrained: str = None, **kwargs, ): """ PyTorch Lightning implementation of `Data-Efficient Image Recognition with Contrastive Predictive Coding <https://arxiv.org/abs/1905.09272>`_ Paper authors: (Olivier J. Hénaff, Aravind Srinivas, Jeffrey De Fauw, Ali Razavi, Carl Doersch, S. M. Ali Eslami, Aaron van den Oord). Model implemented by: - `William Falcon <https://github.com/williamFalcon>`_ - `Tullie Murrell <https://github.com/tullie>`_ Example: >>> from pl_bolts.models.self_supervised import CPCV2 ... >>> model = CPCV2() Train:: trainer = Trainer() trainer.fit(model) CLI command:: # cifar10 python cpc_module.py --gpus 1 # imagenet python cpc_module.py --gpus 8 --dataset imagenet2012 --data_dir /path/to/imagenet/ --meta_dir /path/to/folder/with/meta.bin/ --batch_size 32 To Finetune:: python cpc_finetuner.py --ckpt_path path/to/checkpoint.ckpt --dataset cifar10 --gpus x Some uses:: # load resnet18 pretrained using CPC on imagenet model = CPCV2(encoder='resnet18', pretrained=True) resnet18 = model.encoder renset18.freeze() # it supportes any torchvision resnet model = CPCV2(encoder='resnet50', pretrained=True) # use it as a feature extractor x = torch.rand(2, 3, 224, 224) out = model(x) Args: datamodule: A Datamodule (optional). Otherwise set the dataloaders directly encoder: A string for any of the resnets in torchvision, or the original CPC encoder, or a custon nn.Module encoder patch_size: How big to make the image patches patch_overlap: How much overlap should each patch have. online_ft: Enable a 1024-unit MLP to fine-tune online task: Which self-supervised task to use ('cpc', 'amdim', etc...) num_workers: num dataloader worksers learning_rate: what learning rate to use data_dir: where to store data batch_size: batch size pretrained: If true, will use the weights pretrained (using CPC) on Imagenet """ super().__init__() self.save_hyperparameters() self.online_evaluator = self.hparams.online_ft if pretrained: self.hparams.dataset = pretrained self.online_evaluator = True # link data if datamodule is None: datamodule = CIFAR10DataModule( self.hparams.data_dir, num_workers=self.hparams.num_workers, batch_size=batch_size ) datamodule.train_transforms = CPCTrainTransformsCIFAR10() datamodule.val_transforms = CPCEvalTransformsCIFAR10() self.datamodule = datamodule # init encoder self.encoder = encoder if isinstance(encoder, str): self.encoder = self.init_encoder() # info nce loss c, h = self.__compute_final_nb_c(self.hparams.patch_size) self.contrastive_task = CPCTask(num_input_channels=c, target_dim=64, embed_scale=0.1) self.z_dim = c * h * h self.num_classes = self.datamodule.num_classes if pretrained: self.load_pretrained(encoder)
# todo: covert to CLI func and add test if __name__ == '__main__': pl.seed_everything(1234) parser = ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) parser = CPCV2.add_model_specific_args(parser) args = parser.parse_args() args.online_ft = True datamodule = None if args.dataset == 'cifar10': datamodule = CIFAR10DataModule.from_argparse_args(args) datamodule.train_transforms = CPCTrainTransformsCIFAR10() datamodule.val_transforms = CPCEvalTransformsCIFAR10() args.patch_size = 8 elif args.dataset == 'stl10': datamodule = STL10DataModule.from_argparse_args(args) datamodule.train_dataloader = datamodule.train_dataloader_mixed datamodule.val_dataloader = datamodule.val_dataloader_mixed datamodule.train_transforms = CPCTrainTransformsSTL10() datamodule.val_transforms = CPCEvalTransformsSTL10() args.patch_size = 16 elif args.dataset == 'imagenet2012': datamodule = SSLImagenetDataModule.from_argparse_args(args) datamodule.train_transforms = CPCTrainTransformsImageNet128() datamodule.val_transforms = CPCEvalTransformsImageNet128()