def test_swav(tmpdir, datadir, batch_size=2):
    # inputs, y = batch  (doesn't receive y for some reason)
    datamodule = CIFAR10DataModule(data_dir=datadir,
                                   batch_size=batch_size,
                                   num_workers=0)

    datamodule.train_transforms = SwAVTrainDataTransform(
        normalize=cifar10_normalization(),
        size_crops=[32, 16],
        nmb_crops=[2, 1],
        gaussian_blur=False)
    datamodule.val_transforms = SwAVEvalDataTransform(
        normalize=cifar10_normalization(),
        size_crops=[32, 16],
        nmb_crops=[2, 1],
        gaussian_blur=False)

    model = SwAV(arch='resnet18',
                 hidden_mlp=512,
                 gpus=0,
                 nodes=1,
                 num_samples=datamodule.num_samples,
                 batch_size=batch_size,
                 nmb_crops=[2, 1],
                 sinkhorn_iterations=1,
                 nmb_prototypes=2,
                 queue_length=0,
                 maxpool1=False,
                 first_conv=False,
                 dataset='cifar10')

    trainer = Trainer(gpus=0, fast_dev_run=True, default_root_dir=tmpdir)

    trainer.fit(model, datamodule=datamodule)
def cli_main():  # pragma: no-cover
    pl.seed_everything(1234)

    parser = ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser.add_argument('--dataset', type=str, help='stl10, cifar10', default='cifar10')
    parser.add_argument('--ckpt_path', type=str, help='path to ckpt')
    parser.add_argument('--data_dir', type=str, help='path to ckpt', default=os.getcwd())
    args = parser.parse_args()

    # load the backbone
    backbone = CPCV2.load_from_checkpoint(args.ckpt_path, strict=False)

    if args.dataset == 'cifar10':
        dm = CIFAR10DataModule.from_argparse_args(args)
        dm.train_transforms = CPCTrainTransformsCIFAR10()
        dm.val_transforms = CPCEvalTransformsCIFAR10()
        dm.test_transforms = CPCEvalTransformsCIFAR10()

    elif args.dataset == 'stl10':
        dm = STL10DataModule.from_argparse_args(args)
        dm.train_dataloader = dm.train_dataloader_labeled
        dm.val_dataloader = dm.val_dataloader_labeled
        dm.train_transforms = CPCTrainTransformsSTL10()
        dm.val_transforms = CPCEvalTransformsSTL10()
        dm.test_transforms = CPCEvalTransformsSTL10()

    # finetune
    tuner = SSLFineTuner(backbone, in_features=backbone.z_dim, num_classes=backbone.num_classes)
    trainer = pl.Trainer.from_argparse_args(args, early_stop_callback=True)
    trainer.fit(tuner, dm)

    trainer.test(datamodule=dm)
Beispiel #3
0
def get_datamodule(cfg: DictConfig) -> pl.LightningDataModule:
    """Create DataModule according to the user configuration.

    Currently only ImageNet and CIFAR10 dataset is supported.

    Args:
        cfg: The top-level user configuration object.

    Returns:
        A LightningDataModule, which will be consumed by the neural network model.

    Raises:
        ValueError: An unsupported dataset is specified.
    """
    dataset_cfg = cfg.dataset
    if dataset_cfg.name == 'cifar10':
        dm = CIFAR10DataModule(dataset_cfg.data_dir,
                               num_workers=dataset_cfg.workers,
                               batch_size=dataset_cfg.batch_size,
                               seed=cfg.seed)
    elif dataset_cfg.name == 'imagenet':
        dm = ImagenetDataModule(dataset_cfg.data_dir,
                                num_workers=dataset_cfg.workers,
                                batch_size=dataset_cfg.batch_size)
    else:
        raise ValueError(
            f'get_datamodule does not support dataset {dataset_cfg.name}')
    return dm
