示例#1
0
def load_simclr_imagenet(
    path_or_url:
    str = f"{ROOT_S3_BUCKET}/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt"
):
    simclr = SimCLR.load_from_checkpoint(path_or_url, strict=False)
    model_config = {'model': simclr.encoder, 'emb_size': 2048}
    return model_config
示例#2
0
def simple_simclr_example():
	from pl_bolts.models.self_supervised import SimCLR
	from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform, SimCLREvalDataTransform

	# Load ResNet50 pretrained using SimCLR on ImageNet.
	weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt"
	simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)

	#train_dataset = MyDataset(transforms=SimCLRTrainDataTransform())
	#val_dataset = MyDataset(transforms=SimCLREvalDataTransform())
	train_dataset = torchvision.datasets.CIFAR10("", train=True, download=True, transform=SimCLRTrainDataTransform())
	val_dataset = torchvision.datasets.CIFAR10("", train=False, download=True, transform=SimCLREvalDataTransform())

	# SimCLR needs a lot of compute!
	model = SimCLR(gpus=2, num_samples=len(train_dataset), batch_size=32, dataset="cifar10")

	trainer = pl.Trainer(gpus=2, accelerator="ddp")
	trainer.fit(
		model,
		torch.utils.data.DataLoader(train_dataset, batch_size=32, num_workers=12),
		torch.utils.data.DataLoader(val_dataset, batch_size=32, num_workers=12),
	)

	#--------------------
	simclr_resnet50 = simclr.encoder
	simclr_resnet50.eval()

	#my_dataset = SomeDataset()
	my_dataset = val_dataset
	for batch in my_dataset:
		x, y = batch
		out = simclr_resnet50(x)
示例#3
0
def simclr_example():
	from pl_bolts.models.self_supervised import SimCLR
	from pl_bolts.datamodules import CIFAR10DataModule
	from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform, SimCLREvalDataTransform

	# Data module.
	dm = CIFAR10DataModule(num_workers=12, batch_size=32)
	dm.train_transforms = SimCLRTrainDataTransform(input_height=32)
	dm.val_transforms = SimCLREvalDataTransform(input_height=32)

	# Model.
	model = SimCLR(gpus=2, num_samples=dm.num_samples, batch_size=dm.batch_size, dataset="cifar10")

	# Fit.
	trainer = pl.Trainer(gpus=2, accelerator="ddp")
	trainer.fit(model, datamodule=dm)

	#--------------------
	# CIFAR-10 pretrained model:
	weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt"
	# ImageNet pretrained model:
	#weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt"
	simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)

	simclr.freeze()
示例#4
0
 def load_simclr_imagenet(
     path_or_url:
     str = f"{ROOT_S3_BUCKET}/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt"
 ):
     simclr = SimCLR.load_from_checkpoint(path_or_url, strict=False)
     # remove the last two layers & turn it into a Sequential model
     backbone = nn.Sequential(*list(simclr.encoder.children())[:-2])
     return backbone, 2048
