Beispiel #1
0
def cli_main():
    parser = ArgumentParser()
    parser.add_argument("--DATA_PATH",
                        type=str,
                        help="path to folders with images")
    parser.add_argument("--MODEL_PATH",
                        default=None,
                        type=str,
                        help="path to model checkpoint")
    parser.add_argument("--batch_size",
                        default=128,
                        type=int,
                        help="batch size for SSL")
    parser.add_argument("--image_size",
                        default=256,
                        type=int,
                        help="image size for SSL")
    parser.add_argument("--image_embedding_size",
                        default=128,
                        type=int,
                        help="size of image representation of SIMCLR")
    parser.add_argument(
        "--hidden_dims",
        default=128,
        type=int,
        help=
        "hidden dimensions in classification layer added onto model for finetuning"
    )
    parser.add_argument("--epochs",
                        default=200,
                        type=int,
                        help="number of epochs to train model")
    parser.add_argument("--lr",
                        default=0.3,
                        type=float,
                        help="learning rate for training model")
    parser.add_argument(
        "--patience",
        default=-1,
        type=int,
        help=
        "automatically cuts off training if validation does not drop for (patience) epochs. Leave blank to have no validation based early stopping."
    )
    parser.add_argument("--val_split",
                        default=0.2,
                        type=float,
                        help="percent in validation data")
    parser.add_argument(
        "--withold_train_percent",
        default=0,
        type=float,
        help=
        "decimal from 0-1 representing how much of the training data to withold during finetuning"
    )
    parser.add_argument("--gpus",
                        default=1,
                        type=int,
                        help="number of gpus to use for training")
    parser.add_argument(
        "--eval",
        default=True,
        type=bool,
        help=
        "Eval Mode will train and evaluate the finetuned model's performance")
    parser.add_argument(
        "--pretrain_encoder",
        default=False,
        type=bool,
        help=
        "initialize resnet encoder with pretrained imagenet weights. Ignored if MODEL_PATH is specified."
    )
    parser.add_argument("--version",
                        default="0",
                        type=str,
                        help="version to name checkpoint for saving")
    parser.add_argument("--fix_backbone",
                        default=True,
                        type=bool,
                        help="Fix backbone during finetuning")
    parser.add_argument("--num_workers",
                        default=0,
                        type=int,
                        help="number of workers to use to fetch data")

    args = parser.parse_args()
    URL = args.DATA_PATH
    batch_size = args.batch_size
    image_size = args.image_size
    embedding_size = args.image_embedding_size
    hidden_dims = args.hidden_dims
    epochs = args.epochs
    lr = args.lr
    patience = args.patience
    val_split = args.val_split
    withold_train_percent = args.withold_train_percent
    version = args.version
    model_checkpoint = args.MODEL_PATH
    gpus = args.gpus
    eval_model = args.eval
    version = args.version
    pretrain = args.pretrain_encoder
    fix_backbone = args.fix_backbone
    num_workers = args.num_workers

    dm = FolderDataset2(URL,
                        val_split=val_split,
                        train_transform=SimCLRFinetuneTransform(image_size),
                        val_transform=SimCLRFinetuneTransform(image_size))
    dm.setup()

    model = SimCLR(arch='resnet18',
                   batch_size=batch_size,
                   num_samples=dm.num_samples,
                   gpus=1,
                   dataset='None',
                   max_epochs=100,
                   learning_rate=lr)  #
    model.projection = Projection(input_dim=512,
                                  hidden_dim=256,
                                  output_dim=128)  #overrides
    model.encoder = resnet18(pretrained=pretrain,
                             first_conv=model.first_conv,
                             maxpool1=model.maxpool1,
                             return_all_feature_maps=False)
    if model_checkpoint is not None:
        model.load_state_dict(torch.load(model_checkpoint))
        print(
            'Successfully loaded your checkpoint. Keep in mind that this does not preserve the previous trainer states, only the model weights'
        )
    else:
        if pretrain:
            print('Using imagenet weights instead of a pretrained SSL model')
        else:
            print('Using random initialization of encoder')

    print('Finetuning to classify ', dm.num_classes, ' Classes')

    tuner = SSLFineTuner(model,
                         in_features=512,
                         num_classes=dm.num_classes,
                         epochs=epochs,
                         hidden_dim=hidden_dims,
                         dropout=0,
                         learning_rate=lr,
                         weight_decay=1e-6,
                         nesterov=False,
                         scheduler_type='cosine',
                         gamma=0.1,
                         final_lr=0.,
                         fix_backbone=True)

    trainer = pl.Trainer(gpus=gpus,
                         num_nodes=1,
                         precision=16,
                         max_epochs=epochs,
                         distributed_backend='ddp',
                         sync_batchnorm=False)

    trainer.fit(tuner, dm)

    Path(f"./models/Finetune/SIMCLR_Finetune_{version}").mkdir(parents=True,
                                                               exist_ok=True)

    if eval_model:
        print('Evaluating Model...')
        save_path = f"./models/Finetune/SIMCLR_Finetune_{version}/Evaluation/validationMetrics"
        Path(save_path).mkdir(parents=True, exist_ok=True)

        if dm.val_dataloader() is not None:
            eval_finetune(tuner, 'validation', dm.val_dataloader(), save_path)

        save_path = f"./models/Finetune/SIMCLR_Finetune_{version}/Evaluation/trainingMetrics"
        Path(save_path).mkdir(parents=True, exist_ok=True)
        eval_finetune(tuner, 'training', dm.train_dataloader(), save_path)

    print('Saving model...')

    torch.save(
        tuner.state_dict(),
        f"./models/Finetune/SIMCLR_Finetune_{version}/SIMCLR_FINETUNE_{version}.pt"
    )
