Example #1
0
def cli_main():
    parser = ArgumentParser()
    parser.add_argument("--DATA_PATH",
                        type=str,
                        help="path to folders with images")
    parser.add_argument("--MODEL_PATH",
                        default=None,
                        type=str,
                        help="path to model checkpoint")
    parser.add_argument("--batch_size",
                        default=128,
                        type=int,
                        help="batch size for SSL")
    parser.add_argument("--image_size",
                        default=256,
                        type=int,
                        help="image size for SSL")
    parser.add_argument("--image_embedding_size",
                        default=128,
                        type=int,
                        help="size of image representation of SIMCLR")
    parser.add_argument(
        "--hidden_dims",
        default=128,
        type=int,
        help=
        "hidden dimensions in classification layer added onto model for finetuning"
    )
    parser.add_argument("--epochs",
                        default=200,
                        type=int,
                        help="number of epochs to train model")
    parser.add_argument("--lr",
                        default=0.3,
                        type=float,
                        help="learning rate for training model")
    parser.add_argument(
        "--patience",
        default=-1,
        type=int,
        help=
        "automatically cuts off training if validation does not drop for (patience) epochs. Leave blank to have no validation based early stopping."
    )
    parser.add_argument("--val_split",
                        default=0.2,
                        type=float,
                        help="percent in validation data")
    parser.add_argument(
        "--withold_train_percent",
        default=0,
        type=float,
        help=
        "decimal from 0-1 representing how much of the training data to withold during finetuning"
    )
    parser.add_argument("--gpus",
                        default=1,
                        type=int,
                        help="number of gpus to use for training")
    parser.add_argument(
        "--eval",
        default=True,
        type=bool,
        help=
        "Eval Mode will train and evaluate the finetuned model's performance")
    parser.add_argument(
        "--pretrain_encoder",
        default=False,
        type=bool,
        help=
        "initialize resnet encoder with pretrained imagenet weights. Ignored if MODEL_PATH is specified."
    )
    parser.add_argument("--version",
                        default="0",
                        type=str,
                        help="version to name checkpoint for saving")
    parser.add_argument("--fix_backbone",
                        default=True,
                        type=bool,
                        help="Fix backbone during finetuning")
    parser.add_argument("--num_workers",
                        default=0,
                        type=int,
                        help="number of workers to use to fetch data")

    args = parser.parse_args()
    URL = args.DATA_PATH
    batch_size = args.batch_size
    image_size = args.image_size
    embedding_size = args.image_embedding_size
    hidden_dims = args.hidden_dims
    epochs = args.epochs
    lr = args.lr
    patience = args.patience
    val_split = args.val_split
    withold_train_percent = args.withold_train_percent
    version = args.version
    model_checkpoint = args.MODEL_PATH
    gpus = args.gpus
    eval_model = args.eval
    version = args.version
    pretrain = args.pretrain_encoder
    fix_backbone = args.fix_backbone
    num_workers = args.num_workers

    dm = FolderDataset2(URL,
                        val_split=val_split,
                        train_transform=SimCLRFinetuneTransform(image_size),
                        val_transform=SimCLRFinetuneTransform(image_size))
    dm.setup()

    model = SimCLR(arch='resnet18',
                   batch_size=batch_size,
                   num_samples=dm.num_samples,
                   gpus=1,
                   dataset='None',
                   max_epochs=100,
                   learning_rate=lr)  #
    model.projection = Projection(input_dim=512,
                                  hidden_dim=256,
                                  output_dim=128)  #overrides
    model.encoder = resnet18(pretrained=pretrain,
                             first_conv=model.first_conv,
                             maxpool1=model.maxpool1,
                             return_all_feature_maps=False)
    if model_checkpoint is not None:
        model.load_state_dict(torch.load(model_checkpoint))
        print(
            'Successfully loaded your checkpoint. Keep in mind that this does not preserve the previous trainer states, only the model weights'
        )
    else:
        if pretrain:
            print('Using imagenet weights instead of a pretrained SSL model')
        else:
            print('Using random initialization of encoder')

    print('Finetuning to classify ', dm.num_classes, ' Classes')

    tuner = SSLFineTuner(model,
                         in_features=512,
                         num_classes=dm.num_classes,
                         epochs=epochs,
                         hidden_dim=hidden_dims,
                         dropout=0,
                         learning_rate=lr,
                         weight_decay=1e-6,
                         nesterov=False,
                         scheduler_type='cosine',
                         gamma=0.1,
                         final_lr=0.,
                         fix_backbone=True)

    trainer = pl.Trainer(gpus=gpus,
                         num_nodes=1,
                         precision=16,
                         max_epochs=epochs,
                         distributed_backend='ddp',
                         sync_batchnorm=False)

    trainer.fit(tuner, dm)

    Path(f"./models/Finetune/SIMCLR_Finetune_{version}").mkdir(parents=True,
                                                               exist_ok=True)

    if eval_model:
        print('Evaluating Model...')
        save_path = f"./models/Finetune/SIMCLR_Finetune_{version}/Evaluation/validationMetrics"
        Path(save_path).mkdir(parents=True, exist_ok=True)

        if dm.val_dataloader() is not None:
            eval_finetune(tuner, 'validation', dm.val_dataloader(), save_path)

        save_path = f"./models/Finetune/SIMCLR_Finetune_{version}/Evaluation/trainingMetrics"
        Path(save_path).mkdir(parents=True, exist_ok=True)
        eval_finetune(tuner, 'training', dm.train_dataloader(), save_path)

    print('Saving model...')

    torch.save(
        tuner.state_dict(),
        f"./models/Finetune/SIMCLR_Finetune_{version}/SIMCLR_FINETUNE_{version}.pt"
    )
