def cli_main(): from pl_bolts.callbacks.self_supervised import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule, ImagenetDataModule from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform, SimCLREvalDataTransform seed_everything(1234) parser = ArgumentParser() # trainer args parser = pl.Trainer.add_argparse_args(parser) # model args parser = BYOL.add_model_specific_args(parser) args = parser.parse_args() # pick data dm = None # init default datamodule if args.dataset == 'cifar10': dm = CIFAR10DataModule.from_argparse_args(args) dm.train_transforms = SimCLRTrainDataTransform(32) dm.val_transforms = SimCLREvalDataTransform(32) args.num_classes = dm.num_classes elif args.dataset == 'stl10': dm = STL10DataModule.from_argparse_args(args) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) args.num_classes = dm.num_classes elif args.dataset == 'imagenet2012': dm = ImagenetDataModule.from_argparse_args(args, image_size=196) (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) args.num_classes = dm.num_classes model = BYOL(**args.__dict__) def to_device(batch, device): (x1, x2), y = batch x1 = x1.to(device) y = y.to(device) return x1, y # finetune in real-time online_eval = SSLOnlineEvaluator(z_dim=2048, num_classes=dm.num_classes) online_eval.to_device = to_device trainer = pl.Trainer.from_argparse_args(args, max_steps=300000, callbacks=[online_eval]) trainer.fit(model, dm)
def cli_main(): # pragma: no-cover pl.seed_everything(1234) parser = ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) parser.add_argument('--dataset', type=str, help='stl10, cifar10', default='cifar10') parser.add_argument('--ckpt_path', type=str, help='path to ckpt') parser.add_argument('--data_dir', type=str, help='path to ckpt', default=os.getcwd()) args = parser.parse_args() # load the backbone backbone = CPCV2.load_from_checkpoint(args.ckpt_path, strict=False) if args.dataset == 'cifar10': dm = CIFAR10DataModule.from_argparse_args(args) dm.train_transforms = CPCTrainTransformsCIFAR10() dm.val_transforms = CPCEvalTransformsCIFAR10() dm.test_transforms = CPCEvalTransformsCIFAR10() elif args.dataset == 'stl10': dm = STL10DataModule.from_argparse_args(args) dm.train_dataloader = dm.train_dataloader_labeled dm.val_dataloader = dm.val_dataloader_labeled dm.train_transforms = CPCTrainTransformsSTL10() dm.val_transforms = CPCEvalTransformsSTL10() dm.test_transforms = CPCEvalTransformsSTL10() # finetune tuner = SSLFineTuner(backbone, in_features=backbone.z_dim, num_classes=backbone.num_classes) trainer = pl.Trainer.from_argparse_args(args, early_stop_callback=True) trainer.fit(tuner, dm) trainer.test(datamodule=dm)
def cli_main(): parser = ArgumentParser() # trainer args parser = pl.Trainer.add_argparse_args(parser) # model args parser = MocoV2.add_model_specific_args(parser) args = parser.parse_args() if args.dataset == 'cifar10': datamodule = CIFAR10DataModule.from_argparse_args(args) datamodule.train_transforms = Moco2TrainCIFAR10Transforms() datamodule.val_transforms = Moco2EvalCIFAR10Transforms() elif args.dataset == 'stl10': datamodule = STL10DataModule.from_argparse_args(args) datamodule.train_dataloader = datamodule.train_dataloader_mixed datamodule.val_dataloader = datamodule.val_dataloader_mixed datamodule.train_transforms = Moco2TrainSTL10Transforms() datamodule.val_transforms = Moco2EvalSTL10Transforms() elif args.dataset == 'imagenet2012': datamodule = SSLImagenetDataModule.from_argparse_args(args) datamodule.train_transforms = Moco2TrainImagenetTransforms() datamodule.val_transforms = Moco2EvalImagenetTransforms() model = MocoV2(**args.__dict__, datamodule=datamodule) trainer = pl.Trainer.from_argparse_args(args) trainer.fit(model)
def cli_main(): from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler from pl_bolts.datamodules import ImagenetDataModule pl.seed_everything(1234) parser = ArgumentParser() parser.add_argument('--dataset', default='mnist', type=str, help='mnist, stl10, imagenet') parser = pl.Trainer.add_argparse_args(parser) parser = VAE.add_model_specific_args(parser) parser = ImagenetDataModule.add_argparse_args(parser) parser = MNISTDataModule.add_argparse_args(parser) args = parser.parse_args() # default is mnist datamodule = None if args.dataset == 'imagenet2012': datamodule = ImagenetDataModule.from_argparse_args(args) elif args.dataset == 'stl10': datamodule = STL10DataModule.from_argparse_args(args) callbacks = [ TensorboardGenerativeModelImageSampler(), LatentDimInterpolator(interpolate_epoch_interval=5) ] vae = VAE(**vars(args), datamodule=datamodule) trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, progress_bar_refresh_rate=10) trainer.fit(vae)
def cli_main(): # pragma: no-cover pl.seed_everything(1234) parser = ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) parser.add_argument('--dataset', type=str, help='stl10, cifar10', default='cifar10') parser.add_argument('--ckpt_path', type=str, help='path to ckpt') parser.add_argument('--data_dir', type=str, help='path to ckpt', default=os.getcwd()) args = parser.parse_args() # load the backbone backbone = SimCLR.load_from_checkpoint(args.ckpt_path, strict=False) # init default datamodule if args.dataset == 'cifar10': dm = CIFAR10DataModule.from_argparse_args(args) dm.train_transforms = SimCLRTrainDataTransform(32) dm.val_transforms = SimCLREvalDataTransform(32) dm.test_transforms = SimCLREvalDataTransform(32) args.num_samples = dm.num_samples elif args.dataset == 'stl10': dm = STL10DataModule.from_argparse_args(args) dm.train_dataloader = dm.train_dataloader_labeled dm.val_dataloader = dm.val_dataloader_labeled args.num_samples = dm.num_labeled_samples (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) elif args.dataset == 'imagenet2012': dm = ImagenetDataModule.from_argparse_args(args, image_size=196) (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) # finetune tuner = SSLFineTuner(backbone, in_features=2048 * 2 * 2, num_classes=dm.num_classes, hidden_dim=None) trainer = pl.Trainer.from_argparse_args(args, early_stop_callback=True) trainer.fit(tuner, dm) trainer.test(datamodule=dm)
def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule, ImagenetDataModule parser = ArgumentParser() # trainer args parser = pl.Trainer.add_argparse_args(parser) # model args parser = SimCLR.add_model_specific_args(parser) args = parser.parse_args() # init default datamodule if args.dataset == 'cifar10': dm = CIFAR10DataModule.from_argparse_args(args) dm.train_transforms = SimCLRTrainDataTransform(32) dm.val_transforms = SimCLREvalDataTransform(32) args.num_samples = dm.num_samples elif args.dataset == 'stl10': dm = STL10DataModule.from_argparse_args(args) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed args.num_samples = dm.num_unlabeled_samples (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) elif args.dataset == 'imagenet2012': dm = ImagenetDataModule.from_argparse_args(args, image_size=196) (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) model = SimCLR(**args.__dict__) # finetune in real-time def to_device(batch, device): (x1, x2), y = batch x1 = x1.to(device) y = y.to(device) return x1, y online_eval = SSLOnlineEvaluator(z_dim=2048 * 2 * 2, num_classes=dm.num_classes) online_eval.to_device = to_device trainer = pl.Trainer.from_argparse_args(args, callbacks=[online_eval]) trainer.fit(model, dm)
def cli_main(): pl.seed_everything(1234) parser = ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) parser = CPCV2.add_model_specific_args(parser) args = parser.parse_args() args.online_ft = True datamodule = None online_evaluator = SSLOnlineEvaluator() if args.dataset == 'cifar10': datamodule = CIFAR10DataModule.from_argparse_args(args) datamodule.train_transforms = CPCTrainTransformsCIFAR10() datamodule.val_transforms = CPCEvalTransformsCIFAR10() args.patch_size = 8 elif args.dataset == 'stl10': datamodule = STL10DataModule.from_argparse_args(args) datamodule.train_dataloader = datamodule.train_dataloader_mixed datamodule.val_dataloader = datamodule.val_dataloader_mixed datamodule.train_transforms = CPCTrainTransformsSTL10() datamodule.val_transforms = CPCEvalTransformsSTL10() args.patch_size = 16 # 16 GB RAM - 64 # 32 GB RAM - 144 args.batch_size = 144 def to_device(batch, device): (_, _), (x2, y2) = batch x2 = x2.to(device) y2 = y2.to(device) return x2, y2 online_evaluator.to_device = to_device elif args.dataset == 'imagenet2012': datamodule = SSLImagenetDataModule.from_argparse_args(args) datamodule.train_transforms = CPCTrainTransformsImageNet128() datamodule.val_transforms = CPCEvalTransformsImageNet128() args.patch_size = 32 model = CPCV2(**vars(args), datamodule=datamodule) trainer = pl.Trainer.from_argparse_args(args, callbacks=[online_evaluator]) trainer.fit(model)
def cli_main(): from pl_bolts.datamodules import CIFAR10DataModule, SSLImagenetDataModule, STL10DataModule parser = ArgumentParser() # trainer args parser = pl.Trainer.add_argparse_args(parser) # model args parser = MocoV2.add_model_specific_args(parser) args = parser.parse_args() if args.dataset == 'cifar10': datamodule = CIFAR10DataModule.from_argparse_args(args) datamodule.train_transforms = Moco2TrainCIFAR10Transforms() datamodule.val_transforms = Moco2EvalCIFAR10Transforms() elif args.dataset == 'stl10': datamodule = STL10DataModule.from_argparse_args(args) datamodule.train_dataloader = datamodule.train_dataloader_mixed datamodule.val_dataloader = datamodule.val_dataloader_mixed datamodule.train_transforms = Moco2TrainSTL10Transforms() datamodule.val_transforms = Moco2EvalSTL10Transforms() elif args.dataset == 'imagenet2012': datamodule = SSLImagenetDataModule.from_argparse_args(args) datamodule.train_transforms = Moco2TrainImagenetTransforms() datamodule.val_transforms = Moco2EvalImagenetTransforms() else: # replace with your own dataset, otherwise CIFAR-10 will be used by default if `None` passed in datamodule = None model = MocoV2(**args.__dict__) wandb_logger = WandbLogger(name='Baseline', project='MocoV2') trainer = pl.Trainer.from_argparse_args(args, logger=wandb_logger) trainer.fit(model, datamodule=datamodule) wandb.finish()
def cli_main(): from pl_bolts.datamodules import CIFAR10DataModule, SSLImagenetDataModule, STL10DataModule parser = ArgumentParser() # trainer args parser = Trainer.add_argparse_args(parser) # model args parser = Moco_v2.add_model_specific_args(parser) args = parser.parse_args() if args.dataset == "cifar10": datamodule = CIFAR10DataModule.from_argparse_args(args) datamodule.train_transforms = Moco2TrainCIFAR10Transforms() datamodule.val_transforms = Moco2EvalCIFAR10Transforms() elif args.dataset == "stl10": datamodule = STL10DataModule.from_argparse_args(args) datamodule.train_dataloader = datamodule.train_dataloader_mixed datamodule.val_dataloader = datamodule.val_dataloader_mixed datamodule.train_transforms = Moco2TrainSTL10Transforms() datamodule.val_transforms = Moco2EvalSTL10Transforms() elif args.dataset == "imagenet2012": datamodule = SSLImagenetDataModule.from_argparse_args(args) datamodule.train_transforms = Moco2TrainImagenetTransforms() datamodule.val_transforms = Moco2EvalImagenetTransforms() else: # replace with your own dataset, otherwise CIFAR-10 will be used by default if `None` passed in datamodule = None model = Moco_v2(**args.__dict__) trainer = Trainer.from_argparse_args(args) trainer.fit(model, datamodule=datamodule)
grid, global_step=trainer.global_step) # todo: covert to CLI func and add test if __name__ == '__main__': from pl_bolts.datamodules import ImagenetDataModule parser = ArgumentParser() parser = Trainer.add_argparse_args(parser) parser = GAN.add_model_specific_args(parser) parser = ImagenetDataModule.add_argparse_args(parser) args = parser.parse_args() # default is mnist datamodule = None if args.dataset == 'imagenet2012': datamodule = ImagenetDataModule.from_argparse_args(args) elif args.dataset == 'stl10': datamodule = STL10DataModule.from_argparse_args(args) gan = GAN(**vars(args), datamodule=datamodule) callbacks = [ImageGenerator(), LatentDimInterpolator()] # no val loop... thus we condition on loss and always save the last checkpoint_cb = ModelCheckpoint(monitor='loss', save_last=True) trainer = Trainer.from_argparse_args(args, callbacks=callbacks, checkpoint_callback=checkpoint_cb) trainer.fit(gan)