Beispiel #4
0
def cpc_v2_example():
	from pl_bolts.models.self_supervised import CPC_v2
	from pl_bolts.datamodules import CIFAR10DataModule
	from pl_bolts.models.self_supervised.cpc import CPCTrainTransformsCIFAR10, CPCEvalTransformsCIFAR10
	from pytorch_lightning.plugins import DDPPlugin

	# Data module.
	dm = CIFAR10DataModule(num_workers=12, batch_size=32)
	dm.train_transforms = CPCTrainTransformsCIFAR10()
	dm.val_transforms = CPCEvalTransformsCIFAR10()

	# Model.
	model = CPC_v2(encoder="cpc_encoder")

	# Fit.
	trainer = pl.Trainer(gpus=2, accelerator="ddp", plugins=DDPPlugin(find_unused_parameters=False))
	trainer.fit(model, datamodule=dm)

	#--------------------
	# CIFAR-10 pretrained model:
	weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/cpc/cpc-cifar10-v4-exp3/epoch%3D474.ckpt"
	# STL-10 pretrained model:
	#weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/cpc/cpc-stl10-v0-exp3/epoch%3D624.ckpt"
	cpc_v2 = CPC_v2.load_from_checkpoint(weight_path, strict=False)

	cpc_v2.freeze()
def cli_main():

    parser = ArgumentParser()

    # trainer args
    parser = pl.Trainer.add_argparse_args(parser)

    # model args
    parser = MocoV2.add_model_specific_args(parser)
    args = parser.parse_args()

    if args.dataset == 'cifar10':
        datamodule = CIFAR10DataModule.from_argparse_args(args)
        datamodule.train_transforms = Moco2TrainCIFAR10Transforms()
        datamodule.val_transforms = Moco2EvalCIFAR10Transforms()

    elif args.dataset == 'stl10':
        datamodule = STL10DataModule.from_argparse_args(args)
        datamodule.train_dataloader = datamodule.train_dataloader_mixed
        datamodule.val_dataloader = datamodule.val_dataloader_mixed
        datamodule.train_transforms = Moco2TrainSTL10Transforms()
        datamodule.val_transforms = Moco2EvalSTL10Transforms()

    elif args.dataset == 'imagenet2012':
        datamodule = SSLImagenetDataModule.from_argparse_args(args)
        datamodule.train_transforms = Moco2TrainImagenetTransforms()
        datamodule.val_transforms = Moco2EvalImagenetTransforms()

    model = MocoV2(**args.__dict__, datamodule=datamodule)

    trainer = pl.Trainer.from_argparse_args(args)
    trainer.fit(model)
Beispiel #6
0
def simclr_example():
	from pl_bolts.models.self_supervised import SimCLR
	from pl_bolts.datamodules import CIFAR10DataModule
	from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform, SimCLREvalDataTransform

	# Data module.
	dm = CIFAR10DataModule(num_workers=12, batch_size=32)
	dm.train_transforms = SimCLRTrainDataTransform(input_height=32)
	dm.val_transforms = SimCLREvalDataTransform(input_height=32)

	# Model.
	model = SimCLR(gpus=2, num_samples=dm.num_samples, batch_size=dm.batch_size, dataset="cifar10")

	# Fit.
	trainer = pl.Trainer(gpus=2, accelerator="ddp")
	trainer.fit(model, datamodule=dm)

	#--------------------
	# CIFAR-10 pretrained model:
	weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt"
	# ImageNet pretrained model:
	#weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt"
	simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)

	simclr.freeze()
def cli_main():
    from pl_bolts.callbacks.self_supervised import SSLOnlineEvaluator
    from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule, ImagenetDataModule
    from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform, SimCLREvalDataTransform

    seed_everything(1234)

    parser = ArgumentParser()

    # trainer args
    parser = pl.Trainer.add_argparse_args(parser)

    # model args
    parser = BYOL.add_model_specific_args(parser)
    args = parser.parse_args()

    # pick data
    dm = None

    # init default datamodule
    if args.dataset == 'cifar10':
        dm = CIFAR10DataModule.from_argparse_args(args)
        dm.train_transforms = SimCLRTrainDataTransform(32)
        dm.val_transforms = SimCLREvalDataTransform(32)
        args.num_classes = dm.num_classes

    elif args.dataset == 'stl10':
        dm = STL10DataModule.from_argparse_args(args)
        dm.train_dataloader = dm.train_dataloader_mixed
        dm.val_dataloader = dm.val_dataloader_mixed

        (c, h, w) = dm.size()
        dm.train_transforms = SimCLRTrainDataTransform(h)
        dm.val_transforms = SimCLREvalDataTransform(h)
        args.num_classes = dm.num_classes

    elif args.dataset == 'imagenet2012':
        dm = ImagenetDataModule.from_argparse_args(args, image_size=196)
        (c, h, w) = dm.size()
        dm.train_transforms = SimCLRTrainDataTransform(h)
        dm.val_transforms = SimCLREvalDataTransform(h)
        args.num_classes = dm.num_classes

    model = BYOL(**args.__dict__)

    def to_device(batch, device):
        (x1, x2), y = batch
        x1 = x1.to(device)
        y = y.to(device)
        return x1, y

    # finetune in real-time
    online_eval = SSLOnlineEvaluator(z_dim=2048, num_classes=dm.num_classes)
    online_eval.to_device = to_device

    trainer = pl.Trainer.from_argparse_args(args,
                                            max_steps=300000,
                                            callbacks=[online_eval])

    trainer.fit(model, dm)
