Beispiel #1
0
def simple_simclr_example():
	from pl_bolts.models.self_supervised import SimCLR
	from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform, SimCLREvalDataTransform

	# Load ResNet50 pretrained using SimCLR on ImageNet.
	weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt"
	simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)

	#train_dataset = MyDataset(transforms=SimCLRTrainDataTransform())
	#val_dataset = MyDataset(transforms=SimCLREvalDataTransform())
	train_dataset = torchvision.datasets.CIFAR10("", train=True, download=True, transform=SimCLRTrainDataTransform())
	val_dataset = torchvision.datasets.CIFAR10("", train=False, download=True, transform=SimCLREvalDataTransform())

	# SimCLR needs a lot of compute!
	model = SimCLR(gpus=2, num_samples=len(train_dataset), batch_size=32, dataset="cifar10")

	trainer = pl.Trainer(gpus=2, accelerator="ddp")
	trainer.fit(
		model,
		torch.utils.data.DataLoader(train_dataset, batch_size=32, num_workers=12),
		torch.utils.data.DataLoader(val_dataset, batch_size=32, num_workers=12),
	)

	#--------------------
	simclr_resnet50 = simclr.encoder
	simclr_resnet50.eval()

	#my_dataset = SomeDataset()
	my_dataset = val_dataset
	for batch in my_dataset:
		x, y = batch
		out = simclr_resnet50(x)
Beispiel #2
0
def simclr_example():
	from pl_bolts.models.self_supervised import SimCLR
	from pl_bolts.datamodules import CIFAR10DataModule
	from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform, SimCLREvalDataTransform

	# Data module.
	dm = CIFAR10DataModule(num_workers=12, batch_size=32)
	dm.train_transforms = SimCLRTrainDataTransform(input_height=32)
	dm.val_transforms = SimCLREvalDataTransform(input_height=32)

	# Model.
	model = SimCLR(gpus=2, num_samples=dm.num_samples, batch_size=dm.batch_size, dataset="cifar10")

	# Fit.
	trainer = pl.Trainer(gpus=2, accelerator="ddp")
	trainer.fit(model, datamodule=dm)

	#--------------------
	# CIFAR-10 pretrained model:
	weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt"
	# ImageNet pretrained model:
	#weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt"
	simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)

	simclr.freeze()
def load_simclr_imagenet(
    path_or_url:
    str = f"{ROOT_S3_BUCKET}/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt"
):
    simclr = SimCLR.load_from_checkpoint(path_or_url, strict=False)
    model_config = {'model': simclr.encoder, 'emb_size': 2048}
    return model_config
def test_simclr(tmpdir, datadir):
    datamodule = CIFAR10DataModule(data_dir=datadir, num_workers=0, batch_size=2)
    datamodule.train_transforms = SimCLRTrainDataTransform(32)
    datamodule.val_transforms = SimCLREvalDataTransform(32)

    model = SimCLR(batch_size=2, num_samples=datamodule.num_samples, gpus=0, nodes=1, dataset='cifar10')
    trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir)
    trainer.fit(model, datamodule=datamodule)
Beispiel #5
0
 def load_simclr_imagenet(
     path_or_url:
     str = f"{ROOT_S3_BUCKET}/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt"
 ):
     simclr = SimCLR.load_from_checkpoint(path_or_url, strict=False)
     # remove the last two layers & turn it into a Sequential model
     backbone = nn.Sequential(*list(simclr.encoder.children())[:-2])
     return backbone, 2048
Beispiel #6
0
def cli_main():  # pragma: no-cover
    pl.seed_everything(1234)

    parser = ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser.add_argument('--dataset',
                        type=str,
                        help='stl10, cifar10',
                        default='cifar10')
    parser.add_argument('--ckpt_path', type=str, help='path to ckpt')
    parser.add_argument('--data_dir',
                        type=str,
                        help='path to ckpt',
                        default=os.getcwd())
    args = parser.parse_args()

    # load the backbone
    backbone = SimCLR.load_from_checkpoint(args.ckpt_path, strict=False)

    # init default datamodule
    if args.dataset == 'cifar10':
        dm = CIFAR10DataModule.from_argparse_args(args)
        dm.train_transforms = SimCLRTrainDataTransform(32)
        dm.val_transforms = SimCLREvalDataTransform(32)
        dm.test_transforms = SimCLREvalDataTransform(32)
        args.num_samples = dm.num_samples

    elif args.dataset == 'stl10':
        dm = STL10DataModule.from_argparse_args(args)
        dm.train_dataloader = dm.train_dataloader_labeled
        dm.val_dataloader = dm.val_dataloader_labeled
        args.num_samples = dm.num_labeled_samples

        (c, h, w) = dm.size()
        dm.train_transforms = SimCLRTrainDataTransform(h)
        dm.val_transforms = SimCLREvalDataTransform(h)

    elif args.dataset == 'imagenet2012':
        dm = ImagenetDataModule.from_argparse_args(args, image_size=196)
        (c, h, w) = dm.size()
        dm.train_transforms = SimCLRTrainDataTransform(h)
        dm.val_transforms = SimCLREvalDataTransform(h)

    # finetune
    tuner = SSLFineTuner(backbone,
                         in_features=2048 * 2 * 2,
                         num_classes=dm.num_classes,
                         hidden_dim=None)
    trainer = pl.Trainer.from_argparse_args(args, early_stop_callback=True)
    trainer.fit(tuner, dm)

    trainer.test(datamodule=dm)
Beispiel #7
0
def test_simclr(tmpdir):
    seed_everything()

    datamodule = CIFAR10DataModule(tmpdir, num_workers=0, batch_size=2)
    datamodule.train_transforms = SimCLRTrainDataTransform(32)
    datamodule.val_transforms = SimCLREvalDataTransform(32)

    model = SimCLR(batch_size=2, num_samples=datamodule.num_samples)
    trainer = pl.Trainer(fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir)
    trainer.fit(model, datamodule)
    loss = trainer.progress_bar_dict['loss']

    assert float(loss) > 0
