def load_simclr_imagenet( path_or_url: str = f"{ROOT_S3_BUCKET}/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt" ): simclr = SimCLR.load_from_checkpoint(path_or_url, strict=False) model_config = {'model': simclr.encoder, 'emb_size': 2048} return model_config
def simple_simclr_example(): from pl_bolts.models.self_supervised import SimCLR from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform, SimCLREvalDataTransform # Load ResNet50 pretrained using SimCLR on ImageNet. weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt" simclr = SimCLR.load_from_checkpoint(weight_path, strict=False) #train_dataset = MyDataset(transforms=SimCLRTrainDataTransform()) #val_dataset = MyDataset(transforms=SimCLREvalDataTransform()) train_dataset = torchvision.datasets.CIFAR10("", train=True, download=True, transform=SimCLRTrainDataTransform()) val_dataset = torchvision.datasets.CIFAR10("", train=False, download=True, transform=SimCLREvalDataTransform()) # SimCLR needs a lot of compute! model = SimCLR(gpus=2, num_samples=len(train_dataset), batch_size=32, dataset="cifar10") trainer = pl.Trainer(gpus=2, accelerator="ddp") trainer.fit( model, torch.utils.data.DataLoader(train_dataset, batch_size=32, num_workers=12), torch.utils.data.DataLoader(val_dataset, batch_size=32, num_workers=12), ) #-------------------- simclr_resnet50 = simclr.encoder simclr_resnet50.eval() #my_dataset = SomeDataset() my_dataset = val_dataset for batch in my_dataset: x, y = batch out = simclr_resnet50(x)
def simclr_example(): from pl_bolts.models.self_supervised import SimCLR from pl_bolts.datamodules import CIFAR10DataModule from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform, SimCLREvalDataTransform # Data module. dm = CIFAR10DataModule(num_workers=12, batch_size=32) dm.train_transforms = SimCLRTrainDataTransform(input_height=32) dm.val_transforms = SimCLREvalDataTransform(input_height=32) # Model. model = SimCLR(gpus=2, num_samples=dm.num_samples, batch_size=dm.batch_size, dataset="cifar10") # Fit. trainer = pl.Trainer(gpus=2, accelerator="ddp") trainer.fit(model, datamodule=dm) #-------------------- # CIFAR-10 pretrained model: weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt" # ImageNet pretrained model: #weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt" simclr = SimCLR.load_from_checkpoint(weight_path, strict=False) simclr.freeze()
def load_simclr_imagenet( path_or_url: str = f"{ROOT_S3_BUCKET}/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt" ): simclr = SimCLR.load_from_checkpoint(path_or_url, strict=False) # remove the last two layers & turn it into a Sequential model backbone = nn.Sequential(*list(simclr.encoder.children())[:-2]) return backbone, 2048
def cli_main(): # pragma: no-cover pl.seed_everything(1234) parser = ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) parser.add_argument('--dataset', type=str, help='stl10, cifar10', default='cifar10') parser.add_argument('--ckpt_path', type=str, help='path to ckpt') parser.add_argument('--data_dir', type=str, help='path to ckpt', default=os.getcwd()) args = parser.parse_args() # load the backbone backbone = SimCLR.load_from_checkpoint(args.ckpt_path, strict=False) # init default datamodule if args.dataset == 'cifar10': dm = CIFAR10DataModule.from_argparse_args(args) dm.train_transforms = SimCLRTrainDataTransform(32) dm.val_transforms = SimCLREvalDataTransform(32) dm.test_transforms = SimCLREvalDataTransform(32) args.num_samples = dm.num_samples elif args.dataset == 'stl10': dm = STL10DataModule.from_argparse_args(args) dm.train_dataloader = dm.train_dataloader_labeled dm.val_dataloader = dm.val_dataloader_labeled args.num_samples = dm.num_labeled_samples (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) elif args.dataset == 'imagenet2012': dm = ImagenetDataModule.from_argparse_args(args, image_size=196) (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) # finetune tuner = SSLFineTuner(backbone, in_features=2048 * 2 * 2, num_classes=dm.num_classes, hidden_dim=None) trainer = pl.Trainer.from_argparse_args(args, early_stop_callback=True) trainer.fit(tuner, dm) trainer.test(datamodule=dm)
def __init__( self, weight_path='https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt' ): super().__init__() backbone = deepcopy( SimCLR.load_from_checkpoint(weight_path, strict=False).encoder) backbone.fc = nn.Identity() self.encoder = backbone self.freeze() # freeze last block of resnet18 self.inplanes = self.encoder.inplanes
def __init__(self): super().__init__() # init a pretrained resnet weight_path = r'D:\Users\lVavrek\research\data\sim-clr-backups\01112021-self-supervised-librispeech-epoch=19-val_loss=1.52.ckpt' simclr = SimCLR.load_from_checkpoint(weight_path, strict=False) backbone = simclr.encoder # extract last layer num_filters = backbone.fc.in_features layers = list(backbone.children())[:-1] self.feature_extractor = torch.nn.Sequential(*layers) # use the pretrained model to classify destination problem (PD or healthy, 2 classes) num_target_classes = 2 self.classifier = torch.nn.Linear(num_filters, num_target_classes) self.sigmoid = torch.nn.Sigmoid() self.celoss = torch.nn.CrossEntropyLoss()
def __init__( self, pre, weight_path='https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/simclr-cifar10-v1-exp12_87_52/epoch%3D960.ckpt' ): super().__init__() self.pre = pre self.weight_path = weight_path if pre: self.encoder = deepcopy( SimCLR.load_from_checkpoint(weight_path, strict=False).encoder) else: self.encoder = torchvision.models.resnet18(pretrained=False) self.freeze() self.avgpool = self.encoder.avgpool numft = self.encoder.fc.in_features self.fc = nn.Sequential(nn.Linear(in_features=numft, out_features=256), nn.ReLU(), nn.Linear(in_features=256, out_features=100), nn.ReLU(), nn.Linear(in_features=100, out_features=20))
parser.add_argument('--data_dir', required=True, type=str, help='path to image data directory') parser.add_argument('--save_dir', required=True, type=str, help='path to image data directory') parser.add_argument('--image_index', type=int, default=42) parser.add_argument('--n_images', type=int, default=20) parser.add_argument('--rgb', type=bool, default=True) args = parser.parse_args() image_paths = utils.recursive_folder_image_paths(args.data_dir) model = SimCLR.load_from_checkpoint(checkpoint_path=args.model_dir, strict=False) model_enc = model.encoder model_enc.eval() transform = transforms.Compose( [transforms.Resize((32, 32)), transforms.ToTensor()]) y = np.empty((len(image_paths), 2048), float) for i, p in enumerate(tqdm(image_paths)): image = Image.open(p) if args.rgb: image = image.convert('RGB') image = transform(image).unsqueeze_(0) y_hat = model_enc(image) y_hat = y_hat[0].detach().numpy().reshape(1, -1)
def get_self_supervised_model(run_params): import pl_bolts from pl_bolts.models.self_supervised import SimCLR from pl_bolts.models.self_supervised.simclr import ( SimCLRTrainDataTransform, SimCLREvalDataTransform, ) from pytorch_lightning import Trainer from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks.early_stopping import EarlyStopping checkpoints_dir = os.path.join(run_params["PATH_PREFIX"], "checkpoints") checkpoint_resume = os.path.join(checkpoints_dir, run_params["MODEL_SAVE_NAME"] + ".ckpt") dataset = SelfSupervisedDataset( final_df, validation=False, transform=SimCLRTrainDataTransform( min(run_params["RESIZE"], run_params["RANDOM_RESIZE_CROP"])), prefix=run_params["RAW_PREPROCESS_FOLDER"] + "/", ) val_dataset = SelfSupervisedDataset( final_df, validation=True, transform=SimCLREvalDataTransform( min(run_params["RESIZE"], run_params["RANDOM_RESIZE_CROP"])), prefix=run_params["RAW_PREPROCESS_FOLDER"] + "/", ) data_loader = torch.utils.data.DataLoader( dataset, batch_size=run_params["SELF_SUPERVISED_BATCH_SIZE"], num_workers=0) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=run_params["SELF_SUPERVISED_BATCH_SIZE"], num_workers=0) num_samples = len(dataset) # #init model with batch size, num_samples (len of data), epochs to train, and autofinds learning rate model_self_sup = SimCLR( gpus=1, arch="resnet50", dataset="", max_epochs=run_params["SELF_SUPERVISED_EPOCHS"], warmup_epochs=run_params["SELF_SUPERVISED_WARMUP_EPOCHS"], batch_size=run_params["SELF_SUPERVISED_BATCH_SIZE"], num_samples=num_samples, ) if run_params["SELF_SUPERVISED_TRAIN"]: logger = TensorBoardLogger( os.path.join(run_params["PATH_PREFIX"], "tb_logs", "simCLR"), name=run_params["MODEL_SAVE_NAME"], ) early_stopping = EarlyStopping("val_loss", patience=5) if os.path.exists(checkpoint_resume): trainer = Trainer( gpus=1, max_epochs=run_params["SELF_SUPERVISED_EPOCHS"], logger=logger, auto_scale_batch_size=True, resume_from_checkpoint=checkpoint_resume, callbacks=[early_stopping], ) else: checkpoint_callback = ModelCheckpoint( monitor="val_loss", dirpath=checkpoints_dir, filename=run_params["MODEL_SAVE_NAME"], save_top_k=1, mode="min", ) trainer = Trainer( gpus=1, max_epochs=run_params["SELF_SUPERVISED_EPOCHS"], logger=logger, auto_scale_batch_size=True, callbacks=[checkpoint_callback, early_stopping], ) trainer.fit(model_self_sup, data_loader, val_loader) model_self_sup = model_self_sup.load_from_checkpoint(checkpoint_resume) elif os.path.exists(checkpoint_resume): model_self_sup.load_from_checkpoint(checkpoint_resume) else: print( f"Not checkpoint found, so it could not load model from it\n{checkpoint_resume}" ) return model_self_sup