Пример #1
0
def cli_main():
    from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
    from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule, ImagenetDataModule

    parser = ArgumentParser()

    # trainer args
    parser = pl.Trainer.add_argparse_args(parser)

    # model args
    parser = SimCLR.add_model_specific_args(parser)
    args = parser.parse_args()

    # init default datamodule
    if args.dataset == 'cifar10':
        dm = CIFAR10DataModule.from_argparse_args(args)
        dm.train_transforms = SimCLRTrainDataTransform(32)
        dm.val_transforms = SimCLREvalDataTransform(32)
        args.num_samples = dm.num_samples

    elif args.dataset == 'stl10':
        dm = STL10DataModule.from_argparse_args(args)
        dm.train_dataloader = dm.train_dataloader_mixed
        dm.val_dataloader = dm.val_dataloader_mixed
        args.num_samples = dm.num_unlabeled_samples

        (c, h, w) = dm.size()
        dm.train_transforms = SimCLRTrainDataTransform(h)
        dm.val_transforms = SimCLREvalDataTransform(h)

    elif args.dataset == 'imagenet2012':
        dm = ImagenetDataModule.from_argparse_args(args, image_size=196)
        (c, h, w) = dm.size()
        dm.train_transforms = SimCLRTrainDataTransform(h)
        dm.val_transforms = SimCLREvalDataTransform(h)

    model = SimCLR(**args.__dict__)

    # finetune in real-time
    def to_device(batch, device):
        (x1, x2), y = batch
        x1 = x1.to(device)
        y = y.to(device)
        return x1, y

    online_eval = SSLOnlineEvaluator(z_dim=2048 * 2 * 2,
                                     num_classes=dm.num_classes)
    online_eval.to_device = to_device

    trainer = pl.Trainer.from_argparse_args(args, callbacks=[online_eval])
    trainer.fit(model, dm)
def test_simclr(tmpdir, datadir):
    datamodule = CIFAR10DataModule(data_dir=datadir, num_workers=0, batch_size=2)
    datamodule.train_transforms = SimCLRTrainDataTransform(32)
    datamodule.val_transforms = SimCLREvalDataTransform(32)

    model = SimCLR(batch_size=2, num_samples=datamodule.num_samples, gpus=0, nodes=1, dataset='cifar10')
    trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir)
    trainer.fit(model, datamodule=datamodule)
Пример #3
0
def cli_main():  # pragma: no-cover
    from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule, ImagenetDataModule

    pl.seed_everything(1234)

    parser = ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser.add_argument('--dataset', type=str, help='stl10, cifar10', default='cifar10')
    parser.add_argument('--ckpt_path', type=str, help='path to ckpt')
    parser.add_argument('--data_dir', type=str, help='path to ckpt', default=os.getcwd())
    args = parser.parse_args()

    # load the backbone
    backbone = SimCLR.load_from_checkpoint(args.ckpt_path, strict=False)

    # init default datamodule
    if args.dataset == 'cifar10':
        dm = CIFAR10DataModule.from_argparse_args(args)
        dm.train_transforms = SimCLRTrainDataTransform(32)
        dm.val_transforms = SimCLREvalDataTransform(32)
        dm.test_transforms = SimCLREvalDataTransform(32)
        args.num_samples = dm.num_samples

    elif args.dataset == 'stl10':
        dm = STL10DataModule.from_argparse_args(args)
        dm.train_dataloader = dm.train_dataloader_labeled
        dm.val_dataloader = dm.val_dataloader_labeled
        args.num_samples = dm.num_labeled_samples

        (c, h, w) = dm.size()
        dm.train_transforms = SimCLRTrainDataTransform(h)
        dm.val_transforms = SimCLREvalDataTransform(h)

    elif args.dataset == 'imagenet2012':
        dm = ImagenetDataModule.from_argparse_args(args, image_size=196)
        (c, h, w) = dm.size()
        dm.train_transforms = SimCLRTrainDataTransform(h)
        dm.val_transforms = SimCLREvalDataTransform(h)

    # finetune
    tuner = SSLFineTuner(backbone, in_features=2048 * 2 * 2, num_classes=dm.num_classes, hidden_dim=None)
    trainer = pl.Trainer.from_argparse_args(args, early_stop_callback=True)
    trainer.fit(tuner, dm)

    trainer.test(datamodule=dm)