示例#5
0
def cli_main():  # pragma: no-cover
    pl.seed_everything(1234)

    parser = ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser.add_argument('--dataset',
                        type=str,
                        help='stl10, cifar10',
                        default='cifar10')
    parser.add_argument('--ckpt_path', type=str, help='path to ckpt')
    parser.add_argument('--data_dir',
                        type=str,
                        help='path to ckpt',
                        default=os.getcwd())
    args = parser.parse_args()

    # load the backbone
    backbone = SimCLR.load_from_checkpoint(args.ckpt_path, strict=False)

    # init default datamodule
    if args.dataset == 'cifar10':
        dm = CIFAR10DataModule.from_argparse_args(args)
        dm.train_transforms = SimCLRTrainDataTransform(32)
        dm.val_transforms = SimCLREvalDataTransform(32)
        dm.test_transforms = SimCLREvalDataTransform(32)
        args.num_samples = dm.num_samples

    elif args.dataset == 'stl10':
        dm = STL10DataModule.from_argparse_args(args)
        dm.train_dataloader = dm.train_dataloader_labeled
        dm.val_dataloader = dm.val_dataloader_labeled
        args.num_samples = dm.num_labeled_samples

        (c, h, w) = dm.size()
        dm.train_transforms = SimCLRTrainDataTransform(h)
        dm.val_transforms = SimCLREvalDataTransform(h)

    elif args.dataset == 'imagenet2012':
        dm = ImagenetDataModule.from_argparse_args(args, image_size=196)
        (c, h, w) = dm.size()
        dm.train_transforms = SimCLRTrainDataTransform(h)
        dm.val_transforms = SimCLREvalDataTransform(h)

    # finetune
    tuner = SSLFineTuner(backbone,
                         in_features=2048 * 2 * 2,
                         num_classes=dm.num_classes,
                         hidden_dim=None)
    trainer = pl.Trainer.from_argparse_args(args, early_stop_callback=True)
    trainer.fit(tuner, dm)

    trainer.test(datamodule=dm)
示例#6
0
    def __init__(
        self,
        weight_path='https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt'
    ):
        super().__init__()

        backbone = deepcopy(
            SimCLR.load_from_checkpoint(weight_path, strict=False).encoder)
        backbone.fc = nn.Identity()

        self.encoder = backbone

        self.freeze()  # freeze last block of resnet18
        self.inplanes = self.encoder.inplanes
示例#7
0
    def __init__(self):
        super().__init__()

        # init a pretrained resnet
        weight_path = r'D:\Users\lVavrek\research\data\sim-clr-backups\01112021-self-supervised-librispeech-epoch=19-val_loss=1.52.ckpt'
        simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)
        backbone = simclr.encoder

        # extract last layer
        num_filters = backbone.fc.in_features
        layers = list(backbone.children())[:-1]
        self.feature_extractor = torch.nn.Sequential(*layers)

        # use the pretrained model to classify destination problem (PD or healthy, 2 classes)
        num_target_classes = 2
        self.classifier = torch.nn.Linear(num_filters, num_target_classes)
        self.sigmoid = torch.nn.Sigmoid()

        self.celoss = torch.nn.CrossEntropyLoss()
示例#8
0
    def __init__(
        self,
        pre,
        weight_path='https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/simclr-cifar10-v1-exp12_87_52/epoch%3D960.ckpt'
    ):
        super().__init__()
        self.pre = pre
        self.weight_path = weight_path

        if pre:
            self.encoder = deepcopy(
                SimCLR.load_from_checkpoint(weight_path, strict=False).encoder)
        else:
            self.encoder = torchvision.models.resnet18(pretrained=False)

        self.freeze()

        self.avgpool = self.encoder.avgpool
        numft = self.encoder.fc.in_features
        self.fc = nn.Sequential(nn.Linear(in_features=numft, out_features=256),
                                nn.ReLU(),
                                nn.Linear(in_features=256, out_features=100),
                                nn.ReLU(),
                                nn.Linear(in_features=100, out_features=20))
示例#9
0
    parser.add_argument('--data_dir',
                        required=True,
                        type=str,
                        help='path to image data directory')
    parser.add_argument('--save_dir',
                        required=True,
                        type=str,
                        help='path to image data directory')
    parser.add_argument('--image_index', type=int, default=42)
    parser.add_argument('--n_images', type=int, default=20)
    parser.add_argument('--rgb', type=bool, default=True)
    args = parser.parse_args()

    image_paths = utils.recursive_folder_image_paths(args.data_dir)

    model = SimCLR.load_from_checkpoint(checkpoint_path=args.model_dir,
                                        strict=False)
    model_enc = model.encoder
    model_enc.eval()

    transform = transforms.Compose(
        [transforms.Resize((32, 32)),
         transforms.ToTensor()])
    y = np.empty((len(image_paths), 2048), float)

    for i, p in enumerate(tqdm(image_paths)):
        image = Image.open(p)
        if args.rgb:
            image = image.convert('RGB')
        image = transform(image).unsqueeze_(0)
        y_hat = model_enc(image)
        y_hat = y_hat[0].detach().numpy().reshape(1, -1)