Beispiel #8
0
def main(train_dir, batch_size=128, lr=1e-3, num_workers=1):
    params = argparse.Namespace(**locals())
    model = ClassificationModel(params)
    dm = CIFAR10DataModule()

    normalize = transforms.Normalize(
        mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
        std=[x / 255.0 for x in [63.0, 62.1, 66.7]],
    )
    dm.train_transforms = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])

    dm.test_transforms = dm.val_transformrs = transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])
    # show(dm)

    ckpt_dir = os.path.join(train_dir, 'checkpoints')
    ckcb = ModelCheckpoint(dirpath=ckpt_dir,
                           filename='weights#{epoch}',
                           save_top_k=None,
                           period=1)

    logger = TestTubeLogger(save_dir=os.path.join(train_dir, 'logs'),
                            version=1)
    trainer = pl.Trainer(log_every_n_steps=10,
                         checkpoint_callback=ckcb,
                         logger=logger,
                         gpus=1)
    trainer.fit(model, dm)
def test_byol(tmpdir, datadir):
    datamodule = CIFAR10DataModule(data_dir=datadir, num_workers=0, batch_size=2)
    datamodule.train_transforms = CPCTrainTransformsCIFAR10()
    datamodule.val_transforms = CPCEvalTransformsCIFAR10()

    model = BYOL(data_dir=datadir, num_classes=datamodule)
    trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir)
    trainer.fit(model, datamodule=datamodule)
Beispiel #10
0
def run_training(config):
    """ Instanciate a datamodule, a model and a trainer and run trainer.fit(model, data) """
    data = CIFAR10DataModule(data_dir=config.rootdir,
                             num_workers=config.num_workers,
                             batch_size=config.batch_size)
    model = LightningModel(config)
    trainer = init_trainer()
    trainer.fit(model, data)
def test_simclr(tmpdir, datadir):
    datamodule = CIFAR10DataModule(data_dir=datadir, num_workers=0, batch_size=2)
    datamodule.train_transforms = SimCLRTrainDataTransform(32)
    datamodule.val_transforms = SimCLREvalDataTransform(32)

    model = SimCLR(batch_size=2, num_samples=datamodule.num_samples, gpus=0, nodes=1, dataset='cifar10')
    trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir)
    trainer.fit(model, datamodule=datamodule)
def test_moco(tmpdir, datadir):
    datamodule = CIFAR10DataModule(data_dir=datadir, num_workers=0, batch_size=2)
    datamodule.train_transforms = Moco2TrainCIFAR10Transforms()
    datamodule.val_transforms = Moco2EvalCIFAR10Transforms()

    model = Moco_v2(data_dir=datadir, batch_size=2, online_ft=True)
    trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir, callbacks=[MocoLRScheduler()])
    trainer.fit(model, datamodule=datamodule)