Beispiel #8
0
def test_simclr(tmpdir):
    reset_seed()

    datamodule = CIFAR10DataModule(tmpdir, num_workers=0)
    datamodule.train_transforms = SimCLRTrainDataTransform(32)
    datamodule.val_transforms = SimCLREvalDataTransform(32)

    model = SimCLR(data_dir=tmpdir, batch_size=2, datamodule=datamodule, online_ft=True)
    trainer = pl.Trainer(overfit_batches=2, max_epochs=1, default_root_dir=tmpdir)
    trainer.fit(model)
    loss = trainer.callback_metrics['loss']

    assert loss > 0
    def __init__(
        self,
        weight_path='https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt'
    ):
        super().__init__()

        backbone = deepcopy(
            SimCLR.load_from_checkpoint(weight_path, strict=False).encoder)
        backbone.fc = nn.Identity()

        self.encoder = backbone

        self.freeze()  # freeze last block of resnet18
        self.inplanes = self.encoder.inplanes
Beispiel #10
0
def train_self_supervised():
    logger = TensorBoardLogger('runs', name='SimCLR_libri_speech')

    # 8, 224, 8 worked well
    # 16, 224, 4 as well
    batch_size = 16
    input_height = 224
    num_workers = 4

    train_dataset = LibrispeechSpectrogramDataset(
        transform=SimCLRTrainDataTransform(input_height=input_height,
                                           gaussian_blur=False),
        train=True)
    val_dataset = LibrispeechSpectrogramDataset(
        transform=SimCLREvalDataTransform(input_height=input_height,
                                          gaussian_blur=False),
        train=False)

    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              num_workers=num_workers)
    test_loader = DataLoader(val_dataset,
                             batch_size=batch_size,
                             num_workers=num_workers)

    model = SimCLR(gpus=1,
                   num_samples=len(train_dataset),
                   batch_size=batch_size,
                   dataset='librispeech')

    checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        dirpath=r'D:\Users\lVavrek\research\data',
        filename="self-supervised-librispeech-{epoch:02d}-{val_loss:.2f}",
        save_top_k=3,
        mode="min",
    )

    early_stopping = EarlyStopping(monitor="val_loss")

    trainer = Trainer(gpus=1,
                      callbacks=[checkpoint_callback, early_stopping],
                      logger=logger)
    trainer.fit(model, train_loader, test_loader)
Beispiel #11
0
    def __init__(self):
        super().__init__()

        # init a pretrained resnet
        weight_path = r'D:\Users\lVavrek\research\data\sim-clr-backups\01112021-self-supervised-librispeech-epoch=19-val_loss=1.52.ckpt'
        simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)
        backbone = simclr.encoder

        # extract last layer
        num_filters = backbone.fc.in_features
        layers = list(backbone.children())[:-1]
        self.feature_extractor = torch.nn.Sequential(*layers)

        # use the pretrained model to classify destination problem (PD or healthy, 2 classes)
        num_target_classes = 2
        self.classifier = torch.nn.Linear(num_filters, num_target_classes)
        self.sigmoid = torch.nn.Sigmoid()

        self.celoss = torch.nn.CrossEntropyLoss()
Beispiel #12
0
    def __init__(
        self,
        pre,
        weight_path='https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/simclr-cifar10-v1-exp12_87_52/epoch%3D960.ckpt'
    ):
        super().__init__()
        self.pre = pre
        self.weight_path = weight_path

        if pre:
            self.encoder = deepcopy(
                SimCLR.load_from_checkpoint(weight_path, strict=False).encoder)
        else:
            self.encoder = torchvision.models.resnet18(pretrained=False)

        self.freeze()

        self.avgpool = self.encoder.avgpool
        numft = self.encoder.fc.in_features
        self.fc = nn.Sequential(nn.Linear(in_features=numft, out_features=256),
                                nn.ReLU(),
                                nn.Linear(in_features=256, out_features=100),
                                nn.ReLU(),
                                nn.Linear(in_features=100, out_features=20))
Beispiel #13
0
def cli_main():
    parser = ArgumentParser()
    parser.add_argument("--DATA_PATH", type=str, help="path to folders with images")
    parser.add_argument("--MODEL_PATH", default=None , type=str, help="path to model checkpoint")
    parser.add_argument("--batch_size", default=128, type=int, help="batch size for SSL")
    parser.add_argument("--image_size", default=256, type=int, help="image size for SSL")
    parser.add_argument("--num_workers", default=1, type=int, help="number of CPU cores to use for data processing")
    parser.add_argument("--image_embedding_size", default=128, type=int, help="size of image representation of SIMCLR")
    parser.add_argument("--hidden_dims", default=128, type=int, help="hidden dimensions in classification layer added onto model for finetuning")
    parser.add_argument("--epochs", default=200, type=int, help="number of epochs to train model")
    parser.add_argument("--lr", default=1e-3, type=float, help="learning rate for training model")
    parser.add_argument("--patience", default=-1, type=int, help="automatically cuts off training if validation does not drop for (patience) epochs. Leave blank to have no validation based early stopping.")
    parser.add_argument("--val_split", default=0.2, type=float, help="percent in validation data")
    parser.add_argument("--withold_train_percent", default=0, type=float, help="decimal from 0-1 representing how much of the training data to withold during finetuning")
    parser.add_argument("--gpus", default=1, type=int, help="number of gpus to use for training")
    parser.add_argument("--eval", default=True, type=bool, help="Eval Mode will train and evaluate the finetuned model's performance")
    parser.add_argument("--pretrain_encoder", default=False, type=bool, help="initialize resnet encoder with pretrained imagenet weights. Ignored if MODEL_PATH is specified.")
    parser.add_argument("--version", default="0", type=str, help="version to name checkpoint for saving")
    parser.add_argument("--fix_backbone", default=True, type=bool, help="Fix backbone during finetuning")
    
    args = parser.parse_args()
    URL = args.DATA_PATH
    batch_size = args.batch_size
    image_size = args.image_size
    num_workers = args.num_workers
    embedding_size = args.image_embedding_size
    hidden_dims = args.hidden_dims
    epochs = args.epochs
    lr = args.lr
    patience = args.patience
    val_split = args.val_split
    withold_train_percent = args.withold_train_percent
    version = args.version
    model_checkpoint = args.MODEL_PATH
    gpus = args.gpus
    eval_model = args.eval
    version = args.version
    pretrain= args.pretrain_encoder
    fix_backbone = args.fix_backbone

    train_transform = SimCLRFinetuneTransform(256, eval_transform=False)
    val_transform = SimCLRFinetuneTransform(256, eval_transform=True)
    dm = ImageDataModule(URL, train_transform = train_transform, val_transform = val_transform, val_split = val_split)
    dm.setup()

    #init model with batch size, num_samples (len of data), epochs to train, and autofinds learning rate
    model = SimCLR(arch = 'resnet18', batch_size = batch_size, num_samples = dm.num_samples, gpus = gpus, dataset = 'None', max_epochs = epochs, learning_rate = lr) #
    
    model.encoder = resnet18(pretrained=pretrain, first_conv=model.first_conv, maxpool1=model.maxpool1, return_all_feature_maps=False)
    model.projection = Projection(input_dim = 512, hidden_dim = 256, output_dim = embedding_size) #overrides
       
    if model_checkpoint is not None:  
        model.load_state_dict(torch.load(model_checkpoint))
        print('Successfully loaded your checkpoint. Keep in mind that this does not preserve the previous trainer states, only the model weights')
    else:
        if pretrain:   
            print('Using imagenet weights instead of a pretrained SSL model')
        else:
            print('Using random initialization of encoder')
        
    print('Finetuning to classify ', dm.num_classes, ' Classes')

        tuner = SSLFineTuner(
        backbone,
        in_features=512,
        num_classes=dm.num_classes,
        epochs=epochs,
        hidden_dim=None,
        dropout=0,
        learning_rate=0.3,
        weight_decay=1e-6,
        nesterov=False,
        scheduler_type='cosine',
        gamma=0.1,
        final_lr=0.,
        fix_backbone = fix_backbone
    )
