def cli_main(): from pl_bolts.callbacks.self_supervised import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule, ImagenetDataModule from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform, SimCLREvalDataTransform seed_everything(1234) parser = ArgumentParser() # trainer args parser = pl.Trainer.add_argparse_args(parser) # model args parser = BYOL.add_model_specific_args(parser) args = parser.parse_args() # pick data dm = None # init default datamodule if args.dataset == 'cifar10': dm = CIFAR10DataModule.from_argparse_args(args) dm.train_transforms = SimCLRTrainDataTransform(32) dm.val_transforms = SimCLREvalDataTransform(32) args.num_classes = dm.num_classes elif args.dataset == 'stl10': dm = STL10DataModule.from_argparse_args(args) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) args.num_classes = dm.num_classes elif args.dataset == 'imagenet2012': dm = ImagenetDataModule.from_argparse_args(args, image_size=196) (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) args.num_classes = dm.num_classes model = BYOL(**args.__dict__) def to_device(batch, device): (x1, x2), y = batch x1 = x1.to(device) y = y.to(device) return x1, y # finetune in real-time online_eval = SSLOnlineEvaluator(z_dim=2048, num_classes=dm.num_classes) online_eval.to_device = to_device trainer = pl.Trainer.from_argparse_args(args, max_steps=300000, callbacks=[online_eval]) trainer.fit(model, dm)
def simclr_example(): from pl_bolts.models.self_supervised import SimCLR from pl_bolts.datamodules import CIFAR10DataModule from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform, SimCLREvalDataTransform # Data module. dm = CIFAR10DataModule(num_workers=12, batch_size=32) dm.train_transforms = SimCLRTrainDataTransform(input_height=32) dm.val_transforms = SimCLREvalDataTransform(input_height=32) # Model. model = SimCLR(gpus=2, num_samples=dm.num_samples, batch_size=dm.batch_size, dataset="cifar10") # Fit. trainer = pl.Trainer(gpus=2, accelerator="ddp") trainer.fit(model, datamodule=dm) #-------------------- # CIFAR-10 pretrained model: weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt" # ImageNet pretrained model: #weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt" simclr = SimCLR.load_from_checkpoint(weight_path, strict=False) simclr.freeze()
def train_transfer_learning(): logger = TensorBoardLogger('runs', name='pc-gita') batch_size = 32 input_height = 224 num_workers = 4 train_dataset = PcGitaTorchDataset(transform=SimCLRTrainDataTransform( input_height=input_height, gaussian_blur=False), train=True) val_dataset = PcGitaTorchDataset(transform=SimCLRTrainDataTransform( input_height=input_height, gaussian_blur=False), train=False) train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers) test_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers) model = ImagenetTransferLearning() checkpoint_callback = ModelCheckpoint( monitor="val_loss", dirpath=r'D:\Users\lVavrek\research\data', filename="transfer-learning-pcgita-{epoch:02d}-{val_loss:.2f}", save_top_k=1, mode="min", ) # early_stopping = EarlyStopping(monitor="val_loss") trainer = Trainer(gpus=1, callbacks=[checkpoint_callback], logger=logger, max_epochs=20) trainer.fit(model, train_loader, test_loader)
def byol_example(): from pl_bolts.models.self_supervised import BYOL from pl_bolts.datamodules import CIFAR10DataModule from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform, SimCLREvalDataTransform # Data module. dm = CIFAR10DataModule(num_workers=12, batch_size=32) dm.train_transforms = SimCLRTrainDataTransform(input_height=32) dm.val_transforms = SimCLREvalDataTransform(input_height=32) # Model. model = BYOL(num_classes=10) # Fit. trainer = pl.Trainer(gpus=2, accelerator="ddp") trainer.fit(model, datamodule=dm)
def simsiam_example(): from pl_bolts.models.self_supervised import SimSiam from pl_bolts.datamodules import CIFAR10DataModule from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform, SimCLREvalDataTransform # Data module. dm = CIFAR10DataModule(num_workers=12, batch_size=32) dm.train_transforms = SimCLRTrainDataTransform(input_height=32) dm.val_transforms = SimCLREvalDataTransform(input_height=32) # Model. model = SimSiam(gpus=2, num_samples=dm.num_samples, batch_size=dm.batch_size, dataset="cifar10") # Fit. trainer = pl.Trainer(gpus=2, accelerator="ddp") trainer.fit(model, datamodule=dm)
def train_self_supervised(): logger = TensorBoardLogger('runs', name='SimCLR_libri_speech') # 8, 224, 8 worked well # 16, 224, 4 as well batch_size = 16 input_height = 224 num_workers = 4 train_dataset = LibrispeechSpectrogramDataset( transform=SimCLRTrainDataTransform(input_height=input_height, gaussian_blur=False), train=True) val_dataset = LibrispeechSpectrogramDataset( transform=SimCLREvalDataTransform(input_height=input_height, gaussian_blur=False), train=False) train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers) test_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers) model = SimCLR(gpus=1, num_samples=len(train_dataset), batch_size=batch_size, dataset='librispeech') checkpoint_callback = ModelCheckpoint( monitor="val_loss", dirpath=r'D:\Users\lVavrek\research\data', filename="self-supervised-librispeech-{epoch:02d}-{val_loss:.2f}", save_top_k=3, mode="min", ) early_stopping = EarlyStopping(monitor="val_loss") trainer = Trainer(gpus=1, callbacks=[checkpoint_callback, early_stopping], logger=logger) trainer.fit(model, train_loader, test_loader)
def get_self_supervised_model(run_params): import pl_bolts from pl_bolts.models.self_supervised import SimCLR from pl_bolts.models.self_supervised.simclr import ( SimCLRTrainDataTransform, SimCLREvalDataTransform, ) from pytorch_lightning import Trainer from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks.early_stopping import EarlyStopping checkpoints_dir = os.path.join(run_params["PATH_PREFIX"], "checkpoints") checkpoint_resume = os.path.join(checkpoints_dir, run_params["MODEL_SAVE_NAME"] + ".ckpt") dataset = SelfSupervisedDataset( final_df, validation=False, transform=SimCLRTrainDataTransform( min(run_params["RESIZE"], run_params["RANDOM_RESIZE_CROP"])), prefix=run_params["RAW_PREPROCESS_FOLDER"] + "/", ) val_dataset = SelfSupervisedDataset( final_df, validation=True, transform=SimCLREvalDataTransform( min(run_params["RESIZE"], run_params["RANDOM_RESIZE_CROP"])), prefix=run_params["RAW_PREPROCESS_FOLDER"] + "/", ) data_loader = torch.utils.data.DataLoader( dataset, batch_size=run_params["SELF_SUPERVISED_BATCH_SIZE"], num_workers=0) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=run_params["SELF_SUPERVISED_BATCH_SIZE"], num_workers=0) num_samples = len(dataset) # #init model with batch size, num_samples (len of data), epochs to train, and autofinds learning rate model_self_sup = SimCLR( gpus=1, arch="resnet50", dataset="", max_epochs=run_params["SELF_SUPERVISED_EPOCHS"], warmup_epochs=run_params["SELF_SUPERVISED_WARMUP_EPOCHS"], batch_size=run_params["SELF_SUPERVISED_BATCH_SIZE"], num_samples=num_samples, ) if run_params["SELF_SUPERVISED_TRAIN"]: logger = TensorBoardLogger( os.path.join(run_params["PATH_PREFIX"], "tb_logs", "simCLR"), name=run_params["MODEL_SAVE_NAME"], ) early_stopping = EarlyStopping("val_loss", patience=5) if os.path.exists(checkpoint_resume): trainer = Trainer( gpus=1, max_epochs=run_params["SELF_SUPERVISED_EPOCHS"], logger=logger, auto_scale_batch_size=True, resume_from_checkpoint=checkpoint_resume, callbacks=[early_stopping], ) else: checkpoint_callback = ModelCheckpoint( monitor="val_loss", dirpath=checkpoints_dir, filename=run_params["MODEL_SAVE_NAME"], save_top_k=1, mode="min", ) trainer = Trainer( gpus=1, max_epochs=run_params["SELF_SUPERVISED_EPOCHS"], logger=logger, auto_scale_batch_size=True, callbacks=[checkpoint_callback, early_stopping], ) trainer.fit(model_self_sup, data_loader, val_loader) model_self_sup = model_self_sup.load_from_checkpoint(checkpoint_resume) elif os.path.exists(checkpoint_resume): model_self_sup.load_from_checkpoint(checkpoint_resume) else: print( f"Not checkpoint found, so it could not load model from it\n{checkpoint_resume}" ) return model_self_sup
def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule from pl_bolts.models.self_supervised.simclr import SimCLREvalDataTransform, SimCLRTrainDataTransform seed_everything(1234) parser = ArgumentParser() # trainer args parser = pl.Trainer.add_argparse_args(parser) # model args parser = SimSiam.add_model_specific_args(parser) args = parser.parse_args() # pick data dm = None # init datamodule if args.dataset == "stl10": dm = STL10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed args.num_samples = dm.num_unlabeled_samples args.maxpool1 = False args.first_conv = True args.input_height = dm.size()[-1] normalization = stl10_normalization() args.gaussian_blur = True args.jitter_strength = 1.0 elif args.dataset == "cifar10": val_split = 5000 if args.num_nodes * args.gpus * args.batch_size > val_split: val_split = args.num_nodes * args.gpus * args.batch_size dm = CIFAR10DataModule( data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers, val_split=val_split, ) args.num_samples = dm.num_samples args.maxpool1 = False args.first_conv = False args.input_height = dm.size()[-1] args.temperature = 0.5 normalization = cifar10_normalization() args.gaussian_blur = False args.jitter_strength = 0.5 elif args.dataset == "imagenet": args.maxpool1 = True args.first_conv = True normalization = imagenet_normalization() args.gaussian_blur = True args.jitter_strength = 1.0 args.batch_size = 64 args.num_nodes = 8 args.gpus = 8 # per-node args.max_epochs = 800 args.optimizer = "sgd" args.lars_wrapper = True args.learning_rate = 4.8 args.final_lr = 0.0048 args.start_lr = 0.3 args.online_ft = True dm = ImagenetDataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) args.num_samples = dm.num_samples args.input_height = dm.size()[-1] else: raise NotImplementedError( "other datasets have not been implemented till now") dm.train_transforms = SimCLRTrainDataTransform( input_height=args.input_height, gaussian_blur=args.gaussian_blur, jitter_strength=args.jitter_strength, normalize=normalization, ) dm.val_transforms = SimCLREvalDataTransform( input_height=args.input_height, gaussian_blur=args.gaussian_blur, jitter_strength=args.jitter_strength, normalize=normalization, ) model = SimSiam(**args.__dict__) # finetune in real-time online_evaluator = None if args.online_ft: # online eval online_evaluator = SSLOnlineEvaluator( drop_p=0.0, hidden_dim=None, z_dim=args.hidden_mlp, num_classes=dm.num_classes, dataset=args.dataset, ) trainer = pl.Trainer.from_argparse_args( args, sync_batchnorm=True if args.gpus > 1 else False, callbacks=[online_evaluator] if args.online_ft else None, ) trainer.fit(model, datamodule=dm)
# trainer args parser = pl.Trainer.add_argparse_args(parser) # model args parser = BYOL.add_model_specific_args(parser) parser.add_argument('--input_size', type=int, default=32) args = parser.parse_args() # pick data dm = None # init default datamodule if args.dataset == 'cifar10': dm = CIFAR10DataModule.from_argparse_args(args) dm.train_transforms = SimCLRTrainDataTransform(args.input_size) dm.val_transforms = SimCLREvalDataTransform(args.input_size) args.num_classes = dm.num_classes dm.name_classes = ['plane', 'car', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck'] dm.num_channels = 3 elif args.dataset == 'stl10': dm = STL10DataModule.from_argparse_args(args) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) args.num_classes = dm.num_classes dm.num_channels = 3
def simple_simclr_example(): from pl_bolts.models.self_supervised import SimCLR from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform, SimCLREvalDataTransform # Load ResNet50 pretrained using SimCLR on ImageNet. weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt" simclr = SimCLR.load_from_checkpoint(weight_path, strict=False) #train_dataset = MyDataset(transforms=SimCLRTrainDataTransform()) #val_dataset = MyDataset(transforms=SimCLREvalDataTransform()) train_dataset = torchvision.datasets.CIFAR10("", train=True, download=True, transform=SimCLRTrainDataTransform()) val_dataset = torchvision.datasets.CIFAR10("", train=False, download=True, transform=SimCLREvalDataTransform()) # SimCLR needs a lot of compute! model = SimCLR(gpus=2, num_samples=len(train_dataset), batch_size=32, dataset="cifar10") trainer = pl.Trainer(gpus=2, accelerator="ddp") trainer.fit( model, torch.utils.data.DataLoader(train_dataset, batch_size=32, num_workers=12), torch.utils.data.DataLoader(val_dataset, batch_size=32, num_workers=12), ) #-------------------- simclr_resnet50 = simclr.encoder simclr_resnet50.eval() #my_dataset = SomeDataset() my_dataset = val_dataset for batch in my_dataset: x, y = batch out = simclr_resnet50(x)
# trainer = Trainer() # trainer.fit(model) # prediction # model = ImagenetTransferLearning.load_from_checkpoint(PATH) # model.freeze() # x = some_images_from_cifar10() # predictions = model(x) if __name__ == '__main__': visualize_spectrograms() # train_self_supervised() # train_transfer_learning() exit() d1 = LibrispeechSpectrogramDataset(transform=SimCLRTrainDataTransform( input_height=224, gaussian_blur=False), train=True) d2 = PcGitaTorchDataset(transform=SimCLRTrainDataTransform( input_height=224, gaussian_blur=False), train=True) dl1 = DataLoader(d1, batch_size=16, num_workers=4) dl2 = DataLoader(d2, batch_size=16, num_workers=4) x1 = next(iter(dl1)) print(len(x1)) print(x1[1]) print(len(x1[0])) x2 = next(iter(dl2))