def test_swav(tmpdir, datadir, batch_size=2):
    # inputs, y = batch  (doesn't receive y for some reason)
    datamodule = CIFAR10DataModule(data_dir=datadir,
                                   batch_size=batch_size,
                                   num_workers=0)

    datamodule.train_transforms = SwAVTrainDataTransform(
        normalize=cifar10_normalization(),
        size_crops=[32, 16],
        nmb_crops=[2, 1],
        gaussian_blur=False)
    datamodule.val_transforms = SwAVEvalDataTransform(
        normalize=cifar10_normalization(),
        size_crops=[32, 16],
        nmb_crops=[2, 1],
        gaussian_blur=False)

    model = SwAV(arch='resnet18',
                 hidden_mlp=512,
                 gpus=0,
                 nodes=1,
                 num_samples=datamodule.num_samples,
                 batch_size=batch_size,
                 nmb_crops=[2, 1],
                 sinkhorn_iterations=1,
                 nmb_prototypes=2,
                 queue_length=0,
                 maxpool1=False,
                 first_conv=False,
                 dataset='cifar10')

    trainer = Trainer(gpus=0, fast_dev_run=True, default_root_dir=tmpdir)

    trainer.fit(model, datamodule=datamodule)
def test_swav(tmpdir, datadir):
    seed_everything()

    batch_size = 2

    # inputs, y = batch  (doesn't receive y for some reason)
    datamodule = CIFAR10DataModule(
        data_dir=datadir,
        batch_size=batch_size,
        num_workers=0
    )

    datamodule.train_transforms = SwAVTrainDataTransform(
        normalize=cifar10_normalization(),
        size_crops=[32, 16],
        nmb_crops=[2, 1],
        gaussian_blur=False
    )
    datamodule.val_transforms = SwAVEvalDataTransform(
        normalize=cifar10_normalization(),
        size_crops=[32, 16],
        nmb_crops=[2, 1],
        gaussian_blur=False
    )

    model = SwAV(
        arch='resnet18',
        hidden_mlp=512,
        gpus=0,
        nodes=1,
        num_samples=datamodule.num_samples,
        batch_size=batch_size,
        nmb_crops=[2, 1],
        sinkhorn_iterations=1,
        nmb_prototypes=2,
        maxpool1=False,
        first_conv=False,
        dataset='cifar10'
    )

    trainer = pl.Trainer(
        gpus=0, fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir, max_steps=3
    )

    trainer.fit(model, datamodule)
    loss = trainer.progress_bar_dict['loss']

    assert float(loss) > 0
Exemplo n.º 3
0
def swav_example():
	from pl_bolts.models.self_supervised import SwAV
	from pl_bolts.datamodules import STL10DataModule
	from pl_bolts.models.self_supervised.swav.transforms import SwAVTrainDataTransform, SwAVEvalDataTransform
	from pl_bolts.transforms.dataset_normalizations import stl10_normalization

	batch_size = 128

	# Data module.
	dm = STL10DataModule(data_dir=".", num_workers=16, batch_size=batch_size)
	dm.train_dataloader = dm.train_dataloader_mixed
	dm.val_dataloader = dm.val_dataloader_mixed
	dm.train_transforms = SwAVTrainDataTransform(normalize=stl10_normalization())
	dm.val_transforms = SwAVEvalDataTransform(normalize=stl10_normalization())

	# Model.
	model = SwAV(
		gpus=1,
		num_samples=dm.num_unlabeled_samples,
		dataset="stl10",
		batch_size=batch_size
	)

	# Fit.
	trainer = pl.Trainer(gpus=2, accelerator="ddp", precision=16)
	trainer.fit(model, datamodule=dm)

	#--------------------
	# ImageNet pretrained model:
	weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/bolts_swav_imagenet/swav_imagenet.ckpt"
	#weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/swav_imagenet/swav_imagenet.pth.tar"
	# STL-10 pretrained model:
	#weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/checkpoints/swav_stl10.pth.tar"
	swav = SwAV.load_from_checkpoint(weight_path, strict=True)

	swav.freeze()