Beispiel #14
0
    if os.path.exists(os.path.join(args.data_dir, 'train')):
        train_set = ImageFolderDataset(os.path.join(args.data_dir, 'train'))
        test_set = ImageFolderDataset(os.path.join(args.data_dir, 'test'))
    else:
        files = utils.recursive_folder_image_paths(args.data_dir)
        random.seed(19)
        random.shuffle(files)
        train_files = files[:int(train_test_ratio * len(files))]
        test_files = files[int(train_test_ratio * len(files)):]
        train_set = ImageFilesDataset(train_files)
        test_set = ImageFilesDataset(test_files)

    train_loader = DataLoader(train_set,
                              batch_size=12,
                              shuffle=True,
                              num_workers=4)
    test_loader = DataLoader(test_set,
                             batch_size=12,
                             shuffle=False,
                             num_workers=4)

    model = SimCLR(gpus=1,
                   num_samples=(len(train_set) + len(test_set)),
                   batch_size=12,
                   dataset=train_loader)

    trainer = pl.Trainer(gpus=1)
    trainer.fit(model, train_loader, test_loader)
    model.freeze()
Beispiel #15
0
def main():

    parser = ArgumentParser()
    parser.add_argument("--data_dir",
                        type=str,
                        required=True,
                        help="path to the folder of images")
    parser.add_argument("--log_dir",
                        type=str,
                        required=True,
                        help="output training logging dir")
    parser.add_argument("--learning_rate",
                        type=float,
                        required=True,
                        default=1e-3,
                        help="learning rate")
    parser.add_argument(
        "--input_height",
        type=int,
        required=True,
        help="height of image input to SimCLR",
    )
    parser.add_argument("--batch_size", type=int, default=1024, required=True)
    parser.add_argument("--gpus",
                        type=int,
                        default=0,
                        required=True,
                        help="Number of GPUs")
    parser.add_argument("--num_workers",
                        type=int,
                        default=0,
                        required=True,
                        help="Number of dataloader workers")
    parser.add_argument("--max_epochs",
                        default=100,
                        type=int,
                        help="number of total epochs to run")

    args = parser.parse_args()

    dm = SneakerDataModule(image_folder=args.data_dir,
                           batch_size=args.batch_size,
                           num_workers=args.num_workers)
    dm.train_transforms = SimCLRTrainDataTransform(args.input_height)
    dm.val_transforms = SimCLREvalDataTransform(args.input_height)

    model = SimCLR(
        num_samples=dm.num_samples,
        batch_size=dm.batch_size,
        learning_rate=args.learning_rate,
        max_epochs=args.max_epochs,
        gpus=args.gpus,
        dataset="sneakers",
    )

    model_checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        save_last=True,
        save_top_k=-1,
        period=10,
        filename='{epoch}-{val_loss:.2f}-{step}')

    # TODO set the logger folder
    # Warning message is "Missing logger folder: /lightning_logs"
    trainer = pl.Trainer(
        default_root_dir=args.log_dir,
        callbacks=[model_checkpoint_callback],
        # checkpoint_callback=True,  # configures a default checkpointing callback
        max_epochs=args.max_epochs,
        gpus=args.gpus,
        accelerator='ddp' if args.gpus > 1 else None,
        enable_pl_optimizer=True if args.gpus > 1 else False,
    )

    trainer.fit(model, dm)
