def __init__(self): # image augmentation functions self.train_transform = transforms.Compose([ transforms.RandomResizedCrop(32, scale=(0.2, 1.)), transforms.RandomApply( [ transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened ], p=0.8), transforms.RandomGrayscale(p=0.2), transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5), transforms.RandomHorizontalFlip(), transforms.ToTensor(), cifar10_normalization() ]) self.finetune_transform = transforms.Compose([ transforms.RandomResizedCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), cifar10_normalization(), ]) self.test_transform = transforms.Compose([ transforms.Resize(44), transforms.CenterCrop(32), transforms.ToTensor(), cifar10_normalization(), ])
def test_swav(tmpdir, datadir, batch_size=2): # inputs, y = batch (doesn't receive y for some reason) datamodule = CIFAR10DataModule(data_dir=datadir, batch_size=batch_size, num_workers=0) datamodule.train_transforms = SwAVTrainDataTransform( normalize=cifar10_normalization(), size_crops=[32, 16], nmb_crops=[2, 1], gaussian_blur=False) datamodule.val_transforms = SwAVEvalDataTransform( normalize=cifar10_normalization(), size_crops=[32, 16], nmb_crops=[2, 1], gaussian_blur=False) model = SwAV(arch='resnet18', hidden_mlp=512, gpus=0, nodes=1, num_samples=datamodule.num_samples, batch_size=batch_size, nmb_crops=[2, 1], sinkhorn_iterations=1, nmb_prototypes=2, queue_length=0, maxpool1=False, first_conv=False, dataset='cifar10') trainer = Trainer(gpus=0, fast_dev_run=True, default_root_dir=tmpdir) trainer.fit(model, datamodule=datamodule)
def test_swav(tmpdir, datadir): seed_everything() batch_size = 2 # inputs, y = batch (doesn't receive y for some reason) datamodule = CIFAR10DataModule( data_dir=datadir, batch_size=batch_size, num_workers=0 ) datamodule.train_transforms = SwAVTrainDataTransform( normalize=cifar10_normalization(), size_crops=[32, 16], nmb_crops=[2, 1], gaussian_blur=False ) datamodule.val_transforms = SwAVEvalDataTransform( normalize=cifar10_normalization(), size_crops=[32, 16], nmb_crops=[2, 1], gaussian_blur=False ) model = SwAV( arch='resnet18', hidden_mlp=512, gpus=0, nodes=1, num_samples=datamodule.num_samples, batch_size=batch_size, nmb_crops=[2, 1], sinkhorn_iterations=1, nmb_prototypes=2, maxpool1=False, first_conv=False, dataset='cifar10' ) trainer = pl.Trainer( gpus=0, fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir, max_steps=3 ) trainer.fit(model, datamodule) loss = trainer.progress_bar_dict['loss'] assert float(loss) > 0
def __init__(self, height=32): self.test_transform = transforms.Compose([ transforms.Resize(height + 12), transforms.CenterCrop(height), transforms.ToTensor(), cifar10_normalization(), ])
def default_transforms(self): if self.normalize: cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()]) else: cf10_transforms = transform_lib.Compose([transform_lib.ToTensor()]) return cf10_transforms
def __init__(self, height: int = 32): if not _TORCHVISION_AVAILABLE: # pragma: no cover raise ModuleNotFoundError('You want to use `transforms` from `torchvision` which is not installed yet.') self.test_transform = transforms.Compose([ transforms.Resize(height + 12), transforms.CenterCrop(height), transforms.ToTensor(), cifar10_normalization(), ])
def instantiate_datamodule(args): train_transforms = torchvision.transforms.Compose([ torchvision.transforms.RandomCrop(32, padding=4), torchvision.transforms.RandomHorizontalFlip(), torchvision.transforms.ToTensor(), cifar10_normalization(), ]) test_transforms = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), cifar10_normalization(), ]) cifar10_dm = pl_bolts.datamodules.CIFAR10DataModule( batch_size=args.batch_size, train_transforms=train_transforms, test_transforms=test_transforms, val_transforms=test_transforms, ) return cifar10_dm
def __init__(self, height: int = 32): if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( # pragma: no-cover 'You want to use `transforms` from `torchvision` which is not installed yet.' ) # image augmentation functions self.train_transform = transforms.Compose([ transforms.RandomResizedCrop(height, scale=(0.2, 1.)), transforms.RandomApply([ transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened ], p=0.8), transforms.RandomGrayscale(p=0.2), transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5), transforms.RandomHorizontalFlip(), transforms.ToTensor(), cifar10_normalization() ])
def on_train_epoch_end(self, trainer, pl_module, outputs): figure(figsize=(8, 3), dpi=300) # Z COMES FROM NORMAL(0, 1) rand_v = torch.rand((self.num_preds, pl_module.hparams.latent_dim), device=pl_module.device) p = torch.distributions.Normal(torch.zeros_like(rand_v), torch.zeros_like(rand_v)) z = p.rsample() # SAMPLE IMAGES with torch.no_grad(): pred = pl_module.decoder(z.to(pl_module.device)).cpu() # UNDO DATA NORMALIZATION normalize = cifar10_normalization() mean, std = np.array(normalize.mean), np.array(normalize.std) img = make_grid(pred).permute(1, 2, 0).numpy() * std + mean # PLOT IMAGES trainer.logger.experiment.add_image('img', torch.tensor(img).permute(2, 0, 1), global_step=trainer.global_step)
def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule from pl_bolts.models.self_supervised.swav.transforms import SwAVEvalDataTransform, SwAVTrainDataTransform parser = ArgumentParser() # model args parser = SwAV.add_model_specific_args(parser) args = parser.parse_args() if args.dataset == 'stl10': dm = STL10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed args.num_samples = dm.num_unlabeled_samples args.maxpool1 = False normalization = stl10_normalization() elif args.dataset == 'cifar10': args.batch_size = 2 args.num_workers = 0 dm = CIFAR10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) args.num_samples = dm.num_samples args.maxpool1 = False args.first_conv = False normalization = cifar10_normalization() # cifar10 specific params args.size_crops = [32, 16] args.nmb_crops = [2, 1] args.gaussian_blur = False elif args.dataset == 'imagenet': args.maxpool1 = True args.first_conv = True normalization = imagenet_normalization() args.size_crops = [224, 96] args.nmb_crops = [2, 6] args.min_scale_crops = [0.14, 0.05] args.max_scale_crops = [1., 0.14] args.gaussian_blur = True args.jitter_strength = 1. args.batch_size = 64 args.num_nodes = 8 args.gpus = 8 # per-node args.max_epochs = 800 args.optimizer = 'lars' args.learning_rate = 4.8 args.final_lr = 0.0048 args.start_lr = 0.3 args.nmb_prototypes = 3000 args.online_ft = True dm = ImagenetDataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) args.num_samples = dm.num_samples args.input_height = dm.size()[-1] else: raise NotImplementedError( "other datasets have not been implemented till now") dm.train_transforms = SwAVTrainDataTransform( normalize=normalization, size_crops=args.size_crops, nmb_crops=args.nmb_crops, min_scale_crops=args.min_scale_crops, max_scale_crops=args.max_scale_crops, gaussian_blur=args.gaussian_blur, jitter_strength=args.jitter_strength) dm.val_transforms = SwAVEvalDataTransform( normalize=normalization, size_crops=args.size_crops, nmb_crops=args.nmb_crops, min_scale_crops=args.min_scale_crops, max_scale_crops=args.max_scale_crops, gaussian_blur=args.gaussian_blur, jitter_strength=args.jitter_strength) # swav model init model = SwAV(**args.__dict__) online_evaluator = None if args.online_ft: # online eval online_evaluator = SSLOnlineEvaluator( drop_p=0., hidden_dim=None, z_dim=args.hidden_mlp, num_classes=dm.num_classes, dataset=args.dataset, ) lr_monitor = LearningRateMonitor(logging_interval="step") model_checkpoint = ModelCheckpoint(save_last=True, save_top_k=1, monitor='val_loss') callbacks = [model_checkpoint, online_evaluator ] if args.online_ft else [model_checkpoint] callbacks.append(lr_monitor) trainer = Trainer( max_epochs=args.max_epochs, max_steps=None if args.max_steps == -1 else args.max_steps, gpus=args.gpus, num_nodes=args.num_nodes, distributed_backend='ddp' if args.gpus > 1 else None, sync_batchnorm=True if args.gpus > 1 else False, precision=32 if args.fp32 else 16, callbacks=callbacks, fast_dev_run=args.fast_dev_run) trainer.fit(model, datamodule=dm)
def cli_main(): # pragma: no cover from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule seed_everything(1234) parser = ArgumentParser() parser.add_argument("--dataset", type=str, help="cifar10, stl10, imagenet", default="cifar10") parser.add_argument("--ckpt_path", type=str, help="path to ckpt") parser.add_argument("--data_dir", type=str, help="path to dataset", default=os.getcwd()) parser.add_argument("--batch_size", default=64, type=int, help="batch size per gpu") parser.add_argument("--num_workers", default=8, type=int, help="num of workers per GPU") parser.add_argument("--gpus", default=4, type=int, help="number of GPUs") parser.add_argument("--num_epochs", default=100, type=int, help="number of epochs") # fine-tuner params parser.add_argument("--in_features", type=int, default=2048) parser.add_argument("--dropout", type=float, default=0.0) parser.add_argument("--learning_rate", type=float, default=0.3) parser.add_argument("--weight_decay", type=float, default=1e-6) parser.add_argument("--nesterov", type=bool, default=False) # fix nesterov flag here parser.add_argument("--scheduler_type", type=str, default="cosine") parser.add_argument("--gamma", type=float, default=0.1) parser.add_argument("--final_lr", type=float, default=0.0) args = parser.parse_args() if args.dataset == "cifar10": dm = CIFAR10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) dm.train_transforms = SimCLRFinetuneTransform( normalize=cifar10_normalization(), input_height=dm.size()[-1], eval_transform=False ) dm.val_transforms = SimCLRFinetuneTransform( normalize=cifar10_normalization(), input_height=dm.size()[-1], eval_transform=True ) dm.test_transforms = SimCLRFinetuneTransform( normalize=cifar10_normalization(), input_height=dm.size()[-1], eval_transform=True ) args.maxpool1 = False args.first_conv = False args.num_samples = 1 elif args.dataset == "stl10": dm = STL10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) dm.train_dataloader = dm.train_dataloader_labeled dm.val_dataloader = dm.val_dataloader_labeled args.num_samples = 1 dm.train_transforms = SimCLRFinetuneTransform( normalize=stl10_normalization(), input_height=dm.size()[-1], eval_transform=False ) dm.val_transforms = SimCLRFinetuneTransform( normalize=stl10_normalization(), input_height=dm.size()[-1], eval_transform=True ) dm.test_transforms = SimCLRFinetuneTransform( normalize=stl10_normalization(), input_height=dm.size()[-1], eval_transform=True ) args.maxpool1 = False args.first_conv = True elif args.dataset == "imagenet": dm = ImagenetDataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) dm.train_transforms = SimCLRFinetuneTransform( normalize=imagenet_normalization(), input_height=dm.size()[-1], eval_transform=False ) dm.val_transforms = SimCLRFinetuneTransform( normalize=imagenet_normalization(), input_height=dm.size()[-1], eval_transform=True ) dm.test_transforms = SimCLRFinetuneTransform( normalize=imagenet_normalization(), input_height=dm.size()[-1], eval_transform=True ) args.num_samples = 1 args.maxpool1 = True args.first_conv = True else: raise NotImplementedError("other datasets have not been implemented till now") backbone = SimCLR( gpus=args.gpus, nodes=1, num_samples=args.num_samples, batch_size=args.batch_size, maxpool1=args.maxpool1, first_conv=args.first_conv, dataset=args.dataset, ).load_from_checkpoint(args.ckpt_path, strict=False) tuner = SSLFineTuner( backbone, in_features=args.in_features, num_classes=dm.num_classes, epochs=args.num_epochs, hidden_dim=None, dropout=args.dropout, learning_rate=args.learning_rate, weight_decay=args.weight_decay, nesterov=args.nesterov, scheduler_type=args.scheduler_type, gamma=args.gamma, final_lr=args.final_lr, ) trainer = Trainer( gpus=args.gpus, num_nodes=1, precision=16, max_epochs=args.num_epochs, distributed_backend="ddp", sync_batchnorm=True if args.gpus > 1 else False, ) trainer.fit(tuner, dm) trainer.test(datamodule=dm)
def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule from pl_bolts.models.self_supervised.simclr.transforms import SimCLREvalDataTransform, SimCLRTrainDataTransform parser = ArgumentParser() # model args parser = SimCLR.add_model_specific_args(parser) parser = pl.Trainer.add_argparse_args(parser) args = parser.parse_args() if args.dataset == 'stl10': dm = STL10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed args.num_samples = dm.num_unlabeled_samples args.maxpool1 = False args.first_conv = True args.input_height = dm.size()[-1] normalization = stl10_normalization() args.gaussian_blur = True args.jitter_strength = 1. elif args.dataset == 'cifar10': val_split = 5000 if args.num_nodes * args.gpus * args.batch_size > val_split: val_split = args.num_nodes * args.gpus * args.batch_size dm = CIFAR10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers, val_split=val_split) args.num_samples = dm.num_samples args.maxpool1 = False args.first_conv = False args.input_height = dm.size()[-1] args.temperature = 0.5 normalization = cifar10_normalization() args.gaussian_blur = False args.jitter_strength = 0.5 elif args.dataset == 'imagenet': args.maxpool1 = True args.first_conv = True normalization = imagenet_normalization() args.gaussian_blur = True args.jitter_strength = 1. args.batch_size = 64 args.num_nodes = 8 args.gpus = 8 # per-node args.max_epochs = 800 args.optimizer = 'sgd' args.lars_wrapper = True args.learning_rate = 4.8 args.final_lr = 0.0048 args.start_lr = 0.3 args.online_ft = True dm = ImagenetDataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) args.num_samples = dm.num_samples args.input_height = dm.size()[-1] else: raise NotImplementedError( "other datasets have not been implemented till now") dm.train_transforms = SimCLRTrainDataTransform( input_height=args.input_height, gaussian_blur=args.gaussian_blur, jitter_strength=args.jitter_strength, normalize=normalization, ) dm.val_transforms = SimCLREvalDataTransform( input_height=args.input_height, gaussian_blur=args.gaussian_blur, jitter_strength=args.jitter_strength, normalize=normalization, ) model = SimCLR(**args.__dict__) online_evaluator = None if args.online_ft: # online eval online_evaluator = SSLOnlineEvaluator(drop_p=0., hidden_dim=None, z_dim=args.hidden_mlp, num_classes=dm.num_classes, dataset=args.dataset) model_checkpoint = ModelCheckpoint(save_last=True, save_top_k=1, monitor='val_loss') callbacks = [model_checkpoint, online_evaluator ] if args.online_ft else [model_checkpoint] trainer = pl.Trainer.from_argparse_args( args, sync_batchnorm=True if args.gpus > 1 else False, callbacks=callbacks, ) trainer.fit(model, datamodule=dm)
def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.models.self_supervised.swav.transforms import SwAVTrainDataTransform, SwAVEvalDataTransform from pl_bolts.datamodules import STL10DataModule, CIFAR10DataModule parser = ArgumentParser() # model args parser = SwAV.add_model_specific_args(parser) args = parser.parse_args() if args.dataset == 'stl10': dm = STL10DataModule(data_dir=args.data_path, batch_size=args.batch_size, num_workers=args.num_workers) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed args.num_samples = dm.num_unlabeled_samples args.maxpool1 = False normalization = stl10_normalization() elif args.dataset == 'cifar10': args.batch_size = 2 args.num_workers = 0 dm = CIFAR10DataModule(data_dir=args.data_path, batch_size=args.batch_size, num_workers=args.num_workers) args.num_samples = dm.num_samples args.maxpool1 = False args.first_conv = False normalization = cifar10_normalization() # cifar10 specific params args.size_crops = [32, 16] args.nmb_crops = [2, 1] args.gaussian_blur = False else: raise NotImplementedError( "other datasets have not been implemented till now") dm.train_transforms = SwAVTrainDataTransform( normalize=stl10_normalization(), size_crops=args.size_crops, nmb_crops=args.nmb_crops, min_scale_crops=args.min_scale_crops, max_scale_crops=args.max_scale_crops, gaussian_blur=args.gaussian_blur, jitter_strength=args.jitter_strength) dm.val_transforms = SwAVEvalDataTransform( normalize=stl10_normalization(), size_crops=args.size_crops, nmb_crops=args.nmb_crops, min_scale_crops=args.min_scale_crops, max_scale_crops=args.max_scale_crops, gaussian_blur=args.gaussian_blur, jitter_strength=args.jitter_strength) # swav model init model = SwAV(**args.__dict__) online_evaluator = None if args.online_ft: # online eval online_evaluator = SSLOnlineEvaluator(drop_p=0., hidden_dim=None, z_dim=args.hidden_mlp, num_classes=dm.num_classes, dataset=args.dataset) trainer = pl.Trainer( max_epochs=args.max_epochs, max_steps=None if args.max_steps == -1 else args.max_steps, gpus=args.gpus, distributed_backend='ddp' if args.gpus > 1 else None, sync_batchnorm=True if args.gpus > 1 else False, precision=32 if args.fp32 else 16, callbacks=[online_evaluator] if args.online_ft else None, fast_dev_run=args.fast_dev_run) trainer.fit(model, dm)
def _default_transforms(self): mnist_transforms = transform_lib.Compose( [transform_lib.ToTensor(), cifar10_normalization()]) return mnist_transforms
def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule from pl_bolts.models.self_supervised.simclr import SimCLREvalDataTransform, SimCLRTrainDataTransform seed_everything(1234) parser = ArgumentParser() # trainer args parser = pl.Trainer.add_argparse_args(parser) # model args parser = SimSiam.add_model_specific_args(parser) args = parser.parse_args() # pick data dm = None # init datamodule if args.dataset == "stl10": dm = STL10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed args.num_samples = dm.num_unlabeled_samples args.maxpool1 = False args.first_conv = True args.input_height = dm.size()[-1] normalization = stl10_normalization() args.gaussian_blur = True args.jitter_strength = 1.0 elif args.dataset == "cifar10": val_split = 5000 if args.nodes * args.gpus * args.batch_size > val_split: val_split = args.nodes * args.gpus * args.batch_size dm = CIFAR10DataModule( data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers, val_split=val_split, ) args.num_samples = dm.num_samples args.maxpool1 = False args.first_conv = False args.input_height = dm.size()[-1] args.temperature = 0.5 normalization = cifar10_normalization() args.gaussian_blur = False args.jitter_strength = 0.5 elif args.dataset == "cifar100": val_split = 5000 if args.nodes * args.gpus * args.batch_size > val_split: val_split = args.nodes * args.gpus * args.batch_size dm = CIFAR100DataModule( data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers, val_split=val_split, ) args.num_samples = dm.num_samples args.maxpool1 = False args.first_conv = False args.input_height = dm.size()[-1] args.temperature = 0.5 # ((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023)) normalization = transforms.Normalize( mean=(0.5071, 0.4866, 0.4409), std=(0.2009, 0.1984, 0.2023), ) args.gaussian_blur = False args.jitter_strength = 0.5 elif args.dataset == "imagenet": args.maxpool1 = True args.first_conv = True normalization = imagenet_normalization() args.gaussian_blur = True args.jitter_strength = 1.0 args.batch_size = 64 args.nodes = 8 args.gpus = 8 # per-node args.max_epochs = 800 args.optimizer = "sgd" args.lars_wrapper = True args.learning_rate = 4.8 args.final_lr = 0.0048 args.start_lr = 0.3 args.online_ft = True dm = ImagenetDataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) args.num_samples = dm.num_samples args.input_height = dm.size()[-1] else: raise NotImplementedError("other datasets have not been implemented till now") dm.train_transforms = SimCLRTrainDataTransform( input_height=args.input_height, gaussian_blur=args.gaussian_blur, jitter_strength=args.jitter_strength, normalize=normalization, ) dm.val_transforms = SimCLREvalDataTransform( input_height=args.input_height, gaussian_blur=args.gaussian_blur, jitter_strength=args.jitter_strength, normalize=normalization, ) model = SimSiam(**args.__dict__) # finetune in real-time online_evaluator = None if args.online_ft: # online eval online_evaluator = SSLOnlineEvaluator( drop_p=0.0, hidden_dim=None, z_dim=args.hidden_mlp, num_classes=dm.num_classes, dataset=args.dataset, ) trainer = pl.Trainer( max_epochs=args.max_epochs, max_steps=None if args.max_steps == -1 else args.max_steps, gpus=args.gpus, num_nodes=args.nodes, distributed_backend="ddp" if args.gpus > 1 else None, sync_batchnorm=True if args.gpus > 1 else False, precision=32 if args.fp32 else 16, callbacks=[online_evaluator] if args.online_ft else None, fast_dev_run=args.fast_dev_run, ) trainer.fit(model, dm)