Beispiel #13
0
def cli_main():
    from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
    from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule
    from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule

    pl.seed_everything(1234)
    parser = ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser = CPCV2.add_model_specific_args(parser)
    parser.add_argument('--dataset', default='cifar10', type=str)
    parser.add_argument('--data_dir', default='.', type=str)
    parser.add_argument('--meta_dir',
                        default='.',
                        type=str,
                        help='path to meta.bin for imagenet')
    parser.add_argument('--num_workers', default=8, type=int)
    parser.add_argument('--batch_size', type=int, default=128)

    args = parser.parse_args()

    datamodule = None

    online_evaluator = SSLOnlineEvaluator()
    if args.dataset == 'cifar10':
        datamodule = CIFAR10DataModule.from_argparse_args(args)
        datamodule.train_transforms = CPCTrainTransformsCIFAR10()
        datamodule.val_transforms = CPCEvalTransformsCIFAR10()
        args.patch_size = 8

    elif args.dataset == 'stl10':
        datamodule = STL10DataModule.from_argparse_args(args)
        datamodule.train_dataloader = datamodule.train_dataloader_mixed
        datamodule.val_dataloader = datamodule.val_dataloader_mixed
        datamodule.train_transforms = CPCTrainTransformsSTL10()
        datamodule.val_transforms = CPCEvalTransformsSTL10()
        args.patch_size = 16

        # 16 GB RAM - 64
        # 32 GB RAM - 144
        args.batch_size = 144

        def to_device(batch, device):
            (_, _), (x2, y2) = batch
            x2 = x2.to(device)
            y2 = y2.to(device)
            return x2, y2

        online_evaluator.to_device = to_device

    elif args.dataset == 'imagenet2012':
        datamodule = SSLImagenetDataModule.from_argparse_args(args)
        datamodule.train_transforms = CPCTrainTransformsImageNet128()
        datamodule.val_transforms = CPCEvalTransformsImageNet128()
        args.patch_size = 32

    model = CPCV2(**vars(args))
    trainer = pl.Trainer.from_argparse_args(args, callbacks=[online_evaluator])
    trainer.fit(model, datamodule)
Beispiel #14
0
def cli_main():  # pragma: no-cover
    pl.seed_everything(1234)

    parser = ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser.add_argument('--dataset',
                        type=str,
                        help='stl10, cifar10',
                        default='cifar10')
    parser.add_argument('--ckpt_path', type=str, help='path to ckpt')
    parser.add_argument('--data_dir',
                        type=str,
                        help='path to ckpt',
                        default=os.getcwd())
    args = parser.parse_args()

    # load the backbone
    backbone = SimCLR.load_from_checkpoint(args.ckpt_path, strict=False)

    # init default datamodule
    if args.dataset == 'cifar10':
        dm = CIFAR10DataModule.from_argparse_args(args)
        dm.train_transforms = SimCLRTrainDataTransform(32)
        dm.val_transforms = SimCLREvalDataTransform(32)
        dm.test_transforms = SimCLREvalDataTransform(32)
        args.num_samples = dm.num_samples

    elif args.dataset == 'stl10':
        dm = STL10DataModule.from_argparse_args(args)
        dm.train_dataloader = dm.train_dataloader_labeled
        dm.val_dataloader = dm.val_dataloader_labeled
        args.num_samples = dm.num_labeled_samples

        (c, h, w) = dm.size()
        dm.train_transforms = SimCLRTrainDataTransform(h)
        dm.val_transforms = SimCLREvalDataTransform(h)

    elif args.dataset == 'imagenet2012':
        dm = ImagenetDataModule.from_argparse_args(args, image_size=196)
        (c, h, w) = dm.size()
        dm.train_transforms = SimCLRTrainDataTransform(h)
        dm.val_transforms = SimCLREvalDataTransform(h)

    # finetune
    tuner = SSLFineTuner(backbone,
                         in_features=2048 * 2 * 2,
                         num_classes=dm.num_classes,
                         hidden_dim=None)
    trainer = pl.Trainer.from_argparse_args(args, early_stop_callback=True)
    trainer.fit(tuner, dm)

    trainer.test(datamodule=dm)
Beispiel #15
0
def test_cpcv2(tmpdir):
    seed_everything()

    datamodule = CIFAR10DataModule(data_dir=tmpdir, num_workers=0, batch_size=2)
    datamodule.train_transforms = CPCTrainTransformsCIFAR10()
    datamodule.val_transforms = CPCEvalTransformsCIFAR10()

    model = CPCV2(encoder='resnet18', data_dir=tmpdir, batch_size=2, online_ft=True, datamodule=datamodule)
    trainer = pl.Trainer(fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir)
    trainer.fit(model)
    loss = trainer.progress_bar_dict['val_nce']

    assert float(loss) > 0
Beispiel #16
0
def test_moco(tmpdir):
    seed_everything()

    datamodule = CIFAR10DataModule(tmpdir, num_workers=0, batch_size=2)
    datamodule.train_transforms = Moco2TrainCIFAR10Transforms()
    datamodule.val_transforms = Moco2EvalCIFAR10Transforms()

    model = MocoV2(data_dir=tmpdir, batch_size=2, online_ft=True)
    trainer = pl.Trainer(fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir, callbacks=[MocoLRScheduler()])
    trainer.fit(model, datamodule=datamodule)
    loss = trainer.progress_bar_dict['loss']

    assert float(loss) > 0