Example #2
0
def cli_main():
    parser = ArgumentParser()
    parser.add_argument("--DATA_PATH",
                        type=str,
                        help="path to folders with images")
    parser.add_argument("--MODEL_PATH",
                        default=None,
                        type=str,
                        help="path to model checkpoint")
    parser.add_argument("--batch_size",
                        default=128,
                        type=int,
                        help="batch size for SSL")
    parser.add_argument("--image_size",
                        default=256,
                        type=int,
                        help="image size for SSL")
    parser.add_argument(
        "--image_type",
        default="tif",
        type=str,
        help=
        "extension of image for PIL to open and parse - i.e. jpeg, gif, tif, etc. Only put the extension name, not the dot (.)"
    )
    parser.add_argument("--num_workers",
                        default=1,
                        type=int,
                        help="number of CPU cores to use for data processing")
    parser.add_argument("--image_embedding_size",
                        default=128,
                        type=int,
                        help="size of image representation of SIMCLR")
    parser.add_argument(
        "--hidden_dims",
        default=128,
        type=int,
        help=
        "hidden dimensions in classification layer added onto model for finetuning"
    )
    parser.add_argument("--epochs",
                        default=200,
                        type=int,
                        help="number of epochs to train model")
    parser.add_argument("--lr",
                        default=1e-3,
                        type=float,
                        help="learning rate for training model")
    parser.add_argument(
        "--patience",
        default=-1,
        type=int,
        help=
        "automatically cuts off training if validation does not drop for (patience) epochs. Leave blank to have no validation based early stopping."
    )
    parser.add_argument("--val_split",
                        default=0.2,
                        type=float,
                        help="percent in validation data")
    parser.add_argument(
        "--withold_train_percent",
        default=0,
        type=float,
        help=
        "decimal from 0-1 representing how much of the training data to withold during finetuning"
    )
    parser.add_argument("--gpus",
                        default=1,
                        type=int,
                        help="number of gpus to use for training")
    parser.add_argument(
        "--eval",
        default=True,
        type=bool,
        help=
        "Eval Mode will train and evaluate the finetuned model's performance")
    parser.add_argument("--imagenet_weights",
                        default=False,
                        type=bool,
                        help="Use weights from a non-SSL")
    parser.add_argument("--version",
                        default="0",
                        type=str,
                        help="version to name checkpoint for saving")

    args = parser.parse_args()
    DATA_PATH = args.DATA_PATH
    batch_size = args.batch_size
    image_size = args.image_size
    image_type = args.image_type
    num_workers = args.num_workers
    embedding_size = args.image_embedding_size
    hidden_dims = args.hidden_dims
    epochs = args.epochs
    lr = args.lr
    patience = args.patience
    val_split = args.val_split
    withold_train_percent = args.withold_train_percent
    version = args.version
    model_checkpoint = args.MODEL_PATH
    gpus = args.gpus
    eval_model = args.eval
    version = args.version
    imagenet_weights = args.imagenet_weights

    # #testing
    # batch_size = 128
    # image_type = 'tif'
    # image_size = 256
    # num_workers = 4
    # DATA_PATH ='/content/UCMerced_LandUse/Images'
    # embedding_size = 128
    # epochs = 15
    # hidden_dims = 128
    # lr = 1e-3
    # patience = 1
    # val_split = 0.2
    # withold_train_percent = 0.2
    # model_checkpoint = '/content/models/SSL/SIMCLR_SSL_0/SIMCLR_SSL_0.pt'
    # gpus = 1
    # eval_model = True
    # version = "0"

    #gets dataset. We can't combine since validation data has different transform needed
    finetune_dataset = FolderDataset(
        DATA_PATH,
        validation=False,
        val_split=val_split,
        withold_train_percent=withold_train_percent,
        transform=SimCLRFinetuneTransform(image_size),
        image_type=image_type)

    finetune_loader = torch.utils.data.DataLoader(finetune_dataset,
                                                  batch_size=batch_size,
                                                  num_workers=num_workers,
                                                  drop_last=True)

    print('Training Data Loaded...')
    finetune_val_dataset = FolderDataset(
        DATA_PATH,
        validation=True,
        val_split=val_split,
        transform=SimCLRFinetuneTransform(image_size),
        image_type=image_type)

    finetune_val_loader = torch.utils.data.DataLoader(finetune_val_dataset,
                                                      batch_size=batch_size,
                                                      num_workers=num_workers,
                                                      drop_last=True)
    print('Validation Data Loaded...')

    num_samples = len(finetune_dataset)
    model = SimCLR(arch='resnet18',
                   batch_size=batch_size,
                   num_samples=num_samples,
                   gpus=gpus,
                   dataset='None',
                   max_epochs=epochs,
                   learning_rate=lr)  #
    model.encoder = resnet18(pretrained=imagenet_weights,
                             first_conv=model.first_conv,
                             maxpool1=model.maxpool1,
                             return_all_feature_maps=False)
    model.projection = Projection(input_dim=512,
                                  hidden_dim=256,
                                  output_dim=embedding_size)  #overrides

    if model_checkpoint is not None:
        model.load_state_dict(torch.load(model_checkpoint))
        print(
            'Successfully loaded your checkpoint. Keep in mind that this does not preserve the previous trainer states, only the model weights'
        )
    else:
        if imagenet_weights:
            print('Using imagenet weights instead of a pretrained SSL model')
        else:
            print('Using random initialization of encoder')

    num_classes = len(set(finetune_dataset.labels))
    print('Finetuning to classify ', num_classes, ' Classes')

    tuner = SSLFineTuner(model,
                         in_features=512,
                         num_classes=num_classes,
                         hidden_dim=hidden_dims,
                         learning_rate=lr)
    if patience > 0:
        cb = EarlyStopping('val_loss', patience=patience)
        trainer = Trainer(gpus=gpus,
                          max_epochs=epochs,
                          callbacks=[cb],
                          progress_bar_refresh_rate=5)
    else:
        trainer = Trainer(gpus=gpus,
                          max_epochs=epochs,
                          progress_bar_refresh_rate=5)
    tuner.cuda()
    trainer.fit(tuner,
                train_dataloader=finetune_loader,
                val_dataloaders=finetune_val_loader)

    Path(f"./models/Finetune/SIMCLR_Finetune_{version}").mkdir(parents=True,
                                                               exist_ok=True)

    if eval_model:
        print('Evaluating Model...')
        save_path = f"./models/Finetune/SIMCLR_Finetune_{version}/Evaluation/trainingMetrics"
        Path(save_path).mkdir(parents=True, exist_ok=True)
        eval_finetune(tuner, 'training', finetune_loader, save_path)

        save_path = f"./models/Finetune/SIMCLR_Finetune_{version}/Evaluation/validationMetrics"
        Path(save_path).mkdir(parents=True, exist_ok=True)
        eval_finetune(tuner, 'validation', finetune_val_loader, save_path)

    print('Saving model...')

    torch.save(
        tuner.state_dict(),
        f"./models/Finetune/SIMCLR_Finetune_{version}/SIMCLR_FINETUNE_{version}.pt"
    )