Beispiel #2
0
def cli_main():
    parser = ArgumentParser()
    parser.add_argument("--DATA_PATH",
                        type=str,
                        help="path to folders with images")
    parser.add_argument("--MODEL_PATH",
                        default=None,
                        type=str,
                        help="path to model checkpoint")
    parser.add_argument("--batch_size",
                        default=128,
                        type=int,
                        help="batch size for SSL")
    parser.add_argument("--image_size",
                        default=256,
                        type=int,
                        help="image size for SSL")
    parser.add_argument(
        "--image_type",
        default="tif",
        type=str,
        help=
        "extension of image for PIL to open and parse - i.e. jpeg, gif, tif, etc. Only put the extension name, not the dot (.)"
    )
    parser.add_argument("--num_workers",
                        default=1,
                        type=int,
                        help="number of CPU cores to use for data processing")
    parser.add_argument("--image_embedding_size",
                        default=128,
                        type=int,
                        help="size of image representation of SIMCLR")
    parser.add_argument(
        "--hidden_dims",
        default=128,
        type=int,
        help=
        "hidden dimensions in classification layer added onto model for finetuning"
    )
    parser.add_argument("--epochs",
                        default=200,
                        type=int,
                        help="number of epochs to train model")
    parser.add_argument("--lr",
                        default=1e-3,
                        type=float,
                        help="learning rate for training model")
    parser.add_argument(
        "--patience",
        default=-1,
        type=int,
        help=
        "automatically cuts off training if validation does not drop for (patience) epochs. Leave blank to have no validation based early stopping."
    )
    parser.add_argument("--val_split",
                        default=0.2,
                        type=float,
                        help="percent in validation data")
    parser.add_argument(
        "--withold_train_percent",
        default=0,
        type=float,
        help=
        "decimal from 0-1 representing how much of the training data to withold during finetuning"
    )
    parser.add_argument("--gpus",
                        default=1,
                        type=int,
                        help="number of gpus to use for training")
    parser.add_argument(
        "--eval",
        default=True,
        type=bool,
        help=
        "Eval Mode will train and evaluate the finetuned model's performance")
    parser.add_argument("--imagenet_weights",
                        default=False,
                        type=bool,
                        help="Use weights from a non-SSL")
    parser.add_argument("--version",
                        default="0",
                        type=str,
                        help="version to name checkpoint for saving")

    args = parser.parse_args()
    DATA_PATH = args.DATA_PATH
    batch_size = args.batch_size
    image_size = args.image_size
    image_type = args.image_type
    num_workers = args.num_workers
    embedding_size = args.image_embedding_size
    hidden_dims = args.hidden_dims
    epochs = args.epochs
    lr = args.lr
    patience = args.patience
    val_split = args.val_split
    withold_train_percent = args.withold_train_percent
    version = args.version
    model_checkpoint = args.MODEL_PATH
    gpus = args.gpus
    eval_model = args.eval
    version = args.version
    imagenet_weights = args.imagenet_weights

    # #testing
    # batch_size = 128
    # image_type = 'tif'
    # image_size = 256
    # num_workers = 4
    # DATA_PATH ='/content/UCMerced_LandUse/Images'
    # embedding_size = 128
    # epochs = 15
    # hidden_dims = 128
    # lr = 1e-3
    # patience = 1
    # val_split = 0.2
    # withold_train_percent = 0.2
    # model_checkpoint = '/content/models/SSL/SIMCLR_SSL_0/SIMCLR_SSL_0.pt'
    # gpus = 1
    # eval_model = True
    # version = "0"

    #gets dataset. We can't combine since validation data has different transform needed
    finetune_dataset = FolderDataset(
        DATA_PATH,
        validation=False,
        val_split=val_split,
        withold_train_percent=withold_train_percent,
        transform=SimCLRFinetuneTransform(image_size),
        image_type=image_type)

    finetune_loader = torch.utils.data.DataLoader(finetune_dataset,
                                                  batch_size=batch_size,
                                                  num_workers=num_workers,
                                                  drop_last=True)

    print('Training Data Loaded...')
    finetune_val_dataset = FolderDataset(
        DATA_PATH,
        validation=True,
        val_split=val_split,
        transform=SimCLRFinetuneTransform(image_size),
        image_type=image_type)

    finetune_val_loader = torch.utils.data.DataLoader(finetune_val_dataset,
                                                      batch_size=batch_size,
                                                      num_workers=num_workers,
                                                      drop_last=True)
    print('Validation Data Loaded...')

    num_samples = len(finetune_dataset)
    model = SimCLR(arch='resnet18',
                   batch_size=batch_size,
                   num_samples=num_samples,
                   gpus=gpus,
                   dataset='None',
                   max_epochs=epochs,
                   learning_rate=lr)  #
    model.encoder = resnet18(pretrained=imagenet_weights,
                             first_conv=model.first_conv,
                             maxpool1=model.maxpool1,
                             return_all_feature_maps=False)
    model.projection = Projection(input_dim=512,
                                  hidden_dim=256,
                                  output_dim=embedding_size)  #overrides

    if model_checkpoint is not None:
        model.load_state_dict(torch.load(model_checkpoint))
        print(
            'Successfully loaded your checkpoint. Keep in mind that this does not preserve the previous trainer states, only the model weights'
        )
    else:
        if imagenet_weights:
            print('Using imagenet weights instead of a pretrained SSL model')
        else:
            print('Using random initialization of encoder')

    num_classes = len(set(finetune_dataset.labels))
    print('Finetuning to classify ', num_classes, ' Classes')

    tuner = SSLFineTuner(model,
                         in_features=512,
                         num_classes=num_classes,
                         hidden_dim=hidden_dims,
                         learning_rate=lr)
    if patience > 0:
        cb = EarlyStopping('val_loss', patience=patience)
        trainer = Trainer(gpus=gpus,
                          max_epochs=epochs,
                          callbacks=[cb],
                          progress_bar_refresh_rate=5)
    else:
        trainer = Trainer(gpus=gpus,
                          max_epochs=epochs,
                          progress_bar_refresh_rate=5)
    tuner.cuda()
    trainer.fit(tuner,
                train_dataloader=finetune_loader,
                val_dataloaders=finetune_val_loader)

    Path(f"./models/Finetune/SIMCLR_Finetune_{version}").mkdir(parents=True,
                                                               exist_ok=True)

    if eval_model:
        print('Evaluating Model...')
        save_path = f"./models/Finetune/SIMCLR_Finetune_{version}/Evaluation/trainingMetrics"
        Path(save_path).mkdir(parents=True, exist_ok=True)
        eval_finetune(tuner, 'training', finetune_loader, save_path)

        save_path = f"./models/Finetune/SIMCLR_Finetune_{version}/Evaluation/validationMetrics"
        Path(save_path).mkdir(parents=True, exist_ok=True)
        eval_finetune(tuner, 'validation', finetune_val_loader, save_path)

    print('Saving model...')

    torch.save(
        tuner.state_dict(),
        f"./models/Finetune/SIMCLR_Finetune_{version}/SIMCLR_FINETUNE_{version}.pt"
    )