Beispiel #17
0
def test_simclr(tmpdir):
    seed_everything()

    datamodule = CIFAR10DataModule(tmpdir, num_workers=0, batch_size=2)
    datamodule.train_transforms = SimCLRTrainDataTransform(32)
    datamodule.val_transforms = SimCLREvalDataTransform(32)

    model = SimCLR(batch_size=2, num_samples=datamodule.num_samples)
    trainer = pl.Trainer(fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir)
    trainer.fit(model, datamodule)
    loss = trainer.progress_bar_dict['loss']

    assert float(loss) > 0
Beispiel #18
0
def test_cpcv2(tmpdir):
    reset_seed()

    datamodule = CIFAR10DataModule(data_dir=tmpdir, num_workers=0)
    datamodule.train_transforms = CPCTrainTransformsCIFAR10()
    datamodule.val_transforms = CPCEvalTransformsCIFAR10()

    model = CPCV2(encoder='resnet18', data_dir=tmpdir, batch_size=2, online_ft=True, datamodule=datamodule)
    trainer = pl.Trainer(overfit_batches=2, max_epochs=1, default_root_dir=tmpdir)
    trainer.fit(model)
    loss = trainer.callback_metrics['loss']

    assert loss > 0
Beispiel #19
0
def test_simclr(tmpdir):
    reset_seed()

    datamodule = CIFAR10DataModule(tmpdir, num_workers=0)
    datamodule.train_transforms = SimCLRTrainDataTransform(32)
    datamodule.val_transforms = SimCLREvalDataTransform(32)

    model = SimCLR(data_dir=tmpdir, batch_size=2, datamodule=datamodule, online_ft=True)
    trainer = pl.Trainer(overfit_batches=2, max_epochs=1, default_root_dir=tmpdir)
    trainer.fit(model)
    loss = trainer.callback_metrics['loss']

    assert loss > 0
Beispiel #20
0
def test_moco(tmpdir):
    reset_seed()

    datamodule = CIFAR10DataModule(tmpdir, num_workers=0)
    datamodule.train_transforms = Moco2TrainCIFAR10Transforms()
    datamodule.val_transforms = Moco2EvalCIFAR10Transforms()

    model = MocoV2(data_dir=tmpdir, batch_size=2, datamodule=datamodule, online_ft=True)
    trainer = pl.Trainer(overfit_batches=2, max_epochs=1, default_root_dir=tmpdir, callbacks=[MocoLRScheduler()])
    trainer.fit(model)
    loss = trainer.callback_metrics['loss']

    assert loss > 0
Beispiel #21
0
def test_byol(tmpdir):
    seed_everything()

    datamodule = CIFAR10DataModule(data_dir=tmpdir, num_workers=0, batch_size=2)
    datamodule.train_transforms = CPCTrainTransformsCIFAR10()
    datamodule.val_transforms = CPCEvalTransformsCIFAR10()

    model = BYOL(data_dir=tmpdir, num_classes=datamodule)
    trainer = pl.Trainer(fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir, max_steps=2)
    trainer.fit(model, datamodule)
    loss = trainer.progress_bar_dict['loss']

    assert float(loss) < 1.0