Пример #4
0
def test_simclr_transforms(img_size):
    pl.seed_everything(0)

    (c, h, w) = img_size
    x = torch.rand(c, h, w)
    x = transforms.ToPILImage(mode='RGB')(x)

    transform = SimCLREvalDataTransform(input_height=h)
    transform(x)

    transform = SimCLRTrainDataTransform(input_height=h)
    transform(x)
Пример #5
0
def test_simclr(tmpdir):
    seed_everything()

    datamodule = CIFAR10DataModule(tmpdir, num_workers=0, batch_size=2)
    datamodule.train_transforms = SimCLRTrainDataTransform(32)
    datamodule.val_transforms = SimCLREvalDataTransform(32)

    model = SimCLR(batch_size=2, num_samples=datamodule.num_samples)
    trainer = pl.Trainer(fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir)
    trainer.fit(model, datamodule)
    loss = trainer.progress_bar_dict['loss']

    assert float(loss) > 0
def test_simsiam(tmpdir, datadir):
    seed_everything()

    datamodule = CIFAR10DataModule(data_dir=datadir,
                                   num_workers=0,
                                   batch_size=2)
    datamodule.train_transforms = SimCLRTrainDataTransform(32)
    datamodule.val_transforms = SimCLREvalDataTransform(32)

    model = SimSiam(batch_size=2,
                    num_samples=datamodule.num_samples,
                    gpus=0,
                    nodes=1,
                    dataset='cifar10')
    trainer = pl.Trainer(gpus=0, fast_dev_run=True, default_root_dir=tmpdir)
    trainer.fit(model, datamodule=datamodule)
    loss = trainer.progress_bar_dict['loss']

    assert float(loss) < 0
def cli_main():
    from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
    from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule
    from pl_bolts.models.self_supervised.simclr.transforms import SimCLREvalDataTransform, SimCLRTrainDataTransform

    parser = ArgumentParser()

    # model args
    parser = SimCLR.add_model_specific_args(parser)
    parser = pl.Trainer.add_argparse_args(parser)
    args = parser.parse_args()

    if args.dataset == 'stl10':
        dm = STL10DataModule(data_dir=args.data_dir,
                             batch_size=args.batch_size,
                             num_workers=args.num_workers)

        dm.train_dataloader = dm.train_dataloader_mixed
        dm.val_dataloader = dm.val_dataloader_mixed
        args.num_samples = dm.num_unlabeled_samples

        args.maxpool1 = False
        args.first_conv = True
        args.input_height = dm.size()[-1]

        normalization = stl10_normalization()

        args.gaussian_blur = True
        args.jitter_strength = 1.
    elif args.dataset == 'cifar10':
        val_split = 5000
        if args.num_nodes * args.gpus * args.batch_size > val_split:
            val_split = args.num_nodes * args.gpus * args.batch_size

        dm = CIFAR10DataModule(data_dir=args.data_dir,
                               batch_size=args.batch_size,
                               num_workers=args.num_workers,
                               val_split=val_split)

        args.num_samples = dm.num_samples

        args.maxpool1 = False
        args.first_conv = False
        args.input_height = dm.size()[-1]
        args.temperature = 0.5

        normalization = cifar10_normalization()

        args.gaussian_blur = False
        args.jitter_strength = 0.5
    elif args.dataset == 'imagenet':
        args.maxpool1 = True
        args.first_conv = True
        normalization = imagenet_normalization()

        args.gaussian_blur = True
        args.jitter_strength = 1.

        args.batch_size = 64
        args.num_nodes = 8
        args.gpus = 8  # per-node
        args.max_epochs = 800

        args.optimizer = 'sgd'
        args.lars_wrapper = True
        args.learning_rate = 4.8
        args.final_lr = 0.0048
        args.start_lr = 0.3
        args.online_ft = True

        dm = ImagenetDataModule(data_dir=args.data_dir,
                                batch_size=args.batch_size,
                                num_workers=args.num_workers)

        args.num_samples = dm.num_samples
        args.input_height = dm.size()[-1]
    else:
        raise NotImplementedError(
            "other datasets have not been implemented till now")

    dm.train_transforms = SimCLRTrainDataTransform(
        input_height=args.input_height,
        gaussian_blur=args.gaussian_blur,
        jitter_strength=args.jitter_strength,
        normalize=normalization,
    )

    dm.val_transforms = SimCLREvalDataTransform(
        input_height=args.input_height,
        gaussian_blur=args.gaussian_blur,
        jitter_strength=args.jitter_strength,
        normalize=normalization,
    )

    model = SimCLR(**args.__dict__)

    online_evaluator = None
    if args.online_ft:
        # online eval
        online_evaluator = SSLOnlineEvaluator(drop_p=0.,
                                              hidden_dim=None,
                                              z_dim=args.hidden_mlp,
                                              num_classes=dm.num_classes,
                                              dataset=args.dataset)

    model_checkpoint = ModelCheckpoint(save_last=True,
                                       save_top_k=1,
                                       monitor='val_loss')
    callbacks = [model_checkpoint, online_evaluator
                 ] if args.online_ft else [model_checkpoint]

    trainer = pl.Trainer.from_argparse_args(
        args,
        sync_batchnorm=True if args.gpus > 1 else False,
        callbacks=callbacks,
    )

    trainer.fit(model, datamodule=dm)