Beispiel #16
0
def cli_main():
    parser = ArgumentParser()
    parser.add_argument("--DATA_PATH",
                        type=str,
                        help="path to folders with images")
    parser.add_argument("--MODEL_PATH",
                        default=None,
                        type=str,
                        help="path to model checkpoint")
    parser.add_argument("--batch_size",
                        default=128,
                        type=int,
                        help="batch size for SSL")
    parser.add_argument("--image_size",
                        default=256,
                        type=int,
                        help="image size for SSL")
    parser.add_argument("--image_embedding_size",
                        default=128,
                        type=int,
                        help="size of image representation of SIMCLR")
    parser.add_argument(
        "--hidden_dims",
        default=128,
        type=int,
        help=
        "hidden dimensions in classification layer added onto model for finetuning"
    )
    parser.add_argument("--epochs",
                        default=200,
                        type=int,
                        help="number of epochs to train model")
    parser.add_argument("--lr",
                        default=0.3,
                        type=float,
                        help="learning rate for training model")
    parser.add_argument(
        "--patience",
        default=-1,
        type=int,
        help=
        "automatically cuts off training if validation does not drop for (patience) epochs. Leave blank to have no validation based early stopping."
    )
    parser.add_argument("--val_split",
                        default=0.2,
                        type=float,
                        help="percent in validation data")
    parser.add_argument(
        "--withold_train_percent",
        default=0,
        type=float,
        help=
        "decimal from 0-1 representing how much of the training data to withold during finetuning"
    )
    parser.add_argument("--gpus",
                        default=1,
                        type=int,
                        help="number of gpus to use for training")
    parser.add_argument(
        "--eval",
        default=True,
        type=bool,
        help=
        "Eval Mode will train and evaluate the finetuned model's performance")
    parser.add_argument(
        "--pretrain_encoder",
        default=False,
        type=bool,
        help=
        "initialize resnet encoder with pretrained imagenet weights. Ignored if MODEL_PATH is specified."
    )
    parser.add_argument("--version",
                        default="0",
                        type=str,
                        help="version to name checkpoint for saving")
    parser.add_argument("--fix_backbone",
                        default=True,
                        type=bool,
                        help="Fix backbone during finetuning")
    parser.add_argument("--num_workers",
                        default=0,
                        type=int,
                        help="number of workers to use to fetch data")

    args = parser.parse_args()
    URL = args.DATA_PATH
    batch_size = args.batch_size
    image_size = args.image_size
    embedding_size = args.image_embedding_size
    hidden_dims = args.hidden_dims
    epochs = args.epochs
    lr = args.lr
    patience = args.patience
    val_split = args.val_split
    withold_train_percent = args.withold_train_percent
    version = args.version
    model_checkpoint = args.MODEL_PATH
    gpus = args.gpus
    eval_model = args.eval
    version = args.version
    pretrain = args.pretrain_encoder
    fix_backbone = args.fix_backbone
    num_workers = args.num_workers

    dm = FolderDataset2(URL,
                        val_split=val_split,
                        train_transform=SimCLRFinetuneTransform(image_size),
                        val_transform=SimCLRFinetuneTransform(image_size))
    dm.setup()

    model = SimCLR(arch='resnet18',
                   batch_size=batch_size,
                   num_samples=dm.num_samples,
                   gpus=1,
                   dataset='None',
                   max_epochs=100,
                   learning_rate=lr)  #
    model.projection = Projection(input_dim=512,
                                  hidden_dim=256,
                                  output_dim=128)  #overrides
    model.encoder = resnet18(pretrained=pretrain,
                             first_conv=model.first_conv,
                             maxpool1=model.maxpool1,
                             return_all_feature_maps=False)
    if model_checkpoint is not None:
        model.load_state_dict(torch.load(model_checkpoint))
        print(
            'Successfully loaded your checkpoint. Keep in mind that this does not preserve the previous trainer states, only the model weights'
        )
    else:
        if pretrain:
            print('Using imagenet weights instead of a pretrained SSL model')
        else:
            print('Using random initialization of encoder')

    print('Finetuning to classify ', dm.num_classes, ' Classes')

    tuner = SSLFineTuner(model,
                         in_features=512,
                         num_classes=dm.num_classes,
                         epochs=epochs,
                         hidden_dim=hidden_dims,
                         dropout=0,
                         learning_rate=lr,
                         weight_decay=1e-6,
                         nesterov=False,
                         scheduler_type='cosine',
                         gamma=0.1,
                         final_lr=0.,
                         fix_backbone=True)

    trainer = pl.Trainer(gpus=gpus,
                         num_nodes=1,
                         precision=16,
                         max_epochs=epochs,
                         distributed_backend='ddp',
                         sync_batchnorm=False)

    trainer.fit(tuner, dm)

    Path(f"./models/Finetune/SIMCLR_Finetune_{version}").mkdir(parents=True,
                                                               exist_ok=True)

    if eval_model:
        print('Evaluating Model...')
        save_path = f"./models/Finetune/SIMCLR_Finetune_{version}/Evaluation/validationMetrics"
        Path(save_path).mkdir(parents=True, exist_ok=True)

        if dm.val_dataloader() is not None:
            eval_finetune(tuner, 'validation', dm.val_dataloader(), save_path)

        save_path = f"./models/Finetune/SIMCLR_Finetune_{version}/Evaluation/trainingMetrics"
        Path(save_path).mkdir(parents=True, exist_ok=True)
        eval_finetune(tuner, 'training', dm.train_dataloader(), save_path)

    print('Saving model...')

    torch.save(
        tuner.state_dict(),
        f"./models/Finetune/SIMCLR_Finetune_{version}/SIMCLR_FINETUNE_{version}.pt"
    )
def cli_main():

    parser = ArgumentParser()
    parser.add_argument("--DATA_PATH",
                        type=str,
                        help="path to folders with images")
    parser.add_argument("--MODEL_PATH",
                        default=None,
                        type=str,
                        help="path to model checkpoint")
    parser.add_argument("--batch_size",
                        default=128,
                        type=int,
                        help="batch size for SSL")
    parser.add_argument("--image_size",
                        default=256,
                        type=int,
                        help="image size for SSL")
    parser.add_argument(
        "--image_type",
        default="tif",
        type=str,
        help=
        "extension of image for PIL to open and parse - i.e. jpeg, gif, tif, etc. Only put the extension name, not the dot (.)"
    )
    parser.add_argument("--num_workers",
                        default=1,
                        type=int,
                        help="number of CPU cores to use for data processing")
    parser.add_argument("--image_embedding_size",
                        default=128,
                        type=int,
                        help="size of image representation of SIMCLR")
    parser.add_argument("--epochs",
                        default=200,
                        type=int,
                        help="number of epochs to train model")
    parser.add_argument("--lr",
                        default=1e-3,
                        type=float,
                        help="learning rate for training model")
    parser.add_argument(
        "--patience",
        default=-1,
        type=int,
        help=
        "automatically cuts off training if validation does not drop for (patience) epochs. Leave blank to have no validation based early stopping."
    )
    parser.add_argument("--val_split",
                        default=0.2,
                        type=float,
                        help="percent in validation data")
    parser.add_argument(
        "--pretrain_encoder",
        default=False,
        type=bool,
        help=
        "initialize resnet encoder with pretrained imagenet weights. Cannot be true if passing previous SSL model checkpoint."
    )
    parser.add_argument(
        "--withold_train_percent",
        default=0,
        type=float,
        help=
        "decimal from 0-1 representing how much of the training data to withold during SSL training"
    )
    parser.add_argument("--version",
                        default="0",
                        type=str,
                        help="version to name checkpoint for saving")
    parser.add_argument("--gpus",
                        default=1,
                        type=int,
                        help="number of gpus to use for training")

    args = parser.parse_args()
    URL = args.DATA_PATH
    batch_size = args.batch_size
    image_size = args.image_size
    image_type = args.image_type
    num_workers = args.num_workers
    embedding_size = args.image_embedding_size
    epochs = args.epochs
    lr = args.lr
    patience = args.patience
    val_split = args.val_split
    pretrain = args.pretrain_encoder
    withold_train_percent = args.withold_train_percent
    version = args.version
    model_checkpoint = args.MODEL_PATH
    gpus = args.gpus

    # #testing
    # batch_size = 128
    # image_type = 'tif'
    # image_size = 256
    # num_workers = 4
    # URL ='/content/UCMerced_LandUse/Images'
    # embedding_size = 128
    # epochs = 2
    # lr = 1e-3
    # patience = 1
    # val_split = 0.2
    # pretrain = False
    # withold_train_percent = 0.2
    # version = "1"
    # model_checkpoint = '/content/models/SSL/SIMCLR_SSL_0.pt'
    # gpus = 1

    # #gets dataset. We can't combine since validation data has different transform needed
    train_dataset = FolderDataset(
        URL,
        validation=False,
        val_split=val_split,
        withold_train_percent=withold_train_percent,
        transform=SimCLRTrainDataTransform(image_size),
        image_type=image_type)

    data_loader = torch.utils.data.DataLoader(train_dataset,
                                              batch_size=batch_size,
                                              num_workers=num_workers,
                                              drop_last=True)

    print('Training Data Loaded...')
    val_dataset = FolderDataset(URL,
                                validation=True,
                                val_split=val_split,
                                transform=SimCLREvalDataTransform(image_size),
                                image_type=image_type)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             num_workers=num_workers,
                                             drop_last=True)
    print('Validation Data Loaded...')

    num_samples = len(train_dataset)

    #init model with batch size, num_samples (len of data), epochs to train, and autofinds learning rate
    model = SimCLR(arch='resnet18',
                   batch_size=batch_size,
                   num_samples=num_samples,
                   gpus=gpus,
                   dataset='None',
                   max_epochs=epochs,
                   learning_rate=lr)  #

    model.encoder = resnet18(pretrained=pretrain,
                             first_conv=model.first_conv,
                             maxpool1=model.maxpool1,
                             return_all_feature_maps=False)
    model.projection = Projection(input_dim=512,
                                  hidden_dim=256,
                                  output_dim=embedding_size)  #overrides

    if patience > 0:
        cb = EarlyStopping('val_loss', patience=patience)
        trainer = Trainer(gpus=gpus,
                          max_epochs=epochs,
                          callbacks=[cb],
                          progress_bar_refresh_rate=5)
    else:
        trainer = Trainer(gpus=gpus,
                          max_epochs=epochs,
                          progress_bar_refresh_rate=5)

    if model_checkpoint is not None:
        model.load_state_dict(torch.load(model_checkpoint))
        print(
            'Successfully loaded your checkpoint. Keep in mind that this does not preserve the previous trainer states, only the model weights'
        )

    model.cuda()

    print('Model Initialized')
    trainer.fit(model, data_loader, val_loader)

    Path(f"./models/SSL/SIMCLR_SSL_{version}").mkdir(parents=True,
                                                     exist_ok=True)
    torch.save(model.state_dict(),
               f"./models/SSL/SIMCLR_SSL_{version}/SIMCLR_SSL_{version}.pt")