Beispiel #22
0
def cli_main():
    from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
    from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule, ImagenetDataModule

    parser = ArgumentParser()

    # trainer args
    parser = pl.Trainer.add_argparse_args(parser)

    # model args
    parser = SimCLR.add_model_specific_args(parser)
    args = parser.parse_args()

    # init default datamodule
    if args.dataset == 'cifar10':
        dm = CIFAR10DataModule.from_argparse_args(args)
        dm.train_transforms = SimCLRTrainDataTransform(32)
        dm.val_transforms = SimCLREvalDataTransform(32)
        args.num_samples = dm.num_samples

    elif args.dataset == 'stl10':
        dm = STL10DataModule.from_argparse_args(args)
        dm.train_dataloader = dm.train_dataloader_mixed
        dm.val_dataloader = dm.val_dataloader_mixed
        args.num_samples = dm.num_unlabeled_samples

        (c, h, w) = dm.size()
        dm.train_transforms = SimCLRTrainDataTransform(h)
        dm.val_transforms = SimCLREvalDataTransform(h)

    elif args.dataset == 'imagenet2012':
        dm = ImagenetDataModule.from_argparse_args(args, image_size=196)
        (c, h, w) = dm.size()
        dm.train_transforms = SimCLRTrainDataTransform(h)
        dm.val_transforms = SimCLREvalDataTransform(h)

    model = SimCLR(**args.__dict__)

    # finetune in real-time
    def to_device(batch, device):
        (x1, x2), y = batch
        x1 = x1.to(device)
        y = y.to(device)
        return x1, y

    online_eval = SSLOnlineEvaluator(z_dim=2048 * 2 * 2,
                                     num_classes=dm.num_classes)
    online_eval.to_device = to_device

    trainer = pl.Trainer.from_argparse_args(args, callbacks=[online_eval])
    trainer.fit(model, dm)
Beispiel #23
0
def cli_main():
    from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule
    from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule

    pl.seed_everything(1234)
    parser = ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser = CPCV2.add_model_specific_args(parser)

    args = parser.parse_args()
    args.online_ft = True

    datamodule = None

    online_evaluator = SSLOnlineEvaluator()
    if args.dataset == 'cifar10':
        datamodule = CIFAR10DataModule.from_argparse_args(args)
        datamodule.train_transforms = CPCTrainTransformsCIFAR10()
        datamodule.val_transforms = CPCEvalTransformsCIFAR10()
        args.patch_size = 8

    elif args.dataset == 'stl10':
        datamodule = STL10DataModule.from_argparse_args(args)
        datamodule.train_dataloader = datamodule.train_dataloader_mixed
        datamodule.val_dataloader = datamodule.val_dataloader_mixed
        datamodule.train_transforms = CPCTrainTransformsSTL10()
        datamodule.val_transforms = CPCEvalTransformsSTL10()
        args.patch_size = 16

        # 16 GB RAM - 64
        # 32 GB RAM - 144
        args.batch_size = 144

        def to_device(batch, device):
            (_, _), (x2, y2) = batch
            x2 = x2.to(device)
            y2 = y2.to(device)
            return x2, y2

        online_evaluator.to_device = to_device

    elif args.dataset == 'imagenet2012':
        datamodule = SSLImagenetDataModule.from_argparse_args(args)
        datamodule.train_transforms = CPCTrainTransformsImageNet128()
        datamodule.val_transforms = CPCEvalTransformsImageNet128()
        args.patch_size = 32

    model = CPCV2(**vars(args), datamodule=datamodule)
    trainer = pl.Trainer.from_argparse_args(args, callbacks=[online_evaluator])
    trainer.fit(model)
def test_cpcv2(tmpdir, datadir):
    datamodule = CIFAR10DataModule(data_dir=datadir, num_workers=0, batch_size=2)
    datamodule.train_transforms = CPCTrainTransformsCIFAR10()
    datamodule.val_transforms = CPCEvalTransformsCIFAR10()

    model = CPC_v2(
        encoder='mobilenet_v3_small',
        patch_size=8,
        patch_overlap=2,
        online_ft=True,
        num_classes=datamodule.num_classes,
    )
    trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir)
    trainer.fit(model, datamodule=datamodule)
Beispiel #25
0
def test_from_pretrained(datadir):
    dm = CIFAR10DataModule(data_dir=datadir, batch_size=4)
    dm.setup()
    dm.prepare_data()

    data_loader = dm.val_dataloader()

    vae = VAE(input_height=32)
    ae = AE(input_height=32)

    assert len(VAE.pretrained_weights_available()) > 0
    assert len(AE.pretrained_weights_available()) > 0

    exception_raised = False

    try:
        vae = vae.from_pretrained("cifar10-resnet18")

        # test forward method on pre-trained weights
        for x, y in data_loader:
            vae(x)
            break

        vae = vae.from_pretrained(
            "stl10-resnet18"
        )  # try loading weights not compatible with exact architecture

        ae = ae.from_pretrained("cifar10-resnet18")

        # test forward method on pre-trained weights
        with torch.no_grad():
            for x, y in data_loader:
                ae(x)
                break

    except Exception:
        exception_raised = True

    assert exception_raised is False, "error in loading weights"

    keyerror = False

    try:
        vae.from_pretrained("abc")
        ae.from_pretrained("xyz")
    except KeyError:
        keyerror = True

    assert keyerror is True, "KeyError not raised when provided with illegal checkpoint name"