def cli_main():

    parser = ArgumentParser()
    parser.add_argument("--DATA_PATH",
                        type=str,
                        help="path to folders with images")
    parser.add_argument("--MODEL_PATH",
                        default=None,
                        type=str,
                        help="path to model checkpoint")
    parser.add_argument("--batch_size",
                        default=128,
                        type=int,
                        help="batch size for SSL")
    parser.add_argument("--image_size",
                        default=256,
                        type=int,
                        help="image size for SSL")
    parser.add_argument("--num_workers",
                        default=1,
                        type=int,
                        help="number of CPU cores to use for data processing")
    parser.add_argument("--image_embedding_size",
                        default=128,
                        type=int,
                        help="size of image representation of SIMCLR")
    parser.add_argument("--epochs",
                        default=200,
                        type=int,
                        help="number of epochs to train model")
    parser.add_argument("--lr",
                        default=1e-3,
                        type=float,
                        help="learning rate for training model")
    parser.add_argument(
        "--patience",
        default=-1,
        type=int,
        help=
        "automatically cuts off training if validation does not drop for (patience) epochs. Leave blank to have no validation based early stopping."
    )
    parser.add_argument("--val_split",
                        default=0.2,
                        type=float,
                        help="percent in validation data")
    parser.add_argument(
        "--pretrain_encoder",
        default=False,
        type=bool,
        help=
        "initialize resnet encoder with pretrained imagenet weights. Cannot be true if passing previous SSL model checkpoint."
    )
    parser.add_argument(
        "--withold_train_percent",
        default=0,
        type=float,
        help=
        "decimal from 0-1 representing how much of the training data to withold during SSL training"
    )
    parser.add_argument("--version",
                        default="0",
                        type=str,
                        help="version to name checkpoint for saving")
    parser.add_argument("--gpus",
                        default=1,
                        type=int,
                        help="number of gpus to use for training")
    parser.add_argument("--num_workers",
                        default=0,
                        type=int,
                        help="number of workers to use to fetch data")

    args = parser.parse_args()
    URL = args.DATA_PATH
    batch_size = args.batch_size
    image_size = args.image_size
    num_workers = args.num_workers
    embedding_size = args.image_embedding_size
    epochs = args.epochs
    lr = args.lr
    patience = args.patience
    val_split = args.val_split
    pretrain = args.pretrain_encoder
    withold_train_percent = args.withold_train_percent
    version = args.version
    model_checkpoint = args.MODEL_PATH
    gpus = args.gpus
    num_workers = args.num_workers

    train_transform = SimCLRTrainDataTransform(256)
    val_transform = SimCLREvalDataTransform(256)
    dm = ImageDataModule(URL,
                         train_transform=train_transform,
                         val_transform=val_transform,
                         val_split=val_split,
                         num_workers=num_workers)
    dm.setup()

    #init model with batch size, num_samples (len of data), epochs to train, and autofinds learning rate
    model = SimCLR(arch='resnet18',
                   batch_size=batch_size,
                   num_samples=dm.num_samples,
                   gpus=gpus,
                   dataset='None',
                   max_epochs=epochs,
                   learning_rate=lr)  #

    model.encoder = resnet18(pretrained=pretrain,
                             first_conv=model.first_conv,
                             maxpool1=model.maxpool1,
                             return_all_feature_maps=False)
    model.projection = Projection(input_dim=512,
                                  hidden_dim=256,
                                  output_dim=embedding_size)  #overrides

    if patience > 0:
        cb = EarlyStopping('val_loss', patience=patience)
        trainer = Trainer(gpus=gpus,
                          max_epochs=epochs,
                          callbacks=[cb],
                          progress_bar_refresh_rate=5)
    else:
        trainer = Trainer(gpus=gpus,
                          max_epochs=epochs,
                          progress_bar_refresh_rate=5)

    if model_checkpoint is not None:
        model.load_state_dict(torch.load(model_checkpoint))
        print(
            'Successfully loaded your checkpoint. Keep in mind that this does not preserve the previous trainer states, only the model weights'
        )

    model.cuda()

    print('Model Initialized')
    trainer.fit(model, dm)

    Path(f"./models/SSL/SIMCLR_SSL_{version}").mkdir(parents=True,
                                                     exist_ok=True)
    torch.save(model.state_dict(),
               f"./models/SSL/SIMCLR_SSL_{version}/SIMCLR_SSL_{version}.pt")