def cli_main():
    from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
    from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule
    from pl_bolts.models.self_supervised.swav.transforms import SwAVEvalDataTransform, SwAVTrainDataTransform

    parser = ArgumentParser()

    # model args
    parser = SwAV.add_model_specific_args(parser)
    args = parser.parse_args()

    if args.dataset == 'stl10':
        dm = STL10DataModule(data_dir=args.data_dir,
                             batch_size=args.batch_size,
                             num_workers=args.num_workers)

        dm.train_dataloader = dm.train_dataloader_mixed
        dm.val_dataloader = dm.val_dataloader_mixed
        args.num_samples = dm.num_unlabeled_samples

        args.maxpool1 = False

        normalization = stl10_normalization()
    elif args.dataset == 'cifar10':
        args.batch_size = 2
        args.num_workers = 0

        dm = CIFAR10DataModule(data_dir=args.data_dir,
                               batch_size=args.batch_size,
                               num_workers=args.num_workers)

        args.num_samples = dm.num_samples

        args.maxpool1 = False
        args.first_conv = False

        normalization = cifar10_normalization()

        # cifar10 specific params
        args.size_crops = [32, 16]
        args.nmb_crops = [2, 1]
        args.gaussian_blur = False
    elif args.dataset == 'imagenet':
        args.maxpool1 = True
        args.first_conv = True
        normalization = imagenet_normalization()

        args.size_crops = [224, 96]
        args.nmb_crops = [2, 6]
        args.min_scale_crops = [0.14, 0.05]
        args.max_scale_crops = [1., 0.14]
        args.gaussian_blur = True
        args.jitter_strength = 1.

        args.batch_size = 64
        args.num_nodes = 8
        args.gpus = 8  # per-node
        args.max_epochs = 800

        args.optimizer = 'lars'
        args.learning_rate = 4.8
        args.final_lr = 0.0048
        args.start_lr = 0.3

        args.nmb_prototypes = 3000
        args.online_ft = True

        dm = ImagenetDataModule(data_dir=args.data_dir,
                                batch_size=args.batch_size,
                                num_workers=args.num_workers)

        args.num_samples = dm.num_samples
        args.input_height = dm.size()[-1]
    else:
        raise NotImplementedError(
            "other datasets have not been implemented till now")

    dm.train_transforms = SwAVTrainDataTransform(
        normalize=normalization,
        size_crops=args.size_crops,
        nmb_crops=args.nmb_crops,
        min_scale_crops=args.min_scale_crops,
        max_scale_crops=args.max_scale_crops,
        gaussian_blur=args.gaussian_blur,
        jitter_strength=args.jitter_strength)

    dm.val_transforms = SwAVEvalDataTransform(
        normalize=normalization,
        size_crops=args.size_crops,
        nmb_crops=args.nmb_crops,
        min_scale_crops=args.min_scale_crops,
        max_scale_crops=args.max_scale_crops,
        gaussian_blur=args.gaussian_blur,
        jitter_strength=args.jitter_strength)

    # swav model init
    model = SwAV(**args.__dict__)

    online_evaluator = None
    if args.online_ft:
        # online eval
        online_evaluator = SSLOnlineEvaluator(
            drop_p=0.,
            hidden_dim=None,
            z_dim=args.hidden_mlp,
            num_classes=dm.num_classes,
            dataset=args.dataset,
        )

    lr_monitor = LearningRateMonitor(logging_interval="step")
    model_checkpoint = ModelCheckpoint(save_last=True,
                                       save_top_k=1,
                                       monitor='val_loss')
    callbacks = [model_checkpoint, online_evaluator
                 ] if args.online_ft else [model_checkpoint]
    callbacks.append(lr_monitor)

    trainer = Trainer(
        max_epochs=args.max_epochs,
        max_steps=None if args.max_steps == -1 else args.max_steps,
        gpus=args.gpus,
        num_nodes=args.num_nodes,
        distributed_backend='ddp' if args.gpus > 1 else None,
        sync_batchnorm=True if args.gpus > 1 else False,
        precision=32 if args.fp32 else 16,
        callbacks=callbacks,
        fast_dev_run=args.fast_dev_run)

    trainer.fit(model, datamodule=dm)