def test_swav(tmpdir, datadir):
    seed_everything()

    batch_size = 2

    # inputs, y = batch  (doesn't receive y for some reason)
    datamodule = CIFAR10DataModule(
        data_dir=datadir,
        batch_size=batch_size,
        num_workers=0
    )

    datamodule.train_transforms = SwAVTrainDataTransform(
        normalize=cifar10_normalization(),
        size_crops=[32, 16],
        nmb_crops=[2, 1],
        gaussian_blur=False
    )
    datamodule.val_transforms = SwAVEvalDataTransform(
        normalize=cifar10_normalization(),
        size_crops=[32, 16],
        nmb_crops=[2, 1],
        gaussian_blur=False
    )

    model = SwAV(
        arch='resnet18',
        hidden_mlp=512,
        gpus=0,
        nodes=1,
        num_samples=datamodule.num_samples,
        batch_size=batch_size,
        nmb_crops=[2, 1],
        sinkhorn_iterations=1,
        nmb_prototypes=2,
        maxpool1=False,
        first_conv=False,
        dataset='cifar10'
    )

    trainer = pl.Trainer(
        gpus=0, fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir, max_steps=3
    )

    trainer.fit(model, datamodule)
    loss = trainer.progress_bar_dict['loss']

    assert float(loss) > 0
Beispiel #27
0
def byol_example():
	from pl_bolts.models.self_supervised import BYOL
	from pl_bolts.datamodules import CIFAR10DataModule
	from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform, SimCLREvalDataTransform

	# Data module.
	dm = CIFAR10DataModule(num_workers=12, batch_size=32)
	dm.train_transforms = SimCLRTrainDataTransform(input_height=32)
	dm.val_transforms = SimCLREvalDataTransform(input_height=32)

	# Model.
	model = BYOL(num_classes=10)

	# Fit.
	trainer = pl.Trainer(gpus=2, accelerator="ddp")
	trainer.fit(model, datamodule=dm)
def train():
    parser = ArgumentParser()
    parser.add_argument('--gpus', type=int, default=None)
    parser.add_argument('--dataset', type=str, default='cifar10')
    args = parser.parse_args()

    if args.dataset == 'cifar10':
        dataset = CIFAR10DataModule('.')
    if args.dataset == 'imagenet':
        dataset = ImagenetDataModule('.')

    sampler = ImageSampler()

    vae = VAE()
    trainer = pl.Trainer(gpus=args.gpus, max_epochs=20, callbacks=[sampler])
    trainer.fit(vae, dataset)
Beispiel #29
0
def simsiam_example():
	from pl_bolts.models.self_supervised import SimSiam
	from pl_bolts.datamodules import CIFAR10DataModule
	from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform, SimCLREvalDataTransform

	# Data module.
	dm = CIFAR10DataModule(num_workers=12, batch_size=32)
	dm.train_transforms = SimCLRTrainDataTransform(input_height=32)
	dm.val_transforms = SimCLREvalDataTransform(input_height=32)

	# Model.
	model = SimSiam(gpus=2, num_samples=dm.num_samples, batch_size=dm.batch_size, dataset="cifar10")

	# Fit.
	trainer = pl.Trainer(gpus=2, accelerator="ddp")
	trainer.fit(model, datamodule=dm)
Beispiel #30
0
def moco_v2_example():
	from pl_bolts.models.self_supervised import Moco_v2
	from pl_bolts.datamodules import CIFAR10DataModule
	from pl_bolts.models.self_supervised.moco import Moco2TrainCIFAR10Transforms, Moco2EvalCIFAR10Transforms
	from pytorch_lightning.plugins import DDPPlugin

	# Data module.
	dm = CIFAR10DataModule(num_workers=12, batch_size=32)
	dm.train_transforms = Moco2TrainCIFAR10Transforms()
	dm.val_transforms = Moco2EvalCIFAR10Transforms()

	# Model.
	model = Moco_v2()

	# Fit.
	trainer = pl.Trainer(gpus=2, accelerator="ddp", plugins=DDPPlugin(find_unused_parameters=False))
	trainer.fit(model, datamodule=dm)