def cli_main():

    parser = ArgumentParser()
    parser.add_argument("--DATA_PATH",
                        type=str,
                        help="path to folders with images")
    parser.add_argument("--MODEL_PATH",
                        default=None,
                        type=str,
                        help="path to model checkpoint")
    parser.add_argument("--batch_size",
                        default=128,
                        type=int,
                        help="batch size for SSL")
    parser.add_argument("--image_size",
                        default=256,
                        type=int,
                        help="image size for SSL")
    parser.add_argument(
        "--image_type",
        default="tif",
        type=str,
        help=
        "extension of image for PIL to open and parse - i.e. jpeg, gif, tif, etc. Only put the extension name, not the dot (.)"
    )
    parser.add_argument("--num_workers",
                        default=1,
                        type=int,
                        help="number of CPU cores to use for data processing")
    parser.add_argument("--image_embedding_size",
                        default=128,
                        type=int,
                        help="size of image representation of SIMCLR")
    parser.add_argument("--epochs",
                        default=200,
                        type=int,
                        help="number of epochs to train model")
    parser.add_argument("--lr",
                        default=1e-3,
                        type=float,
                        help="learning rate for training model")
    parser.add_argument(
        "--patience",
        default=-1,
        type=int,
        help=
        "automatically cuts off training if validation does not drop for (patience) epochs. Leave blank to have no validation based early stopping."
    )
    parser.add_argument("--val_split",
                        default=0.2,
                        type=float,
                        help="percent in validation data")
    parser.add_argument(
        "--pretrain_encoder",
        default=False,
        type=bool,
        help=
        "initialize resnet encoder with pretrained imagenet weights. Cannot be true if passing previous SSL model checkpoint."
    )
    parser.add_argument(
        "--withold_train_percent",
        default=0,
        type=float,
        help=
        "decimal from 0-1 representing how much of the training data to withold during SSL training"
    )
    parser.add_argument("--version",
                        default="0",
                        type=str,
                        help="version to name checkpoint for saving")
    parser.add_argument("--gpus",
                        default=1,
                        type=int,
                        help="number of gpus to use for training")

    args = parser.parse_args()
    URL = args.DATA_PATH
    batch_size = args.batch_size
    image_size = args.image_size
    image_type = args.image_type
    num_workers = args.num_workers
    embedding_size = args.image_embedding_size
    epochs = args.epochs
    lr = args.lr
    patience = args.patience
    val_split = args.val_split
    pretrain = args.pretrain_encoder
    withold_train_percent = args.withold_train_percent
    version = args.version
    model_checkpoint = args.MODEL_PATH
    gpus = args.gpus

    # #testing
    # batch_size = 128
    # image_type = 'tif'
    # image_size = 256
    # num_workers = 4
    # URL ='/content/UCMerced_LandUse/Images'
    # embedding_size = 128
    # epochs = 2
    # lr = 1e-3
    # patience = 1
    # val_split = 0.2
    # pretrain = False
    # withold_train_percent = 0.2
    # version = "1"
    # model_checkpoint = '/content/models/SSL/SIMCLR_SSL_0.pt'
    # gpus = 1

    # #gets dataset. We can't combine since validation data has different transform needed
    train_dataset = FolderDataset(
        URL,
        validation=False,
        val_split=val_split,
        withold_train_percent=withold_train_percent,
        transform=SimCLRTrainDataTransform(image_size),
        image_type=image_type)

    data_loader = torch.utils.data.DataLoader(train_dataset,
                                              batch_size=batch_size,
                                              num_workers=num_workers,
                                              drop_last=True)

    print('Training Data Loaded...')
    val_dataset = FolderDataset(URL,
                                validation=True,
                                val_split=val_split,
                                transform=SimCLREvalDataTransform(image_size),
                                image_type=image_type)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             num_workers=num_workers,
                                             drop_last=True)
    print('Validation Data Loaded...')

    num_samples = len(train_dataset)

    #init model with batch size, num_samples (len of data), epochs to train, and autofinds learning rate
    model = SimCLR(arch='resnet18',
                   batch_size=batch_size,
                   num_samples=num_samples,
                   gpus=gpus,
                   dataset='None',
                   max_epochs=epochs,
                   learning_rate=lr)  #

    model.encoder = resnet18(pretrained=pretrain,
                             first_conv=model.first_conv,
                             maxpool1=model.maxpool1,
                             return_all_feature_maps=False)
    model.projection = Projection(input_dim=512,
                                  hidden_dim=256,
                                  output_dim=embedding_size)  #overrides

    if patience > 0:
        cb = EarlyStopping('val_loss', patience=patience)
        trainer = Trainer(gpus=gpus,
                          max_epochs=epochs,
                          callbacks=[cb],
                          progress_bar_refresh_rate=5)
    else:
        trainer = Trainer(gpus=gpus,
                          max_epochs=epochs,
                          progress_bar_refresh_rate=5)

    if model_checkpoint is not None:
        model.load_state_dict(torch.load(model_checkpoint))
        print(
            'Successfully loaded your checkpoint. Keep in mind that this does not preserve the previous trainer states, only the model weights'
        )

    model.cuda()

    print('Model Initialized')
    trainer.fit(model, data_loader, val_loader)

    Path(f"./models/SSL/SIMCLR_SSL_{version}").mkdir(parents=True,
                                                     exist_ok=True)
    torch.save(model.state_dict(),
               f"./models/SSL/SIMCLR_SSL_{version}/SIMCLR_SSL_{version}.pt")