def cli_main():

    parser = ArgumentParser()
    parser.add_argument(
        "--MODEL_PATH",
        type=str,
        help="path to .pt file containing SSL-trained SimCLR Resnet18 Model")
    parser.add_argument(
        "--DATA_PATH",
        type=str,
        help=
        "path to data. If folder already contains validation data only, set val_split to 0"
    )
    parser.add_argument(
        "--val_split",
        default=0.2,
        type=float,
        help="amount of data to use for validation as a decimal")
    parser.add_argument("--image_embedding_size",
                        default=128,
                        type=int,
                        help="size of image representation of SIMCLR")
    parser.add_argument("--image_size",
                        default=128,
                        type=int,
                        help="height of square image to pass through model")
    parser.add_argument("--gpus",
                        default=1,
                        type=int,
                        help="number of gpus to use for training")
    parser.add_argument("--rank",
                        default=50,
                        type=int,
                        help="number of neighbors to search for")
    parser.add_argument("--batch_size",
                        default=128,
                        type=int,
                        help="batch size for Evaluation")
    parser.add_argument(
        "--pretrain_encoder",
        default=False,
        type=bool,
        help=
        "initialize resnet encoder with pretrained imagenet weights. Will be ignored if MODEL_PATH is specified."
    )

    args = parser.parse_args()
    MODEL_PATH = args.MODEL_PATH
    URL = args.DATA_PATH
    image_size = args.image_size
    embedding_size = args.image_embedding_size
    val_split = args.val_split
    gpus = args.gpus
    rank_to = args.rank
    batch_size = args.batch_size
    pretrain = args.pretrain_encoder

    train_transform = SimCLRTrainDataTransform(256)
    val_transform = SimCLREvalDataTransform(256)
    dm = ImageDataModule(URL,
                         train_transform=train_transform,
                         val_transform=val_transform,
                         val_split=val_split)
    dm.setup()

    #init model with batch size, num_samples (len of data), epochs to train, and autofinds learning rate
    model = SimCLR(arch='resnet18',
                   batch_size=batch_size,
                   num_samples=dm.num_samples,
                   gpus=gpus,
                   dataset='None')  #

    model.encoder = resnet18(pretrained=pretrain,
                             first_conv=model.first_conv,
                             maxpool1=model.maxpool1,
                             return_all_feature_maps=False)
    model.projection = Projection(input_dim=512,
                                  hidden_dim=256,
                                  output_dim=embedding_size)  #overrides

    model.cuda()
    print('Successfully loaded your model for evaluation.')

    #running eval on validation data
    save_path = f"{MODEL_PATH[:-3]}/Evaluation/validationMetrics"
    Path(save_path).mkdir(parents=True, exist_ok=True)
    if dm.val_dataloader() is not None:
        eval_embeddings(model, dm.val_dataloader(), save_path, rank_to)
        print('Validation Data Evaluation Complete.')

    #running eval on training data
    save_path = f"{MODEL_PATH[:-3]}/Evaluation/trainingMetrics"
    Path(save_path).mkdir(parents=True, exist_ok=True)
    eval_embeddings(model, dm.train_dataloader(), save_path, rank_to)
    print('Training Data Evaluation Complete.')

    print(f'Please check {MODEL_PATH[:-3]}/Evaluation/ for your results')