Beispiel #18
0
    parser.add_argument('--data_dir',
                        required=True,
                        type=str,
                        help='path to image data directory')
    parser.add_argument('--save_dir',
                        required=True,
                        type=str,
                        help='path to image data directory')
    parser.add_argument('--image_index', type=int, default=42)
    parser.add_argument('--n_images', type=int, default=20)
    parser.add_argument('--rgb', type=bool, default=True)
    args = parser.parse_args()

    image_paths = utils.recursive_folder_image_paths(args.data_dir)

    model = SimCLR.load_from_checkpoint(checkpoint_path=args.model_dir,
                                        strict=False)
    model_enc = model.encoder
    model_enc.eval()

    transform = transforms.Compose(
        [transforms.Resize((32, 32)),
         transforms.ToTensor()])
    y = np.empty((len(image_paths), 2048), float)

    for i, p in enumerate(tqdm(image_paths)):
        image = Image.open(p)
        if args.rgb:
            image = image.convert('RGB')
        image = transform(image).unsqueeze_(0)
        y_hat = model_enc(image)
        y_hat = y_hat[0].detach().numpy().reshape(1, -1)
Beispiel #19
0
def get_self_supervised_model(run_params):
    import pl_bolts
    from pl_bolts.models.self_supervised import SimCLR
    from pl_bolts.models.self_supervised.simclr import (
        SimCLRTrainDataTransform,
        SimCLREvalDataTransform,
    )
    from pytorch_lightning import Trainer
    from pytorch_lightning.loggers import TensorBoardLogger
    from pytorch_lightning.callbacks import ModelCheckpoint
    from pytorch_lightning.callbacks.early_stopping import EarlyStopping

    checkpoints_dir = os.path.join(run_params["PATH_PREFIX"], "checkpoints")
    checkpoint_resume = os.path.join(checkpoints_dir,
                                     run_params["MODEL_SAVE_NAME"] + ".ckpt")

    dataset = SelfSupervisedDataset(
        final_df,
        validation=False,
        transform=SimCLRTrainDataTransform(
            min(run_params["RESIZE"], run_params["RANDOM_RESIZE_CROP"])),
        prefix=run_params["RAW_PREPROCESS_FOLDER"] + "/",
    )
    val_dataset = SelfSupervisedDataset(
        final_df,
        validation=True,
        transform=SimCLREvalDataTransform(
            min(run_params["RESIZE"], run_params["RANDOM_RESIZE_CROP"])),
        prefix=run_params["RAW_PREPROCESS_FOLDER"] + "/",
    )

    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=run_params["SELF_SUPERVISED_BATCH_SIZE"],
        num_workers=0)

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=run_params["SELF_SUPERVISED_BATCH_SIZE"],
        num_workers=0)
    num_samples = len(dataset)

    # #init model with batch size, num_samples (len of data), epochs to train, and autofinds learning rate
    model_self_sup = SimCLR(
        gpus=1,
        arch="resnet50",
        dataset="",
        max_epochs=run_params["SELF_SUPERVISED_EPOCHS"],
        warmup_epochs=run_params["SELF_SUPERVISED_WARMUP_EPOCHS"],
        batch_size=run_params["SELF_SUPERVISED_BATCH_SIZE"],
        num_samples=num_samples,
    )

    if run_params["SELF_SUPERVISED_TRAIN"]:
        logger = TensorBoardLogger(
            os.path.join(run_params["PATH_PREFIX"], "tb_logs", "simCLR"),
            name=run_params["MODEL_SAVE_NAME"],
        )
        early_stopping = EarlyStopping("val_loss", patience=5)

        if os.path.exists(checkpoint_resume):
            trainer = Trainer(
                gpus=1,
                max_epochs=run_params["SELF_SUPERVISED_EPOCHS"],
                logger=logger,
                auto_scale_batch_size=True,
                resume_from_checkpoint=checkpoint_resume,
                callbacks=[early_stopping],
            )
        else:
            checkpoint_callback = ModelCheckpoint(
                monitor="val_loss",
                dirpath=checkpoints_dir,
                filename=run_params["MODEL_SAVE_NAME"],
                save_top_k=1,
                mode="min",
            )

            trainer = Trainer(
                gpus=1,
                max_epochs=run_params["SELF_SUPERVISED_EPOCHS"],
                logger=logger,
                auto_scale_batch_size=True,
                callbacks=[checkpoint_callback, early_stopping],
            )

        trainer.fit(model_self_sup, data_loader, val_loader)
        model_self_sup = model_self_sup.load_from_checkpoint(checkpoint_resume)
    elif os.path.exists(checkpoint_resume):
        model_self_sup.load_from_checkpoint(checkpoint_resume)
    else:
        print(
            f"Not checkpoint found, so it could not load model from it\n{checkpoint_resume}"
        )

    return model_self_sup
