def cli_main():
    from pl_bolts.callbacks.self_supervised import SSLOnlineEvaluator
    from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule, ImagenetDataModule
    from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform, SimCLREvalDataTransform

    seed_everything(1234)

    parser = ArgumentParser()

    # trainer args
    parser = pl.Trainer.add_argparse_args(parser)

    # model args
    parser = BYOL.add_model_specific_args(parser)
    args = parser.parse_args()

    # pick data
    dm = None

    # init default datamodule
    if args.dataset == 'cifar10':
        dm = CIFAR10DataModule.from_argparse_args(args)
        dm.train_transforms = SimCLRTrainDataTransform(32)
        dm.val_transforms = SimCLREvalDataTransform(32)
        args.num_classes = dm.num_classes

    elif args.dataset == 'stl10':
        dm = STL10DataModule.from_argparse_args(args)
        dm.train_dataloader = dm.train_dataloader_mixed
        dm.val_dataloader = dm.val_dataloader_mixed

        (c, h, w) = dm.size()
        dm.train_transforms = SimCLRTrainDataTransform(h)
        dm.val_transforms = SimCLREvalDataTransform(h)
        args.num_classes = dm.num_classes

    elif args.dataset == 'imagenet2012':
        dm = ImagenetDataModule.from_argparse_args(args, image_size=196)
        (c, h, w) = dm.size()
        dm.train_transforms = SimCLRTrainDataTransform(h)
        dm.val_transforms = SimCLREvalDataTransform(h)
        args.num_classes = dm.num_classes

    model = BYOL(**args.__dict__)

    def to_device(batch, device):
        (x1, x2), y = batch
        x1 = x1.to(device)
        y = y.to(device)
        return x1, y

    # finetune in real-time
    online_eval = SSLOnlineEvaluator(z_dim=2048, num_classes=dm.num_classes)
    online_eval.to_device = to_device

    trainer = pl.Trainer.from_argparse_args(args,
                                            max_steps=300000,
                                            callbacks=[online_eval])

    trainer.fit(model, dm)
Example #2
0
def simclr_example():
	from pl_bolts.models.self_supervised import SimCLR
	from pl_bolts.datamodules import CIFAR10DataModule
	from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform, SimCLREvalDataTransform

	# Data module.
	dm = CIFAR10DataModule(num_workers=12, batch_size=32)
	dm.train_transforms = SimCLRTrainDataTransform(input_height=32)
	dm.val_transforms = SimCLREvalDataTransform(input_height=32)

	# Model.
	model = SimCLR(gpus=2, num_samples=dm.num_samples, batch_size=dm.batch_size, dataset="cifar10")

	# Fit.
	trainer = pl.Trainer(gpus=2, accelerator="ddp")
	trainer.fit(model, datamodule=dm)

	#--------------------
	# CIFAR-10 pretrained model:
	weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt"
	# ImageNet pretrained model:
	#weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt"
	simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)

	simclr.freeze()
Example #3
0
def byol_example():
	from pl_bolts.models.self_supervised import BYOL
	from pl_bolts.datamodules import CIFAR10DataModule
	from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform, SimCLREvalDataTransform

	# Data module.
	dm = CIFAR10DataModule(num_workers=12, batch_size=32)
	dm.train_transforms = SimCLRTrainDataTransform(input_height=32)
	dm.val_transforms = SimCLREvalDataTransform(input_height=32)

	# Model.
	model = BYOL(num_classes=10)

	# Fit.
	trainer = pl.Trainer(gpus=2, accelerator="ddp")
	trainer.fit(model, datamodule=dm)
Example #4
0
def simsiam_example():
	from pl_bolts.models.self_supervised import SimSiam
	from pl_bolts.datamodules import CIFAR10DataModule
	from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform, SimCLREvalDataTransform

	# Data module.
	dm = CIFAR10DataModule(num_workers=12, batch_size=32)
	dm.train_transforms = SimCLRTrainDataTransform(input_height=32)
	dm.val_transforms = SimCLREvalDataTransform(input_height=32)

	# Model.
	model = SimSiam(gpus=2, num_samples=dm.num_samples, batch_size=dm.batch_size, dataset="cifar10")

	# Fit.
	trainer = pl.Trainer(gpus=2, accelerator="ddp")
	trainer.fit(model, datamodule=dm)
Example #5
0
def train_self_supervised():
    logger = TensorBoardLogger('runs', name='SimCLR_libri_speech')

    # 8, 224, 8 worked well
    # 16, 224, 4 as well
    batch_size = 16
    input_height = 224
    num_workers = 4

    train_dataset = LibrispeechSpectrogramDataset(
        transform=SimCLRTrainDataTransform(input_height=input_height,
                                           gaussian_blur=False),
        train=True)
    val_dataset = LibrispeechSpectrogramDataset(
        transform=SimCLREvalDataTransform(input_height=input_height,
                                          gaussian_blur=False),
        train=False)

    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              num_workers=num_workers)
    test_loader = DataLoader(val_dataset,
                             batch_size=batch_size,
                             num_workers=num_workers)

    model = SimCLR(gpus=1,
                   num_samples=len(train_dataset),
                   batch_size=batch_size,
                   dataset='librispeech')

    checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        dirpath=r'D:\Users\lVavrek\research\data',
        filename="self-supervised-librispeech-{epoch:02d}-{val_loss:.2f}",
        save_top_k=3,
        mode="min",
    )

    early_stopping = EarlyStopping(monitor="val_loss")

    trainer = Trainer(gpus=1,
                      callbacks=[checkpoint_callback, early_stopping],
                      logger=logger)
    trainer.fit(model, train_loader, test_loader)