Пример #9
0
import pytorch_lightning as pl
from pl_bolts.models.self_supervised import BYOL
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.models.self_supervised.simclr.transforms import (
    SimCLREvalDataTransform, SimCLRTrainDataTransform)

# model
model = BYOL(num_classes=10)

# data
dm = CIFAR10DataModule(num_workers=4)
dm.train_transforms = SimCLRTrainDataTransform(32)
dm.val_transforms = SimCLREvalDataTransform(32)

trainer = pl.Trainer()
trainer.fit(model, dm)
def cli_main():

    parser = ArgumentParser()
    parser.add_argument("--DATA_PATH",
                        type=str,
                        help="path to folders with images")
    parser.add_argument("--MODEL_PATH",
                        default=None,
                        type=str,
                        help="path to model checkpoint")
    parser.add_argument("--batch_size",
                        default=128,
                        type=int,
                        help="batch size for SSL")
    parser.add_argument("--image_size",
                        default=256,
                        type=int,
                        help="image size for SSL")
    parser.add_argument("--num_workers",
                        default=1,
                        type=int,
                        help="number of CPU cores to use for data processing")
    parser.add_argument("--image_embedding_size",
                        default=128,
                        type=int,
                        help="size of image representation of SIMCLR")
    parser.add_argument("--epochs",
                        default=200,
                        type=int,
                        help="number of epochs to train model")
    parser.add_argument("--lr",
                        default=1e-3,
                        type=float,
                        help="learning rate for training model")
    parser.add_argument(
        "--patience",
        default=-1,
        type=int,
        help=
        "automatically cuts off training if validation does not drop for (patience) epochs. Leave blank to have no validation based early stopping."
    )
    parser.add_argument("--val_split",
                        default=0.2,
                        type=float,
                        help="percent in validation data")
    parser.add_argument(
        "--pretrain_encoder",
        default=False,
        type=bool,
        help=
        "initialize resnet encoder with pretrained imagenet weights. Cannot be true if passing previous SSL model checkpoint."
    )
    parser.add_argument(
        "--withold_train_percent",
        default=0,
        type=float,
        help=
        "decimal from 0-1 representing how much of the training data to withold during SSL training"
    )
    parser.add_argument("--version",
                        default="0",
                        type=str,
                        help="version to name checkpoint for saving")
    parser.add_argument("--gpus",
                        default=1,
                        type=int,
                        help="number of gpus to use for training")
    parser.add_argument("--num_workers",
                        default=0,
                        type=int,
                        help="number of workers to use to fetch data")

    args = parser.parse_args()
    URL = args.DATA_PATH
    batch_size = args.batch_size
    image_size = args.image_size
    num_workers = args.num_workers
    embedding_size = args.image_embedding_size
    epochs = args.epochs
    lr = args.lr
    patience = args.patience
    val_split = args.val_split
    pretrain = args.pretrain_encoder
    withold_train_percent = args.withold_train_percent
    version = args.version
    model_checkpoint = args.MODEL_PATH
    gpus = args.gpus
    num_workers = args.num_workers

    train_transform = SimCLRTrainDataTransform(256)
    val_transform = SimCLREvalDataTransform(256)
    dm = ImageDataModule(URL,
                         train_transform=train_transform,
                         val_transform=val_transform,
                         val_split=val_split,
                         num_workers=num_workers)
    dm.setup()

    #init model with batch size, num_samples (len of data), epochs to train, and autofinds learning rate
    model = SimCLR(arch='resnet18',
                   batch_size=batch_size,
                   num_samples=dm.num_samples,
                   gpus=gpus,
                   dataset='None',
                   max_epochs=epochs,
                   learning_rate=lr)  #

    model.encoder = resnet18(pretrained=pretrain,
                             first_conv=model.first_conv,
                             maxpool1=model.maxpool1,
                             return_all_feature_maps=False)
    model.projection = Projection(input_dim=512,
                                  hidden_dim=256,
                                  output_dim=embedding_size)  #overrides

    if patience > 0:
        cb = EarlyStopping('val_loss', patience=patience)
        trainer = Trainer(gpus=gpus,
                          max_epochs=epochs,
                          callbacks=[cb],
                          progress_bar_refresh_rate=5)
    else:
        trainer = Trainer(gpus=gpus,
                          max_epochs=epochs,
                          progress_bar_refresh_rate=5)

    if model_checkpoint is not None:
        model.load_state_dict(torch.load(model_checkpoint))
        print(
            'Successfully loaded your checkpoint. Keep in mind that this does not preserve the previous trainer states, only the model weights'
        )

    model.cuda()

    print('Model Initialized')
    trainer.fit(model, dm)

    Path(f"./models/SSL/SIMCLR_SSL_{version}").mkdir(parents=True,
                                                     exist_ok=True)
    torch.save(model.state_dict(),
               f"./models/SSL/SIMCLR_SSL_{version}/SIMCLR_SSL_{version}.pt")