Example #5
0
def cli_main():
    parser = ArgumentParser()
    parser.add_argument("--DATA_PATH", type=str, help="path to folders with images")
    parser.add_argument("--MODEL_PATH", default=None , type=str, help="path to model checkpoint")
    parser.add_argument("--batch_size", default=128, type=int, help="batch size for SSL")
    parser.add_argument("--image_size", default=256, type=int, help="image size for SSL")
    parser.add_argument("--num_workers", default=1, type=int, help="number of CPU cores to use for data processing")
    parser.add_argument("--image_embedding_size", default=128, type=int, help="size of image representation of SIMCLR")
    parser.add_argument("--hidden_dims", default=128, type=int, help="hidden dimensions in classification layer added onto model for finetuning")
    parser.add_argument("--epochs", default=200, type=int, help="number of epochs to train model")
    parser.add_argument("--lr", default=1e-3, type=float, help="learning rate for training model")
    parser.add_argument("--patience", default=-1, type=int, help="automatically cuts off training if validation does not drop for (patience) epochs. Leave blank to have no validation based early stopping.")
    parser.add_argument("--val_split", default=0.2, type=float, help="percent in validation data")
    parser.add_argument("--withold_train_percent", default=0, type=float, help="decimal from 0-1 representing how much of the training data to withold during finetuning")
    parser.add_argument("--gpus", default=1, type=int, help="number of gpus to use for training")
    parser.add_argument("--eval", default=True, type=bool, help="Eval Mode will train and evaluate the finetuned model's performance")
    parser.add_argument("--pretrain_encoder", default=False, type=bool, help="initialize resnet encoder with pretrained imagenet weights. Ignored if MODEL_PATH is specified.")
    parser.add_argument("--version", default="0", type=str, help="version to name checkpoint for saving")
    parser.add_argument("--fix_backbone", default=True, type=bool, help="Fix backbone during finetuning")
    
    args = parser.parse_args()
    URL = args.DATA_PATH
    batch_size = args.batch_size
    image_size = args.image_size
    num_workers = args.num_workers
    embedding_size = args.image_embedding_size
    hidden_dims = args.hidden_dims
    epochs = args.epochs
    lr = args.lr
    patience = args.patience
    val_split = args.val_split
    withold_train_percent = args.withold_train_percent
    version = args.version
    model_checkpoint = args.MODEL_PATH
    gpus = args.gpus
    eval_model = args.eval
    version = args.version
    pretrain= args.pretrain_encoder
    fix_backbone = args.fix_backbone

    train_transform = SimCLRFinetuneTransform(256, eval_transform=False)
    val_transform = SimCLRFinetuneTransform(256, eval_transform=True)
    dm = ImageDataModule(URL, train_transform = train_transform, val_transform = val_transform, val_split = val_split)
    dm.setup()

    #init model with batch size, num_samples (len of data), epochs to train, and autofinds learning rate
    model = SimCLR(arch = 'resnet18', batch_size = batch_size, num_samples = dm.num_samples, gpus = gpus, dataset = 'None', max_epochs = epochs, learning_rate = lr) #
    
    model.encoder = resnet18(pretrained=pretrain, first_conv=model.first_conv, maxpool1=model.maxpool1, return_all_feature_maps=False)
    model.projection = Projection(input_dim = 512, hidden_dim = 256, output_dim = embedding_size) #overrides
       
    if model_checkpoint is not None:  
        model.load_state_dict(torch.load(model_checkpoint))
        print('Successfully loaded your checkpoint. Keep in mind that this does not preserve the previous trainer states, only the model weights')
    else:
        if pretrain:   
            print('Using imagenet weights instead of a pretrained SSL model')
        else:
            print('Using random initialization of encoder')
        
    print('Finetuning to classify ', dm.num_classes, ' Classes')

        tuner = SSLFineTuner(
        backbone,
        in_features=512,
        num_classes=dm.num_classes,
        epochs=epochs,
        hidden_dim=None,
        dropout=0,
        learning_rate=0.3,
        weight_decay=1e-6,
        nesterov=False,
        scheduler_type='cosine',
        gamma=0.1,
        final_lr=0.,
        fix_backbone = fix_backbone
    )