Example #6
0
def get_self_supervised_model(run_params):
    import pl_bolts
    from pl_bolts.models.self_supervised import SimCLR
    from pl_bolts.models.self_supervised.simclr import (
        SimCLRTrainDataTransform,
        SimCLREvalDataTransform,
    )
    from pytorch_lightning import Trainer
    from pytorch_lightning.loggers import TensorBoardLogger
    from pytorch_lightning.callbacks import ModelCheckpoint
    from pytorch_lightning.callbacks.early_stopping import EarlyStopping

    checkpoints_dir = os.path.join(run_params["PATH_PREFIX"], "checkpoints")
    checkpoint_resume = os.path.join(checkpoints_dir,
                                     run_params["MODEL_SAVE_NAME"] + ".ckpt")

    dataset = SelfSupervisedDataset(
        final_df,
        validation=False,
        transform=SimCLRTrainDataTransform(
            min(run_params["RESIZE"], run_params["RANDOM_RESIZE_CROP"])),
        prefix=run_params["RAW_PREPROCESS_FOLDER"] + "/",
    )
    val_dataset = SelfSupervisedDataset(
        final_df,
        validation=True,
        transform=SimCLREvalDataTransform(
            min(run_params["RESIZE"], run_params["RANDOM_RESIZE_CROP"])),
        prefix=run_params["RAW_PREPROCESS_FOLDER"] + "/",
    )

    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=run_params["SELF_SUPERVISED_BATCH_SIZE"],
        num_workers=0)

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=run_params["SELF_SUPERVISED_BATCH_SIZE"],
        num_workers=0)
    num_samples = len(dataset)

    # #init model with batch size, num_samples (len of data), epochs to train, and autofinds learning rate
    model_self_sup = SimCLR(
        gpus=1,
        arch="resnet50",
        dataset="",
        max_epochs=run_params["SELF_SUPERVISED_EPOCHS"],
        warmup_epochs=run_params["SELF_SUPERVISED_WARMUP_EPOCHS"],
        batch_size=run_params["SELF_SUPERVISED_BATCH_SIZE"],
        num_samples=num_samples,
    )

    if run_params["SELF_SUPERVISED_TRAIN"]:
        logger = TensorBoardLogger(
            os.path.join(run_params["PATH_PREFIX"], "tb_logs", "simCLR"),
            name=run_params["MODEL_SAVE_NAME"],
        )
        early_stopping = EarlyStopping("val_loss", patience=5)

        if os.path.exists(checkpoint_resume):
            trainer = Trainer(
                gpus=1,
                max_epochs=run_params["SELF_SUPERVISED_EPOCHS"],
                logger=logger,
                auto_scale_batch_size=True,
                resume_from_checkpoint=checkpoint_resume,
                callbacks=[early_stopping],
            )
        else:
            checkpoint_callback = ModelCheckpoint(
                monitor="val_loss",
                dirpath=checkpoints_dir,
                filename=run_params["MODEL_SAVE_NAME"],
                save_top_k=1,
                mode="min",
            )

            trainer = Trainer(
                gpus=1,
                max_epochs=run_params["SELF_SUPERVISED_EPOCHS"],
                logger=logger,
                auto_scale_batch_size=True,
                callbacks=[checkpoint_callback, early_stopping],
            )

        trainer.fit(model_self_sup, data_loader, val_loader)
        model_self_sup = model_self_sup.load_from_checkpoint(checkpoint_resume)
    elif os.path.exists(checkpoint_resume):
        model_self_sup.load_from_checkpoint(checkpoint_resume)
    else:
        print(
            f"Not checkpoint found, so it could not load model from it\n{checkpoint_resume}"
        )

    return model_self_sup