Пример #11
0
def cli_main():

    parser = ArgumentParser()
    parser.add_argument("--DATA_PATH",
                        type=str,
                        help="path to folders with images")
    parser.add_argument("--MODEL_PATH",
                        default=None,
                        type=str,
                        help="path to model checkpoint")
    parser.add_argument("--batch_size",
                        default=128,
                        type=int,
                        help="batch size for SSL")
    parser.add_argument("--image_size",
                        default=256,
                        type=int,
                        help="image size for SSL")
    parser.add_argument(
        "--image_type",
        default="tif",
        type=str,
        help=
        "extension of image for PIL to open and parse - i.e. jpeg, gif, tif, etc. Only put the extension name, not the dot (.)"
    )
    parser.add_argument("--num_workers",
                        default=1,
                        type=int,
                        help="number of CPU cores to use for data processing")
    parser.add_argument("--image_embedding_size",
                        default=128,
                        type=int,
                        help="size of image representation of SIMCLR")
    parser.add_argument("--epochs",
                        default=200,
                        type=int,
                        help="number of epochs to train model")
    parser.add_argument("--lr",
                        default=1e-3,
                        type=float,
                        help="learning rate for training model")
    parser.add_argument(
        "--patience",
        default=-1,
        type=int,
        help=
        "automatically cuts off training if validation does not drop for (patience) epochs. Leave blank to have no validation based early stopping."
    )
    parser.add_argument("--val_split",
                        default=0.2,
                        type=float,
                        help="percent in validation data")
    parser.add_argument(
        "--pretrain_encoder",
        default=False,
        type=bool,
        help=
        "initialize resnet encoder with pretrained imagenet weights. Cannot be true if passing previous SSL model checkpoint."
    )
    parser.add_argument(
        "--withold_train_percent",
        default=0,
        type=float,
        help=
        "decimal from 0-1 representing how much of the training data to withold during SSL training"
    )
    parser.add_argument("--version",
                        default="0",
                        type=str,
                        help="version to name checkpoint for saving")
    parser.add_argument("--gpus",
                        default=1,
                        type=int,
                        help="number of gpus to use for training")

    args = parser.parse_args()
    URL = args.DATA_PATH
    batch_size = args.batch_size
    image_size = args.image_size
    image_type = args.image_type
    num_workers = args.num_workers
    embedding_size = args.image_embedding_size
    epochs = args.epochs
    lr = args.lr
    patience = args.patience
    val_split = args.val_split
    pretrain = args.pretrain_encoder
    withold_train_percent = args.withold_train_percent
    version = args.version
    model_checkpoint = args.MODEL_PATH
    gpus = args.gpus

    # #testing
    # batch_size = 128
    # image_type = 'tif'
    # image_size = 256
    # num_workers = 4
    # URL ='/content/UCMerced_LandUse/Images'
    # embedding_size = 128
    # epochs = 2
    # lr = 1e-3
    # patience = 1
    # val_split = 0.2
    # pretrain = False
    # withold_train_percent = 0.2
    # version = "1"
    # model_checkpoint = '/content/models/SSL/SIMCLR_SSL_0.pt'
    # gpus = 1

    # #gets dataset. We can't combine since validation data has different transform needed
    train_dataset = FolderDataset(
        URL,
        validation=False,
        val_split=val_split,
        withold_train_percent=withold_train_percent,
        transform=SimCLRTrainDataTransform(image_size),
        image_type=image_type)

    data_loader = torch.utils.data.DataLoader(train_dataset,
                                              batch_size=batch_size,
                                              num_workers=num_workers,
                                              drop_last=True)

    print('Training Data Loaded...')
    val_dataset = FolderDataset(URL,
                                validation=True,
                                val_split=val_split,
                                transform=SimCLREvalDataTransform(image_size),
                                image_type=image_type)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             num_workers=num_workers,
                                             drop_last=True)
    print('Validation Data Loaded...')

    num_samples = len(train_dataset)

    #init model with batch size, num_samples (len of data), epochs to train, and autofinds learning rate
    model = SimCLR(arch='resnet18',
                   batch_size=batch_size,
                   num_samples=num_samples,
                   gpus=gpus,
                   dataset='None',
                   max_epochs=epochs,
                   learning_rate=lr)  #

    model.encoder = resnet18(pretrained=pretrain,
                             first_conv=model.first_conv,
                             maxpool1=model.maxpool1,
                             return_all_feature_maps=False)
    model.projection = Projection(input_dim=512,
                                  hidden_dim=256,
                                  output_dim=embedding_size)  #overrides

    if patience > 0:
        cb = EarlyStopping('val_loss', patience=patience)
        trainer = Trainer(gpus=gpus,
                          max_epochs=epochs,
                          callbacks=[cb],
                          progress_bar_refresh_rate=5)
    else:
        trainer = Trainer(gpus=gpus,
                          max_epochs=epochs,
                          progress_bar_refresh_rate=5)

    if model_checkpoint is not None:
        model.load_state_dict(torch.load(model_checkpoint))
        print(
            'Successfully loaded your checkpoint. Keep in mind that this does not preserve the previous trainer states, only the model weights'
        )

    model.cuda()

    print('Model Initialized')
    trainer.fit(model, data_loader, val_loader)

    Path(f"./models/SSL/SIMCLR_SSL_{version}").mkdir(parents=True,
                                                     exist_ok=True)
    torch.save(model.state_dict(),
               f"./models/SSL/SIMCLR_SSL_{version}/SIMCLR_SSL_{version}.pt")