Beispiel #20
0
def cli_main():
    parser = ArgumentParser()
    parser.add_argument("--DATA_PATH",
                        type=str,
                        help="path to folders with images")
    parser.add_argument("--MODEL_PATH",
                        default=None,
                        type=str,
                        help="path to model checkpoint")
    parser.add_argument("--batch_size",
                        default=128,
                        type=int,
                        help="batch size for SSL")
    parser.add_argument("--image_size",
                        default=256,
                        type=int,
                        help="image size for SSL")
    parser.add_argument(
        "--image_type",
        default="tif",
        type=str,
        help=
        "extension of image for PIL to open and parse - i.e. jpeg, gif, tif, etc. Only put the extension name, not the dot (.)"
    )
    parser.add_argument("--num_workers",
                        default=1,
                        type=int,
                        help="number of CPU cores to use for data processing")
    parser.add_argument("--image_embedding_size",
                        default=128,
                        type=int,
                        help="size of image representation of SIMCLR")
    parser.add_argument(
        "--hidden_dims",
        default=128,
        type=int,
        help=
        "hidden dimensions in classification layer added onto model for finetuning"
    )
    parser.add_argument("--epochs",
                        default=200,
                        type=int,
                        help="number of epochs to train model")
    parser.add_argument("--lr",
                        default=1e-3,
                        type=float,
                        help="learning rate for training model")
    parser.add_argument(
        "--patience",
        default=-1,
        type=int,
        help=
        "automatically cuts off training if validation does not drop for (patience) epochs. Leave blank to have no validation based early stopping."
    )
    parser.add_argument("--val_split",
                        default=0.2,
                        type=float,
                        help="percent in validation data")
    parser.add_argument(
        "--withold_train_percent",
        default=0,
        type=float,
        help=
        "decimal from 0-1 representing how much of the training data to withold during finetuning"
    )
    parser.add_argument("--gpus",
                        default=1,
                        type=int,
                        help="number of gpus to use for training")
    parser.add_argument(
        "--eval",
        default=True,
        type=bool,
        help=
        "Eval Mode will train and evaluate the finetuned model's performance")
    parser.add_argument("--imagenet_weights",
                        default=False,
                        type=bool,
                        help="Use weights from a non-SSL")
    parser.add_argument("--version",
                        default="0",
                        type=str,
                        help="version to name checkpoint for saving")

    args = parser.parse_args()
    DATA_PATH = args.DATA_PATH
    batch_size = args.batch_size
    image_size = args.image_size
    image_type = args.image_type
    num_workers = args.num_workers
    embedding_size = args.image_embedding_size
    hidden_dims = args.hidden_dims
    epochs = args.epochs
    lr = args.lr
    patience = args.patience
    val_split = args.val_split
    withold_train_percent = args.withold_train_percent
    version = args.version
    model_checkpoint = args.MODEL_PATH
    gpus = args.gpus
    eval_model = args.eval
    version = args.version
    imagenet_weights = args.imagenet_weights

    # #testing
    # batch_size = 128
    # image_type = 'tif'
    # image_size = 256
    # num_workers = 4
    # DATA_PATH ='/content/UCMerced_LandUse/Images'
    # embedding_size = 128
    # epochs = 15
    # hidden_dims = 128
    # lr = 1e-3
    # patience = 1
    # val_split = 0.2
    # withold_train_percent = 0.2
    # model_checkpoint = '/content/models/SSL/SIMCLR_SSL_0/SIMCLR_SSL_0.pt'
    # gpus = 1
    # eval_model = True
    # version = "0"

    #gets dataset. We can't combine since validation data has different transform needed
    finetune_dataset = FolderDataset(
        DATA_PATH,
        validation=False,
        val_split=val_split,
        withold_train_percent=withold_train_percent,
        transform=SimCLRFinetuneTransform(image_size),
        image_type=image_type)

    finetune_loader = torch.utils.data.DataLoader(finetune_dataset,
                                                  batch_size=batch_size,
                                                  num_workers=num_workers,
                                                  drop_last=True)

    print('Training Data Loaded...')
    finetune_val_dataset = FolderDataset(
        DATA_PATH,
        validation=True,
        val_split=val_split,
        transform=SimCLRFinetuneTransform(image_size),
        image_type=image_type)

    finetune_val_loader = torch.utils.data.DataLoader(finetune_val_dataset,
                                                      batch_size=batch_size,
                                                      num_workers=num_workers,
                                                      drop_last=True)
    print('Validation Data Loaded...')

    num_samples = len(finetune_dataset)
    model = SimCLR(arch='resnet18',
                   batch_size=batch_size,
                   num_samples=num_samples,
                   gpus=gpus,
                   dataset='None',
                   max_epochs=epochs,
                   learning_rate=lr)  #
    model.encoder = resnet18(pretrained=imagenet_weights,
                             first_conv=model.first_conv,
                             maxpool1=model.maxpool1,
                             return_all_feature_maps=False)
    model.projection = Projection(input_dim=512,
                                  hidden_dim=256,
                                  output_dim=embedding_size)  #overrides

    if model_checkpoint is not None:
        model.load_state_dict(torch.load(model_checkpoint))
        print(
            'Successfully loaded your checkpoint. Keep in mind that this does not preserve the previous trainer states, only the model weights'
        )
    else:
        if imagenet_weights:
            print('Using imagenet weights instead of a pretrained SSL model')
        else:
            print('Using random initialization of encoder')

    num_classes = len(set(finetune_dataset.labels))
    print('Finetuning to classify ', num_classes, ' Classes')

    tuner = SSLFineTuner(model,
                         in_features=512,
                         num_classes=num_classes,
                         hidden_dim=hidden_dims,
                         learning_rate=lr)
    if patience > 0:
        cb = EarlyStopping('val_loss', patience=patience)
        trainer = Trainer(gpus=gpus,
                          max_epochs=epochs,
                          callbacks=[cb],
                          progress_bar_refresh_rate=5)
    else:
        trainer = Trainer(gpus=gpus,
                          max_epochs=epochs,
                          progress_bar_refresh_rate=5)
    tuner.cuda()
    trainer.fit(tuner,
                train_dataloader=finetune_loader,
                val_dataloaders=finetune_val_loader)

    Path(f"./models/Finetune/SIMCLR_Finetune_{version}").mkdir(parents=True,
                                                               exist_ok=True)

    if eval_model:
        print('Evaluating Model...')
        save_path = f"./models/Finetune/SIMCLR_Finetune_{version}/Evaluation/trainingMetrics"
        Path(save_path).mkdir(parents=True, exist_ok=True)
        eval_finetune(tuner, 'training', finetune_loader, save_path)

        save_path = f"./models/Finetune/SIMCLR_Finetune_{version}/Evaluation/validationMetrics"
        Path(save_path).mkdir(parents=True, exist_ok=True)
        eval_finetune(tuner, 'validation', finetune_val_loader, save_path)

    print('Saving model...')

    torch.save(
        tuner.state_dict(),
        f"./models/Finetune/SIMCLR_Finetune_{version}/SIMCLR_FINETUNE_{version}.pt"
    )