Beispiel #3
0
def cli_main():  # pragma: no cover
    from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule

    seed_everything(1234)

    parser = ArgumentParser()
    parser.add_argument("--dataset", type=str, help="cifar10, stl10, imagenet", default="cifar10")
    parser.add_argument("--ckpt_path", type=str, help="path to ckpt")
    parser.add_argument("--data_dir", type=str, help="path to dataset", default=os.getcwd())

    parser.add_argument("--batch_size", default=64, type=int, help="batch size per gpu")
    parser.add_argument("--num_workers", default=8, type=int, help="num of workers per GPU")
    parser.add_argument("--gpus", default=4, type=int, help="number of GPUs")
    parser.add_argument("--num_epochs", default=100, type=int, help="number of epochs")

    # fine-tuner params
    parser.add_argument("--in_features", type=int, default=2048)
    parser.add_argument("--dropout", type=float, default=0.0)
    parser.add_argument("--learning_rate", type=float, default=0.3)
    parser.add_argument("--weight_decay", type=float, default=1e-6)
    parser.add_argument("--nesterov", type=bool, default=False)  # fix nesterov flag here
    parser.add_argument("--scheduler_type", type=str, default="cosine")
    parser.add_argument("--gamma", type=float, default=0.1)
    parser.add_argument("--final_lr", type=float, default=0.0)

    args = parser.parse_args()

    if args.dataset == "cifar10":
        dm = CIFAR10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers)

        dm.train_transforms = SimCLRFinetuneTransform(
            normalize=cifar10_normalization(), input_height=dm.size()[-1], eval_transform=False
        )
        dm.val_transforms = SimCLRFinetuneTransform(
            normalize=cifar10_normalization(), input_height=dm.size()[-1], eval_transform=True
        )
        dm.test_transforms = SimCLRFinetuneTransform(
            normalize=cifar10_normalization(), input_height=dm.size()[-1], eval_transform=True
        )

        args.maxpool1 = False
        args.first_conv = False
        args.num_samples = 1
    elif args.dataset == "stl10":
        dm = STL10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers)

        dm.train_dataloader = dm.train_dataloader_labeled
        dm.val_dataloader = dm.val_dataloader_labeled
        args.num_samples = 1

        dm.train_transforms = SimCLRFinetuneTransform(
            normalize=stl10_normalization(), input_height=dm.size()[-1], eval_transform=False
        )
        dm.val_transforms = SimCLRFinetuneTransform(
            normalize=stl10_normalization(), input_height=dm.size()[-1], eval_transform=True
        )
        dm.test_transforms = SimCLRFinetuneTransform(
            normalize=stl10_normalization(), input_height=dm.size()[-1], eval_transform=True
        )

        args.maxpool1 = False
        args.first_conv = True
    elif args.dataset == "imagenet":
        dm = ImagenetDataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers)

        dm.train_transforms = SimCLRFinetuneTransform(
            normalize=imagenet_normalization(), input_height=dm.size()[-1], eval_transform=False
        )
        dm.val_transforms = SimCLRFinetuneTransform(
            normalize=imagenet_normalization(), input_height=dm.size()[-1], eval_transform=True
        )
        dm.test_transforms = SimCLRFinetuneTransform(
            normalize=imagenet_normalization(), input_height=dm.size()[-1], eval_transform=True
        )

        args.num_samples = 1
        args.maxpool1 = True
        args.first_conv = True
    else:
        raise NotImplementedError("other datasets have not been implemented till now")

    backbone = SimCLR(
        gpus=args.gpus,
        nodes=1,
        num_samples=args.num_samples,
        batch_size=args.batch_size,
        maxpool1=args.maxpool1,
        first_conv=args.first_conv,
        dataset=args.dataset,
    ).load_from_checkpoint(args.ckpt_path, strict=False)

    tuner = SSLFineTuner(
        backbone,
        in_features=args.in_features,
        num_classes=dm.num_classes,
        epochs=args.num_epochs,
        hidden_dim=None,
        dropout=args.dropout,
        learning_rate=args.learning_rate,
        weight_decay=args.weight_decay,
        nesterov=args.nesterov,
        scheduler_type=args.scheduler_type,
        gamma=args.gamma,
        final_lr=args.final_lr,
    )

    trainer = Trainer(
        gpus=args.gpus,
        num_nodes=1,
        precision=16,
        max_epochs=args.num_epochs,
        distributed_backend="ddp",
        sync_batchnorm=True if args.gpus > 1 else False,
    )

    trainer.fit(tuner, dm)
    trainer.test(datamodule=dm)