Пример #12
0
def main():

    parser = ArgumentParser()
    parser.add_argument("--data_dir",
                        type=str,
                        required=True,
                        help="path to the folder of images")
    parser.add_argument("--log_dir",
                        type=str,
                        required=True,
                        help="output training logging dir")
    parser.add_argument("--learning_rate",
                        type=float,
                        required=True,
                        default=1e-3,
                        help="learning rate")
    parser.add_argument(
        "--input_height",
        type=int,
        required=True,
        help="height of image input to SimCLR",
    )
    parser.add_argument("--batch_size", type=int, default=1024, required=True)
    parser.add_argument("--gpus",
                        type=int,
                        default=0,
                        required=True,
                        help="Number of GPUs")
    parser.add_argument("--num_workers",
                        type=int,
                        default=0,
                        required=True,
                        help="Number of dataloader workers")
    parser.add_argument("--max_epochs",
                        default=100,
                        type=int,
                        help="number of total epochs to run")

    args = parser.parse_args()

    dm = SneakerDataModule(image_folder=args.data_dir,
                           batch_size=args.batch_size,
                           num_workers=args.num_workers)
    dm.train_transforms = SimCLRTrainDataTransform(args.input_height)
    dm.val_transforms = SimCLREvalDataTransform(args.input_height)

    model = SimCLR(
        num_samples=dm.num_samples,
        batch_size=dm.batch_size,
        learning_rate=args.learning_rate,
        max_epochs=args.max_epochs,
        gpus=args.gpus,
        dataset="sneakers",
    )

    model_checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        save_last=True,
        save_top_k=-1,
        period=10,
        filename='{epoch}-{val_loss:.2f}-{step}')

    # TODO set the logger folder
    # Warning message is "Missing logger folder: /lightning_logs"
    trainer = pl.Trainer(
        default_root_dir=args.log_dir,
        callbacks=[model_checkpoint_callback],
        # checkpoint_callback=True,  # configures a default checkpointing callback
        max_epochs=args.max_epochs,
        gpus=args.gpus,
        accelerator='ddp' if args.gpus > 1 else None,
        enable_pl_optimizer=True if args.gpus > 1 else False,
    )

    trainer.fit(model, dm)