def cli_main():

    parser = ArgumentParser()
    parser.add_argument("--DATA_PATH",
                        type=str,
                        help="path to folders with images")
    parser.add_argument("--MODEL_PATH",
                        default=None,
                        type=str,
                        help="path to model checkpoint")
    parser.add_argument("--batch_size",
                        default=128,
                        type=int,
                        help="batch size for SSL")
    parser.add_argument("--image_size",
                        default=256,
                        type=int,
                        help="image size for SSL")
    parser.add_argument("--num_workers",
                        default=1,
                        type=int,
                        help="number of CPU cores to use for data processing")
    parser.add_argument("--image_embedding_size",
                        default=128,
                        type=int,
                        help="size of image representation of SIMCLR")
    parser.add_argument("--epochs",
                        default=200,
                        type=int,
                        help="number of epochs to train model")
    parser.add_argument("--lr",
                        default=1e-3,
                        type=float,
                        help="learning rate for training model")
    parser.add_argument(
        "--patience",
        default=-1,
        type=int,
        help=
        "automatically cuts off training if validation does not drop for (patience) epochs. Leave blank to have no validation based early stopping."
    )
    parser.add_argument("--val_split",
                        default=0.2,
                        type=float,
                        help="percent in validation data")
    parser.add_argument(
        "--pretrain_encoder",
        default=False,
        type=bool,
        help=
        "initialize resnet encoder with pretrained imagenet weights. Cannot be true if passing previous SSL model checkpoint."
    )
    parser.add_argument(
        "--withold_train_percent",
        default=0,
        type=float,
        help=
        "decimal from 0-1 representing how much of the training data to withold during SSL training"
    )
    parser.add_argument("--version",
                        default="0",
                        type=str,
                        help="version to name checkpoint for saving")
    parser.add_argument("--gpus",
                        default=1,
                        type=int,
                        help="number of gpus to use for training")
    parser.add_argument("--num_workers",
                        default=0,
                        type=int,
                        help="number of workers to use to fetch data")

    args = parser.parse_args()
    URL = args.DATA_PATH
    batch_size = args.batch_size
    image_size = args.image_size
    num_workers = args.num_workers
    embedding_size = args.image_embedding_size
    epochs = args.epochs
    lr = args.lr
    patience = args.patience
    val_split = args.val_split
    pretrain = args.pretrain_encoder
    withold_train_percent = args.withold_train_percent
    version = args.version
    model_checkpoint = args.MODEL_PATH
    gpus = args.gpus
    num_workers = args.num_workers

    train_transform = SimCLRTrainDataTransform(256)
    val_transform = SimCLREvalDataTransform(256)
    dm = ImageDataModule(URL,
                         train_transform=train_transform,
                         val_transform=val_transform,
                         val_split=val_split,
                         num_workers=num_workers)
    dm.setup()

    #init model with batch size, num_samples (len of data), epochs to train, and autofinds learning rate
    model = SimCLR(arch='resnet18',
                   batch_size=batch_size,
                   num_samples=dm.num_samples,
                   gpus=gpus,
                   dataset='None',
                   max_epochs=epochs,
                   learning_rate=lr)  #

    model.encoder = resnet18(pretrained=pretrain,
                             first_conv=model.first_conv,
                             maxpool1=model.maxpool1,
                             return_all_feature_maps=False)
    model.projection = Projection(input_dim=512,
                                  hidden_dim=256,
                                  output_dim=embedding_size)  #overrides

    if patience > 0:
        cb = EarlyStopping('val_loss', patience=patience)
        trainer = Trainer(gpus=gpus,
                          max_epochs=epochs,
                          callbacks=[cb],
                          progress_bar_refresh_rate=5)
    else:
        trainer = Trainer(gpus=gpus,
                          max_epochs=epochs,
                          progress_bar_refresh_rate=5)

    if model_checkpoint is not None:
        model.load_state_dict(torch.load(model_checkpoint))
        print(
            'Successfully loaded your checkpoint. Keep in mind that this does not preserve the previous trainer states, only the model weights'
        )

    model.cuda()

    print('Model Initialized')
    trainer.fit(model, dm)

    Path(f"./models/SSL/SIMCLR_SSL_{version}").mkdir(parents=True,
                                                     exist_ok=True)
    torch.save(model.state_dict(),
               f"./models/SSL/SIMCLR_SSL_{version}/SIMCLR_SSL_{version}.pt")
