def test_swav(tmpdir, datadir, batch_size=2): # inputs, y = batch (doesn't receive y for some reason) datamodule = CIFAR10DataModule(data_dir=datadir, batch_size=batch_size, num_workers=0) datamodule.train_transforms = SwAVTrainDataTransform( normalize=cifar10_normalization(), size_crops=[32, 16], nmb_crops=[2, 1], gaussian_blur=False) datamodule.val_transforms = SwAVEvalDataTransform( normalize=cifar10_normalization(), size_crops=[32, 16], nmb_crops=[2, 1], gaussian_blur=False) model = SwAV(arch='resnet18', hidden_mlp=512, gpus=0, nodes=1, num_samples=datamodule.num_samples, batch_size=batch_size, nmb_crops=[2, 1], sinkhorn_iterations=1, nmb_prototypes=2, queue_length=0, maxpool1=False, first_conv=False, dataset='cifar10') trainer = Trainer(gpus=0, fast_dev_run=True, default_root_dir=tmpdir) trainer.fit(model, datamodule=datamodule)
def test_swav(tmpdir, datadir): seed_everything() batch_size = 2 # inputs, y = batch (doesn't receive y for some reason) datamodule = CIFAR10DataModule( data_dir=datadir, batch_size=batch_size, num_workers=0 ) datamodule.train_transforms = SwAVTrainDataTransform( normalize=cifar10_normalization(), size_crops=[32, 16], nmb_crops=[2, 1], gaussian_blur=False ) datamodule.val_transforms = SwAVEvalDataTransform( normalize=cifar10_normalization(), size_crops=[32, 16], nmb_crops=[2, 1], gaussian_blur=False ) model = SwAV( arch='resnet18', hidden_mlp=512, gpus=0, nodes=1, num_samples=datamodule.num_samples, batch_size=batch_size, nmb_crops=[2, 1], sinkhorn_iterations=1, nmb_prototypes=2, maxpool1=False, first_conv=False, dataset='cifar10' ) trainer = pl.Trainer( gpus=0, fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir, max_steps=3 ) trainer.fit(model, datamodule) loss = trainer.progress_bar_dict['loss'] assert float(loss) > 0
def swav_example(): from pl_bolts.models.self_supervised import SwAV from pl_bolts.datamodules import STL10DataModule from pl_bolts.models.self_supervised.swav.transforms import SwAVTrainDataTransform, SwAVEvalDataTransform from pl_bolts.transforms.dataset_normalizations import stl10_normalization batch_size = 128 # Data module. dm = STL10DataModule(data_dir=".", num_workers=16, batch_size=batch_size) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed dm.train_transforms = SwAVTrainDataTransform(normalize=stl10_normalization()) dm.val_transforms = SwAVEvalDataTransform(normalize=stl10_normalization()) # Model. model = SwAV( gpus=1, num_samples=dm.num_unlabeled_samples, dataset="stl10", batch_size=batch_size ) # Fit. trainer = pl.Trainer(gpus=2, accelerator="ddp", precision=16) trainer.fit(model, datamodule=dm) #-------------------- # ImageNet pretrained model: weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/bolts_swav_imagenet/swav_imagenet.ckpt" #weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/swav_imagenet/swav_imagenet.pth.tar" # STL-10 pretrained model: #weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/checkpoints/swav_stl10.pth.tar" swav = SwAV.load_from_checkpoint(weight_path, strict=True) swav.freeze()
def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule from pl_bolts.models.self_supervised.swav.transforms import SwAVEvalDataTransform, SwAVTrainDataTransform parser = ArgumentParser() # model args parser = SwAV.add_model_specific_args(parser) args = parser.parse_args() if args.dataset == 'stl10': dm = STL10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed args.num_samples = dm.num_unlabeled_samples args.maxpool1 = False normalization = stl10_normalization() elif args.dataset == 'cifar10': args.batch_size = 2 args.num_workers = 0 dm = CIFAR10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) args.num_samples = dm.num_samples args.maxpool1 = False args.first_conv = False normalization = cifar10_normalization() # cifar10 specific params args.size_crops = [32, 16] args.nmb_crops = [2, 1] args.gaussian_blur = False elif args.dataset == 'imagenet': args.maxpool1 = True args.first_conv = True normalization = imagenet_normalization() args.size_crops = [224, 96] args.nmb_crops = [2, 6] args.min_scale_crops = [0.14, 0.05] args.max_scale_crops = [1., 0.14] args.gaussian_blur = True args.jitter_strength = 1. args.batch_size = 64 args.num_nodes = 8 args.gpus = 8 # per-node args.max_epochs = 800 args.optimizer = 'lars' args.learning_rate = 4.8 args.final_lr = 0.0048 args.start_lr = 0.3 args.nmb_prototypes = 3000 args.online_ft = True dm = ImagenetDataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) args.num_samples = dm.num_samples args.input_height = dm.size()[-1] else: raise NotImplementedError( "other datasets have not been implemented till now") dm.train_transforms = SwAVTrainDataTransform( normalize=normalization, size_crops=args.size_crops, nmb_crops=args.nmb_crops, min_scale_crops=args.min_scale_crops, max_scale_crops=args.max_scale_crops, gaussian_blur=args.gaussian_blur, jitter_strength=args.jitter_strength) dm.val_transforms = SwAVEvalDataTransform( normalize=normalization, size_crops=args.size_crops, nmb_crops=args.nmb_crops, min_scale_crops=args.min_scale_crops, max_scale_crops=args.max_scale_crops, gaussian_blur=args.gaussian_blur, jitter_strength=args.jitter_strength) # swav model init model = SwAV(**args.__dict__) online_evaluator = None if args.online_ft: # online eval online_evaluator = SSLOnlineEvaluator( drop_p=0., hidden_dim=None, z_dim=args.hidden_mlp, num_classes=dm.num_classes, dataset=args.dataset, ) lr_monitor = LearningRateMonitor(logging_interval="step") model_checkpoint = ModelCheckpoint(save_last=True, save_top_k=1, monitor='val_loss') callbacks = [model_checkpoint, online_evaluator ] if args.online_ft else [model_checkpoint] callbacks.append(lr_monitor) trainer = Trainer( max_epochs=args.max_epochs, max_steps=None if args.max_steps == -1 else args.max_steps, gpus=args.gpus, num_nodes=args.num_nodes, distributed_backend='ddp' if args.gpus > 1 else None, sync_batchnorm=True if args.gpus > 1 else False, precision=32 if args.fp32 else 16, callbacks=callbacks, fast_dev_run=args.fast_dev_run) trainer.fit(model, datamodule=dm)
def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.models.self_supervised.swav.transforms import SwAVTrainDataTransform, SwAVEvalDataTransform from pl_bolts.datamodules import STL10DataModule, CIFAR10DataModule parser = ArgumentParser() # model args parser = SwAV.add_model_specific_args(parser) args = parser.parse_args() if args.dataset == 'stl10': dm = STL10DataModule(data_dir=args.data_path, batch_size=args.batch_size, num_workers=args.num_workers) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed args.num_samples = dm.num_unlabeled_samples args.maxpool1 = False normalization = stl10_normalization() elif args.dataset == 'cifar10': args.batch_size = 2 args.num_workers = 0 dm = CIFAR10DataModule(data_dir=args.data_path, batch_size=args.batch_size, num_workers=args.num_workers) args.num_samples = dm.num_samples args.maxpool1 = False args.first_conv = False normalization = cifar10_normalization() # cifar10 specific params args.size_crops = [32, 16] args.nmb_crops = [2, 1] args.gaussian_blur = False else: raise NotImplementedError( "other datasets have not been implemented till now") dm.train_transforms = SwAVTrainDataTransform( normalize=stl10_normalization(), size_crops=args.size_crops, nmb_crops=args.nmb_crops, min_scale_crops=args.min_scale_crops, max_scale_crops=args.max_scale_crops, gaussian_blur=args.gaussian_blur, jitter_strength=args.jitter_strength) dm.val_transforms = SwAVEvalDataTransform( normalize=stl10_normalization(), size_crops=args.size_crops, nmb_crops=args.nmb_crops, min_scale_crops=args.min_scale_crops, max_scale_crops=args.max_scale_crops, gaussian_blur=args.gaussian_blur, jitter_strength=args.jitter_strength) # swav model init model = SwAV(**args.__dict__) online_evaluator = None if args.online_ft: # online eval online_evaluator = SSLOnlineEvaluator(drop_p=0., hidden_dim=None, z_dim=args.hidden_mlp, num_classes=dm.num_classes, dataset=args.dataset) trainer = pl.Trainer( max_epochs=args.max_epochs, max_steps=None if args.max_steps == -1 else args.max_steps, gpus=args.gpus, distributed_backend='ddp' if args.gpus > 1 else None, sync_batchnorm=True if args.gpus > 1 else False, precision=32 if args.fp32 else 16, callbacks=[online_evaluator] if args.online_ft else None, fast_dev_run=args.fast_dev_run) trainer.fit(model, dm)
SwAVTrainDataTransform, SwAVEvalDataTransform ) from pl_bolts.transforms.dataset_normalizations import stl10_normalization # data batch_size = 128 dm = STL10DataModule(data_dir='.', batch_size=batch_size, num_workers=8) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed dm.train_transforms = SwAVTrainDataTransform( normalize=stl10_normalization() ) dm.val_transforms = SwAVEvalDataTransform( normalize=stl10_normalization() ) # model model = SwAV( lars_wrapper=True, online_ft=True, gpus=1, learning_rate=1e-3, num_samples=dm.num_unlabeled_samples, gaussian_blur=True, queue_length=0, dataset='stl10', jitter_strength=1.0, batch_size=batch_size, nmb_prototypes=512,