Пример #13
0
def cli_main():

    parser = ArgumentParser()
    parser.add_argument(
        "--MODEL_PATH",
        type=str,
        help="path to .pt file containing SSL-trained SimCLR Resnet18 Model")
    parser.add_argument(
        "--DATA_PATH",
        type=str,
        help=
        "path to data. If folder already contains validation data only, set val_split to 0"
    )
    parser.add_argument(
        "--val_split",
        default=0.2,
        type=float,
        help="amount of data to use for validation as a decimal")
    parser.add_argument(
        "--image_type",
        default="tif",
        type=str,
        help=
        "extension of image for PIL to open and parse - i.e. jpeg, gif, tif, etc. Only put the extension name, not the dot (.)"
    )
    parser.add_argument("--image_embedding_size",
                        default=128,
                        type=int,
                        help="size of image representation of SIMCLR")
    parser.add_argument("--image_size",
                        default=128,
                        type=int,
                        help="height of square image to pass through model")
    parser.add_argument("--gpus",
                        default=1,
                        type=int,
                        help="number of gpus to use for training")
    parser.add_argument("--rank",
                        default=50,
                        type=int,
                        help="number of neighbors to search for")
    parser.add_argument(
        "--filter_same_group",
        default=False,
        type=bool,
        help="custom arg for hurricane data to filter same hurricanes out")

    args = parser.parse_args()
    MODEL_PATH = args.MODEL_PATH
    DATA_PATH = args.DATA_PATH
    image_size = args.image_size
    image_type = args.image_type
    embedding_size = args.image_embedding_size
    val_split = args.val_split
    gpus = args.gpus
    rank_to = args.rank
    filter_hur = args.filter_same_group

    #testing
    # MODEL_PATH = '/content/models/SSL/SIMCLR_SSL_0.pt'
    # DATA_PATH = '/content/UCMerced_LandUse/Images'
    # image_size = 128
    # image_type = 'tif'
    # embedding_size = 128
    # val_split = 0.2
    # gpus = 1

    # #gets dataset. We can't combine since validation data has different transform needed
    train_dataset = FolderDataset(
        DATA_PATH,
        validation=False,
        val_split=val_split,
        transform=SimCLRTrainDataTransform(image_size),
        image_type=image_type)

    print('Training Data Loaded...')
    val_dataset = FolderDataset(DATA_PATH,
                                validation=True,
                                val_split=val_split,
                                transform=SimCLREvalDataTransform(image_size),
                                image_type=image_type)

    print('Validation Data Loaded...')

    #load model
    num_samples = len(train_dataset)

    #init model with batch size, num_samples (len of data), epochs to train, and autofinds learning rate
    model = SimCLR(arch='resnet18',
                   batch_size=1,
                   num_samples=num_samples,
                   gpus=gpus,
                   dataset='None')  #

    model.encoder = resnet18(pretrained=False,
                             first_conv=model.first_conv,
                             maxpool1=model.maxpool1,
                             return_all_feature_maps=False)
    model.projection = Projection(input_dim=512,
                                  hidden_dim=256,
                                  output_dim=embedding_size)  #overrides

    model.load_state_dict(torch.load(MODEL_PATH))

    model.cuda()
    print('Successfully loaded your model for evaluation.')

    #running eval on validation data
    save_path = f"{MODEL_PATH[:-3]}/Evaluation/validationMetrics"
    Path(save_path).mkdir(parents=True, exist_ok=True)
    eval_embeddings(model, val_dataset, save_path, rank_to, filter_hur)
    print('Validation Data Evaluation Complete.')

    #running eval on training data
    save_path = f"{MODEL_PATH[:-3]}/Evaluation/trainingMetrics"
    Path(save_path).mkdir(parents=True, exist_ok=True)
    eval_embeddings(model, train_dataset, save_path, rank_to, filter_hur)
    print('Training Data Evaluation Complete.')

    print(f'Please check {MODEL_PATH[:-3]}/Evaluation/ for your results')