Beispiel #22
0
def cli_main():

    parser = ArgumentParser()
    parser.add_argument(
        "--MODEL_PATH",
        type=str,
        help="path to .pt file containing SSL-trained SimCLR Resnet18 Model")
    parser.add_argument(
        "--DATA_PATH",
        type=str,
        help=
        "path to data. If folder already contains validation data only, set val_split to 0"
    )
    parser.add_argument(
        "--val_split",
        default=0.2,
        type=float,
        help="amount of data to use for validation as a decimal")
    parser.add_argument(
        "--image_type",
        default="tif",
        type=str,
        help=
        "extension of image for PIL to open and parse - i.e. jpeg, gif, tif, etc. Only put the extension name, not the dot (.)"
    )
    parser.add_argument("--image_embedding_size",
                        default=128,
                        type=int,
                        help="size of image representation of SIMCLR")
    parser.add_argument("--image_size",
                        default=128,
                        type=int,
                        help="height of square image to pass through model")
    parser.add_argument("--gpus",
                        default=1,
                        type=int,
                        help="number of gpus to use for training")
    parser.add_argument("--rank",
                        default=50,
                        type=int,
                        help="number of neighbors to search for")
    parser.add_argument(
        "--filter_same_group",
        default=False,
        type=bool,
        help="custom arg for hurricane data to filter same hurricanes out")

    args = parser.parse_args()
    MODEL_PATH = args.MODEL_PATH
    DATA_PATH = args.DATA_PATH
    image_size = args.image_size
    image_type = args.image_type
    embedding_size = args.image_embedding_size
    val_split = args.val_split
    gpus = args.gpus
    rank_to = args.rank
    filter_hur = args.filter_same_group

    #testing
    # MODEL_PATH = '/content/models/SSL/SIMCLR_SSL_0.pt'
    # DATA_PATH = '/content/UCMerced_LandUse/Images'
    # image_size = 128
    # image_type = 'tif'
    # embedding_size = 128
    # val_split = 0.2
    # gpus = 1

    # #gets dataset. We can't combine since validation data has different transform needed
    train_dataset = FolderDataset(
        DATA_PATH,
        validation=False,
        val_split=val_split,
        transform=SimCLRTrainDataTransform(image_size),
        image_type=image_type)

    print('Training Data Loaded...')
    val_dataset = FolderDataset(DATA_PATH,
                                validation=True,
                                val_split=val_split,
                                transform=SimCLREvalDataTransform(image_size),
                                image_type=image_type)

    print('Validation Data Loaded...')

    #load model
    num_samples = len(train_dataset)

    #init model with batch size, num_samples (len of data), epochs to train, and autofinds learning rate
    model = SimCLR(arch='resnet18',
                   batch_size=1,
                   num_samples=num_samples,
                   gpus=gpus,
                   dataset='None')  #

    model.encoder = resnet18(pretrained=False,
                             first_conv=model.first_conv,
                             maxpool1=model.maxpool1,
                             return_all_feature_maps=False)
    model.projection = Projection(input_dim=512,
                                  hidden_dim=256,
                                  output_dim=embedding_size)  #overrides

    model.load_state_dict(torch.load(MODEL_PATH))

    model.cuda()
    print('Successfully loaded your model for evaluation.')

    #running eval on validation data
    save_path = f"{MODEL_PATH[:-3]}/Evaluation/validationMetrics"
    Path(save_path).mkdir(parents=True, exist_ok=True)
    eval_embeddings(model, val_dataset, save_path, rank_to, filter_hur)
    print('Validation Data Evaluation Complete.')

    #running eval on training data
    save_path = f"{MODEL_PATH[:-3]}/Evaluation/trainingMetrics"
    Path(save_path).mkdir(parents=True, exist_ok=True)
    eval_embeddings(model, train_dataset, save_path, rank_to, filter_hur)
    print('Training Data Evaluation Complete.')

    print(f'Please check {MODEL_PATH[:-3]}/Evaluation/ for your results')
def cli_main():

    parser = ArgumentParser()
    parser.add_argument(
        "--MODEL_PATH",
        type=str,
        help="path to .pt file containing SSL-trained SimCLR Resnet18 Model")
    parser.add_argument(
        "--DATA_PATH",
        type=str,
        help=
        "path to data. If folder already contains validation data only, set val_split to 0"
    )
    parser.add_argument(
        "--val_split",
        default=0.2,
        type=float,
        help="amount of data to use for validation as a decimal")
    parser.add_argument("--image_embedding_size",
                        default=128,
                        type=int,
                        help="size of image representation of SIMCLR")
    parser.add_argument("--image_size",
                        default=128,
                        type=int,
                        help="height of square image to pass through model")
    parser.add_argument("--gpus",
                        default=1,
                        type=int,
                        help="number of gpus to use for training")
    parser.add_argument("--rank",
                        default=50,
                        type=int,
                        help="number of neighbors to search for")
    parser.add_argument("--batch_size",
                        default=128,
                        type=int,
                        help="batch size for Evaluation")
    parser.add_argument(
        "--pretrain_encoder",
        default=False,
        type=bool,
        help=
        "initialize resnet encoder with pretrained imagenet weights. Will be ignored if MODEL_PATH is specified."
    )

    args = parser.parse_args()
    MODEL_PATH = args.MODEL_PATH
    URL = args.DATA_PATH
    image_size = args.image_size
    embedding_size = args.image_embedding_size
    val_split = args.val_split
    gpus = args.gpus
    rank_to = args.rank
    batch_size = args.batch_size
    pretrain = args.pretrain_encoder

    train_transform = SimCLRTrainDataTransform(256)
    val_transform = SimCLREvalDataTransform(256)
    dm = ImageDataModule(URL,
                         train_transform=train_transform,
                         val_transform=val_transform,
                         val_split=val_split)
    dm.setup()

    #init model with batch size, num_samples (len of data), epochs to train, and autofinds learning rate
    model = SimCLR(arch='resnet18',
                   batch_size=batch_size,
                   num_samples=dm.num_samples,
                   gpus=gpus,
                   dataset='None')  #

    model.encoder = resnet18(pretrained=pretrain,
                             first_conv=model.first_conv,
                             maxpool1=model.maxpool1,
                             return_all_feature_maps=False)
    model.projection = Projection(input_dim=512,
                                  hidden_dim=256,
                                  output_dim=embedding_size)  #overrides

    model.cuda()
    print('Successfully loaded your model for evaluation.')

    #running eval on validation data
    save_path = f"{MODEL_PATH[:-3]}/Evaluation/validationMetrics"
    Path(save_path).mkdir(parents=True, exist_ok=True)
    if dm.val_dataloader() is not None:
        eval_embeddings(model, dm.val_dataloader(), save_path, rank_to)
        print('Validation Data Evaluation Complete.')

    #running eval on training data
    save_path = f"{MODEL_PATH[:-3]}/Evaluation/trainingMetrics"
    Path(save_path).mkdir(parents=True, exist_ok=True)
    eval_embeddings(model, dm.train_dataloader(), save_path, rank_to)
    print('Training Data Evaluation Complete.')

    print(f'Please check {MODEL_PATH[:-3]}/Evaluation/ for your results')