Beispiel #4
0
def cli_main():
    parser = ArgumentParser()
    parser.add_argument("--DATA_PATH", type=str, help="path to folders with images")
    parser.add_argument("--MODEL_PATH", default=None , type=str, help="path to model checkpoint")
    parser.add_argument("--batch_size", default=128, type=int, help="batch size for SSL")
    parser.add_argument("--image_size", default=256, type=int, help="image size for SSL")
    parser.add_argument("--num_workers", default=1, type=int, help="number of CPU cores to use for data processing")
    parser.add_argument("--image_embedding_size", default=128, type=int, help="size of image representation of SIMCLR")
    parser.add_argument("--hidden_dims", default=128, type=int, help="hidden dimensions in classification layer added onto model for finetuning")
    parser.add_argument("--epochs", default=200, type=int, help="number of epochs to train model")
    parser.add_argument("--lr", default=1e-3, type=float, help="learning rate for training model")
    parser.add_argument("--patience", default=-1, type=int, help="automatically cuts off training if validation does not drop for (patience) epochs. Leave blank to have no validation based early stopping.")
    parser.add_argument("--val_split", default=0.2, type=float, help="percent in validation data")
    parser.add_argument("--withold_train_percent", default=0, type=float, help="decimal from 0-1 representing how much of the training data to withold during finetuning")
    parser.add_argument("--gpus", default=1, type=int, help="number of gpus to use for training")
    parser.add_argument("--eval", default=True, type=bool, help="Eval Mode will train and evaluate the finetuned model's performance")
    parser.add_argument("--pretrain_encoder", default=False, type=bool, help="initialize resnet encoder with pretrained imagenet weights. Ignored if MODEL_PATH is specified.")
    parser.add_argument("--version", default="0", type=str, help="version to name checkpoint for saving")
    parser.add_argument("--fix_backbone", default=True, type=bool, help="Fix backbone during finetuning")
    
    args = parser.parse_args()
    URL = args.DATA_PATH
    batch_size = args.batch_size
    image_size = args.image_size
    num_workers = args.num_workers
    embedding_size = args.image_embedding_size
    hidden_dims = args.hidden_dims
    epochs = args.epochs
    lr = args.lr
    patience = args.patience
    val_split = args.val_split
    withold_train_percent = args.withold_train_percent
    version = args.version
    model_checkpoint = args.MODEL_PATH
    gpus = args.gpus
    eval_model = args.eval
    version = args.version
    pretrain= args.pretrain_encoder
    fix_backbone = args.fix_backbone

    train_transform = SimCLRFinetuneTransform(256, eval_transform=False)
    val_transform = SimCLRFinetuneTransform(256, eval_transform=True)
    dm = ImageDataModule(URL, train_transform = train_transform, val_transform = val_transform, val_split = val_split)
    dm.setup()

    #init model with batch size, num_samples (len of data), epochs to train, and autofinds learning rate
    model = SimCLR(arch = 'resnet18', batch_size = batch_size, num_samples = dm.num_samples, gpus = gpus, dataset = 'None', max_epochs = epochs, learning_rate = lr) #
    
    model.encoder = resnet18(pretrained=pretrain, first_conv=model.first_conv, maxpool1=model.maxpool1, return_all_feature_maps=False)
    model.projection = Projection(input_dim = 512, hidden_dim = 256, output_dim = embedding_size) #overrides
       
    if model_checkpoint is not None:  
        model.load_state_dict(torch.load(model_checkpoint))
        print('Successfully loaded your checkpoint. Keep in mind that this does not preserve the previous trainer states, only the model weights')
    else:
        if pretrain:   
            print('Using imagenet weights instead of a pretrained SSL model')
        else:
            print('Using random initialization of encoder')
        
    print('Finetuning to classify ', dm.num_classes, ' Classes')

        tuner = SSLFineTuner(
        backbone,
        in_features=512,
        num_classes=dm.num_classes,
        epochs=epochs,
        hidden_dim=None,
        dropout=0,
        learning_rate=0.3,
        weight_decay=1e-6,
        nesterov=False,
        scheduler_type='cosine',
        gamma=0.1,
        final_lr=0.,
        fix_backbone = fix_backbone
    )