def cli_main():
    from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
    from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule
    from pl_bolts.models.self_supervised.simclr import SimCLREvalDataTransform, SimCLRTrainDataTransform

    seed_everything(1234)

    parser = ArgumentParser()

    # trainer args
    parser = pl.Trainer.add_argparse_args(parser)

    # model args
    parser = SimSiam.add_model_specific_args(parser)
    args = parser.parse_args()

    # pick data
    dm = None

    # init datamodule
    if args.dataset == "stl10":
        dm = STL10DataModule(data_dir=args.data_dir,
                             batch_size=args.batch_size,
                             num_workers=args.num_workers)

        dm.train_dataloader = dm.train_dataloader_mixed
        dm.val_dataloader = dm.val_dataloader_mixed
        args.num_samples = dm.num_unlabeled_samples

        args.maxpool1 = False
        args.first_conv = True
        args.input_height = dm.size()[-1]

        normalization = stl10_normalization()

        args.gaussian_blur = True
        args.jitter_strength = 1.0
    elif args.dataset == "cifar10":
        val_split = 5000
        if args.num_nodes * args.gpus * args.batch_size > val_split:
            val_split = args.num_nodes * args.gpus * args.batch_size

        dm = CIFAR10DataModule(
            data_dir=args.data_dir,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            val_split=val_split,
        )

        args.num_samples = dm.num_samples

        args.maxpool1 = False
        args.first_conv = False
        args.input_height = dm.size()[-1]
        args.temperature = 0.5

        normalization = cifar10_normalization()

        args.gaussian_blur = False
        args.jitter_strength = 0.5
    elif args.dataset == "imagenet":
        args.maxpool1 = True
        args.first_conv = True
        normalization = imagenet_normalization()

        args.gaussian_blur = True
        args.jitter_strength = 1.0

        args.batch_size = 64
        args.num_nodes = 8
        args.gpus = 8  # per-node
        args.max_epochs = 800

        args.optimizer = "sgd"
        args.lars_wrapper = True
        args.learning_rate = 4.8
        args.final_lr = 0.0048
        args.start_lr = 0.3
        args.online_ft = True

        dm = ImagenetDataModule(data_dir=args.data_dir,
                                batch_size=args.batch_size,
                                num_workers=args.num_workers)

        args.num_samples = dm.num_samples
        args.input_height = dm.size()[-1]
    else:
        raise NotImplementedError(
            "other datasets have not been implemented till now")

    dm.train_transforms = SimCLRTrainDataTransform(
        input_height=args.input_height,
        gaussian_blur=args.gaussian_blur,
        jitter_strength=args.jitter_strength,
        normalize=normalization,
    )

    dm.val_transforms = SimCLREvalDataTransform(
        input_height=args.input_height,
        gaussian_blur=args.gaussian_blur,
        jitter_strength=args.jitter_strength,
        normalize=normalization,
    )

    model = SimSiam(**args.__dict__)

    # finetune in real-time
    online_evaluator = None
    if args.online_ft:
        # online eval
        online_evaluator = SSLOnlineEvaluator(
            drop_p=0.0,
            hidden_dim=None,
            z_dim=args.hidden_mlp,
            num_classes=dm.num_classes,
            dataset=args.dataset,
        )

    trainer = pl.Trainer.from_argparse_args(
        args,
        sync_batchnorm=True if args.gpus > 1 else False,
        callbacks=[online_evaluator] if args.online_ft else None,
    )

    trainer.fit(model, datamodule=dm)
Example #8
0
    # trainer args
    parser = pl.Trainer.add_argparse_args(parser)

    # model args
    parser = BYOL.add_model_specific_args(parser)
    parser.add_argument('--input_size', type=int, default=32)
    args = parser.parse_args()

    # pick data
    dm = None

    # init default datamodule
    if args.dataset == 'cifar10':
        dm = CIFAR10DataModule.from_argparse_args(args)
        dm.train_transforms = SimCLRTrainDataTransform(args.input_size)
        dm.val_transforms = SimCLREvalDataTransform(args.input_size)
        args.num_classes = dm.num_classes
        dm.name_classes = ['plane', 'car', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck']
        dm.num_channels = 3

    elif args.dataset == 'stl10':
        dm = STL10DataModule.from_argparse_args(args)
        dm.train_dataloader = dm.train_dataloader_mixed
        dm.val_dataloader = dm.val_dataloader_mixed

        (c, h, w) = dm.size()
        dm.train_transforms = SimCLRTrainDataTransform(h)
        dm.val_transforms = SimCLREvalDataTransform(h)
        args.num_classes = dm.num_classes
        dm.num_channels = 3
Example #9
0
def simple_simclr_example():
	from pl_bolts.models.self_supervised import SimCLR
	from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform, SimCLREvalDataTransform

	# Load ResNet50 pretrained using SimCLR on ImageNet.
	weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt"
	simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)

	#train_dataset = MyDataset(transforms=SimCLRTrainDataTransform())
	#val_dataset = MyDataset(transforms=SimCLREvalDataTransform())
	train_dataset = torchvision.datasets.CIFAR10("", train=True, download=True, transform=SimCLRTrainDataTransform())
	val_dataset = torchvision.datasets.CIFAR10("", train=False, download=True, transform=SimCLREvalDataTransform())

	# SimCLR needs a lot of compute!
	model = SimCLR(gpus=2, num_samples=len(train_dataset), batch_size=32, dataset="cifar10")

	trainer = pl.Trainer(gpus=2, accelerator="ddp")
	trainer.fit(
		model,
		torch.utils.data.DataLoader(train_dataset, batch_size=32, num_workers=12),
		torch.utils.data.DataLoader(val_dataset, batch_size=32, num_workers=12),
	)

	#--------------------
	simclr_resnet50 = simclr.encoder
	simclr_resnet50.eval()

	#my_dataset = SomeDataset()
	my_dataset = val_dataset
	for batch in my_dataset:
		x, y = batch
		out = simclr_resnet50(x)