示例#10
0
def get_self_supervised_model(run_params):
    import pl_bolts
    from pl_bolts.models.self_supervised import SimCLR
    from pl_bolts.models.self_supervised.simclr import (
        SimCLRTrainDataTransform,
        SimCLREvalDataTransform,
    )
    from pytorch_lightning import Trainer
    from pytorch_lightning.loggers import TensorBoardLogger
    from pytorch_lightning.callbacks import ModelCheckpoint
    from pytorch_lightning.callbacks.early_stopping import EarlyStopping

    checkpoints_dir = os.path.join(run_params["PATH_PREFIX"], "checkpoints")
    checkpoint_resume = os.path.join(checkpoints_dir,
                                     run_params["MODEL_SAVE_NAME"] + ".ckpt")

    dataset = SelfSupervisedDataset(
        final_df,
        validation=False,
        transform=SimCLRTrainDataTransform(
            min(run_params["RESIZE"], run_params["RANDOM_RESIZE_CROP"])),
        prefix=run_params["RAW_PREPROCESS_FOLDER"] + "/",
    )
    val_dataset = SelfSupervisedDataset(
        final_df,
        validation=True,
        transform=SimCLREvalDataTransform(
            min(run_params["RESIZE"], run_params["RANDOM_RESIZE_CROP"])),
        prefix=run_params["RAW_PREPROCESS_FOLDER"] + "/",
    )

    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=run_params["SELF_SUPERVISED_BATCH_SIZE"],
        num_workers=0)

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=run_params["SELF_SUPERVISED_BATCH_SIZE"],
        num_workers=0)
    num_samples = len(dataset)

    # #init model with batch size, num_samples (len of data), epochs to train, and autofinds learning rate
    model_self_sup = SimCLR(
        gpus=1,
        arch="resnet50",
        dataset="",
        max_epochs=run_params["SELF_SUPERVISED_EPOCHS"],
        warmup_epochs=run_params["SELF_SUPERVISED_WARMUP_EPOCHS"],
        batch_size=run_params["SELF_SUPERVISED_BATCH_SIZE"],
        num_samples=num_samples,
    )

    if run_params["SELF_SUPERVISED_TRAIN"]:
        logger = TensorBoardLogger(
            os.path.join(run_params["PATH_PREFIX"], "tb_logs", "simCLR"),
            name=run_params["MODEL_SAVE_NAME"],
        )
        early_stopping = EarlyStopping("val_loss", patience=5)

        if os.path.exists(checkpoint_resume):
            trainer = Trainer(
                gpus=1,
                max_epochs=run_params["SELF_SUPERVISED_EPOCHS"],
                logger=logger,
                auto_scale_batch_size=True,
                resume_from_checkpoint=checkpoint_resume,
                callbacks=[early_stopping],
            )
        else:
            checkpoint_callback = ModelCheckpoint(
                monitor="val_loss",
                dirpath=checkpoints_dir,
                filename=run_params["MODEL_SAVE_NAME"],
                save_top_k=1,
                mode="min",
            )

            trainer = Trainer(
                gpus=1,
                max_epochs=run_params["SELF_SUPERVISED_EPOCHS"],
                logger=logger,
                auto_scale_batch_size=True,
                callbacks=[checkpoint_callback, early_stopping],
            )

        trainer.fit(model_self_sup, data_loader, val_loader)
        model_self_sup = model_self_sup.load_from_checkpoint(checkpoint_resume)
    elif os.path.exists(checkpoint_resume):
        model_self_sup.load_from_checkpoint(checkpoint_resume)
    else:
        print(
            f"Not checkpoint found, so it could not load model from it\n{checkpoint_resume}"
        )

    return model_self_sup