Example #6
0
def cli_main():

    parser = ArgumentParser()
    parser.add_argument(
        "--MODEL_PATH",
        type=str,
        help="path to .pt file containing SSL-trained SimCLR Resnet18 Model")
    parser.add_argument(
        "--DATA_PATH",
        type=str,
        help=
        "path to data. If folder already contains validation data only, set val_split to 0"
    )
    parser.add_argument(
        "--val_split",
        default=0.2,
        type=float,
        help="amount of data to use for validation as a decimal")
    parser.add_argument(
        "--image_type",
        default="tif",
        type=str,
        help=
        "extension of image for PIL to open and parse - i.e. jpeg, gif, tif, etc. Only put the extension name, not the dot (.)"
    )
    parser.add_argument("--image_embedding_size",
                        default=128,
                        type=int,
                        help="size of image representation of SIMCLR")
    parser.add_argument("--image_size",
                        default=128,
                        type=int,
                        help="height of square image to pass through model")
    parser.add_argument("--gpus",
                        default=1,
                        type=int,
                        help="number of gpus to use for training")
    parser.add_argument("--rank",
                        default=50,
                        type=int,
                        help="number of neighbors to search for")
    parser.add_argument(
        "--filter_same_group",
        default=False,
        type=bool,
        help="custom arg for hurricane data to filter same hurricanes out")

    args = parser.parse_args()
    MODEL_PATH = args.MODEL_PATH
    DATA_PATH = args.DATA_PATH
    image_size = args.image_size
    image_type = args.image_type
    embedding_size = args.image_embedding_size
    val_split = args.val_split
    gpus = args.gpus
    rank_to = args.rank
    filter_hur = args.filter_same_group

    #testing
    # MODEL_PATH = '/content/models/SSL/SIMCLR_SSL_0.pt'
    # DATA_PATH = '/content/UCMerced_LandUse/Images'
    # image_size = 128
    # image_type = 'tif'
    # embedding_size = 128
    # val_split = 0.2
    # gpus = 1

    # #gets dataset. We can't combine since validation data has different transform needed
    train_dataset = FolderDataset(
        DATA_PATH,
        validation=False,
        val_split=val_split,
        transform=SimCLRTrainDataTransform(image_size),
        image_type=image_type)

    print('Training Data Loaded...')
    val_dataset = FolderDataset(DATA_PATH,
                                validation=True,
                                val_split=val_split,
                                transform=SimCLREvalDataTransform(image_size),
                                image_type=image_type)

    print('Validation Data Loaded...')

    #load model
    num_samples = len(train_dataset)

    #init model with batch size, num_samples (len of data), epochs to train, and autofinds learning rate
    model = SimCLR(arch='resnet18',
                   batch_size=1,
                   num_samples=num_samples,
                   gpus=gpus,
                   dataset='None')  #

    model.encoder = resnet18(pretrained=False,
                             first_conv=model.first_conv,
                             maxpool1=model.maxpool1,
                             return_all_feature_maps=False)
    model.projection = Projection(input_dim=512,
                                  hidden_dim=256,
                                  output_dim=embedding_size)  #overrides

    model.load_state_dict(torch.load(MODEL_PATH))

    model.cuda()
    print('Successfully loaded your model for evaluation.')

    #running eval on validation data
    save_path = f"{MODEL_PATH[:-3]}/Evaluation/validationMetrics"
    Path(save_path).mkdir(parents=True, exist_ok=True)
    eval_embeddings(model, val_dataset, save_path, rank_to, filter_hur)
    print('Validation Data Evaluation Complete.')

    #running eval on training data
    save_path = f"{MODEL_PATH[:-3]}/Evaluation/trainingMetrics"
    Path(save_path).mkdir(parents=True, exist_ok=True)
    eval_embeddings(model, train_dataset, save_path, rank_to, filter_hur)
    print('Training Data Evaluation Complete.')

    print(f'Please check {MODEL_PATH[:-3]}/Evaluation/ for your results')