Exemplo n.º 5
0
def cli_main():
    from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
    from pl_bolts.models.self_supervised.swav.transforms import SwAVTrainDataTransform, SwAVEvalDataTransform
    from pl_bolts.datamodules import STL10DataModule, CIFAR10DataModule

    parser = ArgumentParser()

    # model args
    parser = SwAV.add_model_specific_args(parser)
    args = parser.parse_args()

    if args.dataset == 'stl10':
        dm = STL10DataModule(data_dir=args.data_path,
                             batch_size=args.batch_size,
                             num_workers=args.num_workers)

        dm.train_dataloader = dm.train_dataloader_mixed
        dm.val_dataloader = dm.val_dataloader_mixed
        args.num_samples = dm.num_unlabeled_samples

        args.maxpool1 = False

        normalization = stl10_normalization()
    elif args.dataset == 'cifar10':
        args.batch_size = 2
        args.num_workers = 0

        dm = CIFAR10DataModule(data_dir=args.data_path,
                               batch_size=args.batch_size,
                               num_workers=args.num_workers)

        args.num_samples = dm.num_samples

        args.maxpool1 = False
        args.first_conv = False

        normalization = cifar10_normalization()

        # cifar10 specific params
        args.size_crops = [32, 16]
        args.nmb_crops = [2, 1]
        args.gaussian_blur = False
    else:
        raise NotImplementedError(
            "other datasets have not been implemented till now")

    dm.train_transforms = SwAVTrainDataTransform(
        normalize=stl10_normalization(),
        size_crops=args.size_crops,
        nmb_crops=args.nmb_crops,
        min_scale_crops=args.min_scale_crops,
        max_scale_crops=args.max_scale_crops,
        gaussian_blur=args.gaussian_blur,
        jitter_strength=args.jitter_strength)

    dm.val_transforms = SwAVEvalDataTransform(
        normalize=stl10_normalization(),
        size_crops=args.size_crops,
        nmb_crops=args.nmb_crops,
        min_scale_crops=args.min_scale_crops,
        max_scale_crops=args.max_scale_crops,
        gaussian_blur=args.gaussian_blur,
        jitter_strength=args.jitter_strength)

    # swav model init
    model = SwAV(**args.__dict__)

    online_evaluator = None
    if args.online_ft:
        # online eval
        online_evaluator = SSLOnlineEvaluator(drop_p=0.,
                                              hidden_dim=None,
                                              z_dim=args.hidden_mlp,
                                              num_classes=dm.num_classes,
                                              dataset=args.dataset)

    trainer = pl.Trainer(
        max_epochs=args.max_epochs,
        max_steps=None if args.max_steps == -1 else args.max_steps,
        gpus=args.gpus,
        distributed_backend='ddp' if args.gpus > 1 else None,
        sync_batchnorm=True if args.gpus > 1 else False,
        precision=32 if args.fp32 else 16,
        callbacks=[online_evaluator] if args.online_ft else None,
        fast_dev_run=args.fast_dev_run)

    trainer.fit(model, dm)
Exemplo n.º 6
0
import pytorch_lightning as pl
from swav_module import SwAV
from pl_bolts.datamodules import STL10DataModule
from pl_bolts.models.self_supervised.swav.transforms import (
    SwAVTrainDataTransform, SwAVEvalDataTransform
)
from pl_bolts.transforms.dataset_normalizations import stl10_normalization

# data
batch_size = 128
dm = STL10DataModule(data_dir='.', batch_size=batch_size, num_workers=8)
dm.train_dataloader = dm.train_dataloader_mixed
dm.val_dataloader = dm.val_dataloader_mixed

dm.train_transforms = SwAVTrainDataTransform(
    normalize=stl10_normalization()
)

dm.val_transforms = SwAVEvalDataTransform(
    normalize=stl10_normalization()
)

# model
model = SwAV(
    lars_wrapper=True,
    online_ft=True,
    gpus=1,
    learning_rate=1e-3,
    num_samples=dm.num_unlabeled_samples,
    gaussian_blur=True,
    queue_length=0,