コード例 #1
0
def cli_main():
    from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
    from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule, ImagenetDataModule
    from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform, SimCLREvalDataTransform

    seed_everything(1234)

    parser = ArgumentParser()

    # trainer args
    parser = pl.Trainer.add_argparse_args(parser)

    # model args
    parser = BYOL.add_model_specific_args(parser)
    args = parser.parse_args()

    # pick data
    dm = None

    # init default datamodule
    if args.dataset == 'cifar10':
        dm = CIFAR10DataModule.from_argparse_args(args)
        dm.train_transforms = SimCLRTrainDataTransform(32)
        dm.val_transforms = SimCLREvalDataTransform(32)
        args.num_classes = dm.num_classes

    elif args.dataset == 'stl10':
        dm = STL10DataModule.from_argparse_args(args)
        dm.train_dataloader = dm.train_dataloader_mixed
        dm.val_dataloader = dm.val_dataloader_mixed

        (c, h, w) = dm.size()
        dm.train_transforms = SimCLRTrainDataTransform(h)
        dm.val_transforms = SimCLREvalDataTransform(h)
        args.num_classes = dm.num_classes

    elif args.dataset == 'imagenet2012':
        dm = ImagenetDataModule.from_argparse_args(args, image_size=196)
        (c, h, w) = dm.size()
        dm.train_transforms = SimCLRTrainDataTransform(h)
        dm.val_transforms = SimCLREvalDataTransform(h)
        args.num_classes = dm.num_classes

    model = BYOL(**args.__dict__)

    def to_device(batch, device):
        (x1, x2), y = batch
        x1 = x1.to(device)
        y = y.to(device)
        return x1, y

    # finetune in real-time
    online_eval = SSLOnlineEvaluator(z_dim=2048, num_classes=dm.num_classes)
    online_eval.to_device = to_device

    trainer = pl.Trainer.from_argparse_args(args,
                                            max_steps=300000,
                                            callbacks=[online_eval])

    trainer.fit(model, dm)
コード例 #2
0
def cli_main():
    from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
    from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule
    from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule

    pl.seed_everything(1234)
    parser = ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser = CPCV2.add_model_specific_args(parser)
    parser.add_argument('--dataset', default='cifar10', type=str)
    parser.add_argument('--data_dir', default='.', type=str)
    parser.add_argument('--meta_dir',
                        default='.',
                        type=str,
                        help='path to meta.bin for imagenet')
    parser.add_argument('--num_workers', default=8, type=int)
    parser.add_argument('--batch_size', type=int, default=128)

    args = parser.parse_args()

    datamodule = None

    online_evaluator = SSLOnlineEvaluator()
    if args.dataset == 'cifar10':
        datamodule = CIFAR10DataModule.from_argparse_args(args)
        datamodule.train_transforms = CPCTrainTransformsCIFAR10()
        datamodule.val_transforms = CPCEvalTransformsCIFAR10()
        args.patch_size = 8

    elif args.dataset == 'stl10':
        datamodule = STL10DataModule.from_argparse_args(args)
        datamodule.train_dataloader = datamodule.train_dataloader_mixed
        datamodule.val_dataloader = datamodule.val_dataloader_mixed
        datamodule.train_transforms = CPCTrainTransformsSTL10()
        datamodule.val_transforms = CPCEvalTransformsSTL10()
        args.patch_size = 16

        # 16 GB RAM - 64
        # 32 GB RAM - 144
        args.batch_size = 144

        def to_device(batch, device):
            (_, _), (x2, y2) = batch
            x2 = x2.to(device)
            y2 = y2.to(device)
            return x2, y2

        online_evaluator.to_device = to_device

    elif args.dataset == 'imagenet2012':
        datamodule = SSLImagenetDataModule.from_argparse_args(args)
        datamodule.train_transforms = CPCTrainTransformsImageNet128()
        datamodule.val_transforms = CPCEvalTransformsImageNet128()
        args.patch_size = 32

    model = CPCV2(**vars(args))
    trainer = pl.Trainer.from_argparse_args(args, callbacks=[online_evaluator])
    trainer.fit(model, datamodule)
コード例 #3
0
def cli_main():
    from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
    from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule, ImagenetDataModule

    parser = ArgumentParser()

    # trainer args
    parser = pl.Trainer.add_argparse_args(parser)

    # model args
    parser = SimCLR.add_model_specific_args(parser)
    args = parser.parse_args()

    # init default datamodule
    if args.dataset == 'cifar10':
        dm = CIFAR10DataModule.from_argparse_args(args)
        dm.train_transforms = SimCLRTrainDataTransform(32)
        dm.val_transforms = SimCLREvalDataTransform(32)
        args.num_samples = dm.num_samples

    elif args.dataset == 'stl10':
        dm = STL10DataModule.from_argparse_args(args)
        dm.train_dataloader = dm.train_dataloader_mixed
        dm.val_dataloader = dm.val_dataloader_mixed
        args.num_samples = dm.num_unlabeled_samples

        (c, h, w) = dm.size()
        dm.train_transforms = SimCLRTrainDataTransform(h)
        dm.val_transforms = SimCLREvalDataTransform(h)

    elif args.dataset == 'imagenet2012':
        dm = ImagenetDataModule.from_argparse_args(args, image_size=196)
        (c, h, w) = dm.size()
        dm.train_transforms = SimCLRTrainDataTransform(h)
        dm.val_transforms = SimCLREvalDataTransform(h)

    model = SimCLR(**args.__dict__)

    # finetune in real-time
    def to_device(batch, device):
        (x1, x2), y = batch
        x1 = x1.to(device)
        y = y.to(device)
        return x1, y

    online_eval = SSLOnlineEvaluator(z_dim=2048 * 2 * 2,
                                     num_classes=dm.num_classes)
    online_eval.to_device = to_device

    trainer = pl.Trainer.from_argparse_args(args, callbacks=[online_eval])
    trainer.fit(model, dm)
コード例 #4
0
def cli_main():
    from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
    from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule
    from pl_bolts.models.self_supervised.swav.transforms import SwAVEvalDataTransform, SwAVTrainDataTransform

    parser = ArgumentParser()

    # model args
    parser = SwAV.add_model_specific_args(parser)
    args = parser.parse_args()

    if args.dataset == 'stl10':
        dm = STL10DataModule(data_dir=args.data_dir,
                             batch_size=args.batch_size,
                             num_workers=args.num_workers)

        dm.train_dataloader = dm.train_dataloader_mixed
        dm.val_dataloader = dm.val_dataloader_mixed
        args.num_samples = dm.num_unlabeled_samples

        args.maxpool1 = False

        normalization = stl10_normalization()
    elif args.dataset == 'cifar10':
        args.batch_size = 2
        args.num_workers = 0

        dm = CIFAR10DataModule(data_dir=args.data_dir,
                               batch_size=args.batch_size,
                               num_workers=args.num_workers)

        args.num_samples = dm.num_samples

        args.maxpool1 = False
        args.first_conv = False

        normalization = cifar10_normalization()

        # cifar10 specific params
        args.size_crops = [32, 16]
        args.nmb_crops = [2, 1]
        args.gaussian_blur = False
    elif args.dataset == 'imagenet':
        args.maxpool1 = True
        args.first_conv = True
        normalization = imagenet_normalization()

        args.size_crops = [224, 96]
        args.nmb_crops = [2, 6]
        args.min_scale_crops = [0.14, 0.05]
        args.max_scale_crops = [1., 0.14]
        args.gaussian_blur = True
        args.jitter_strength = 1.

        args.batch_size = 64
        args.num_nodes = 8
        args.gpus = 8  # per-node
        args.max_epochs = 800

        args.optimizer = 'lars'
        args.learning_rate = 4.8
        args.final_lr = 0.0048
        args.start_lr = 0.3

        args.nmb_prototypes = 3000
        args.online_ft = True

        dm = ImagenetDataModule(data_dir=args.data_dir,
                                batch_size=args.batch_size,
                                num_workers=args.num_workers)

        args.num_samples = dm.num_samples
        args.input_height = dm.size()[-1]
    else:
        raise NotImplementedError(
            "other datasets have not been implemented till now")

    dm.train_transforms = SwAVTrainDataTransform(
        normalize=normalization,
        size_crops=args.size_crops,
        nmb_crops=args.nmb_crops,
        min_scale_crops=args.min_scale_crops,
        max_scale_crops=args.max_scale_crops,
        gaussian_blur=args.gaussian_blur,
        jitter_strength=args.jitter_strength)

    dm.val_transforms = SwAVEvalDataTransform(
        normalize=normalization,
        size_crops=args.size_crops,
        nmb_crops=args.nmb_crops,
        min_scale_crops=args.min_scale_crops,
        max_scale_crops=args.max_scale_crops,
        gaussian_blur=args.gaussian_blur,
        jitter_strength=args.jitter_strength)

    # swav model init
    model = SwAV(**args.__dict__)

    online_evaluator = None
    if args.online_ft:
        # online eval
        online_evaluator = SSLOnlineEvaluator(
            drop_p=0.,
            hidden_dim=None,
            z_dim=args.hidden_mlp,
            num_classes=dm.num_classes,
            dataset=args.dataset,
        )

    lr_monitor = LearningRateMonitor(logging_interval="step")
    model_checkpoint = ModelCheckpoint(save_last=True,
                                       save_top_k=1,
                                       monitor='val_loss')
    callbacks = [model_checkpoint, online_evaluator
                 ] if args.online_ft else [model_checkpoint]
    callbacks.append(lr_monitor)

    trainer = Trainer(
        max_epochs=args.max_epochs,
        max_steps=None if args.max_steps == -1 else args.max_steps,
        gpus=args.gpus,
        num_nodes=args.num_nodes,
        distributed_backend='ddp' if args.gpus > 1 else None,
        sync_batchnorm=True if args.gpus > 1 else False,
        precision=32 if args.fp32 else 16,
        callbacks=callbacks,
        fast_dev_run=args.fast_dev_run)

    trainer.fit(model, datamodule=dm)
コード例 #5
0
def cli_main():
    from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
    from pl_bolts.datamodules import CIFAR10DataModule
    from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule

    seed_everything(1234)
    parser = ArgumentParser()
    parser = Trainer.add_argparse_args(parser)
    parser = CPC_v2.add_model_specific_args(parser)
    parser.add_argument("--dataset", default="cifar10", type=str)
    parser.add_argument("--data_dir", default=".", type=str)
    parser.add_argument("--meta_dir",
                        default=".",
                        type=str,
                        help="path to meta.bin for imagenet")
    parser.add_argument("--num_workers", default=8, type=int)
    parser.add_argument("--hidden_mlp",
                        default=2048,
                        type=int,
                        help="hidden layer dimension in projection head")
    parser.add_argument("--batch_size", type=int, default=128)

    args = parser.parse_args()

    datamodule = None
    if args.dataset == "cifar10":
        datamodule = CIFAR10DataModule.from_argparse_args(args)
        datamodule.train_transforms = CPCTrainTransformsCIFAR10()
        datamodule.val_transforms = CPCEvalTransformsCIFAR10()
        args.patch_size = 8

    elif args.dataset == "stl10":
        datamodule = STL10DataModule.from_argparse_args(args)
        datamodule.train_dataloader = datamodule.train_dataloader_mixed
        datamodule.val_dataloader = datamodule.val_dataloader_mixed
        datamodule.train_transforms = CPCTrainTransformsSTL10()
        datamodule.val_transforms = CPCEvalTransformsSTL10()
        args.patch_size = 16

    elif args.dataset == "imagenet2012":
        datamodule = SSLImagenetDataModule.from_argparse_args(args)
        datamodule.train_transforms = CPCTrainTransformsImageNet128()
        datamodule.val_transforms = CPCEvalTransformsImageNet128()
        args.patch_size = 32

    online_evaluator = SSLOnlineEvaluator(
        drop_p=0.0,
        hidden_dim=None,
        z_dim=args.hidden_mlp,
        num_classes=datamodule.num_classes,
        dataset=args.dataset,
    )
    if args.dataset == "stl10":
        # 16 GB RAM - 64
        # 32 GB RAM - 144
        args.batch_size = 144

        def to_device(batch, device):
            (_, _), (x2, y2) = batch
            x2 = x2.to(device)
            y2 = y2.to(device)
            return x2, y2

        online_evaluator.to_device = to_device

    model = CPC_v2(**vars(args))
    trainer = Trainer.from_argparse_args(args, callbacks=[online_evaluator])
    trainer.fit(model, datamodule=datamodule)
コード例 #6
0
def cli_main():
    from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
    from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule
    from pl_bolts.models.self_supervised.simclr.transforms import SimCLREvalDataTransform, SimCLRTrainDataTransform

    parser = ArgumentParser()

    # model args
    parser = SimCLR.add_model_specific_args(parser)
    parser = pl.Trainer.add_argparse_args(parser)
    args = parser.parse_args()

    if args.dataset == 'stl10':
        dm = STL10DataModule(data_dir=args.data_dir,
                             batch_size=args.batch_size,
                             num_workers=args.num_workers)

        dm.train_dataloader = dm.train_dataloader_mixed
        dm.val_dataloader = dm.val_dataloader_mixed
        args.num_samples = dm.num_unlabeled_samples

        args.maxpool1 = False
        args.first_conv = True
        args.input_height = dm.size()[-1]

        normalization = stl10_normalization()

        args.gaussian_blur = True
        args.jitter_strength = 1.
    elif args.dataset == 'cifar10':
        val_split = 5000
        if args.num_nodes * args.gpus * args.batch_size > val_split:
            val_split = args.num_nodes * args.gpus * args.batch_size

        dm = CIFAR10DataModule(data_dir=args.data_dir,
                               batch_size=args.batch_size,
                               num_workers=args.num_workers,
                               val_split=val_split)

        args.num_samples = dm.num_samples

        args.maxpool1 = False
        args.first_conv = False
        args.input_height = dm.size()[-1]
        args.temperature = 0.5

        normalization = cifar10_normalization()

        args.gaussian_blur = False
        args.jitter_strength = 0.5
    elif args.dataset == 'imagenet':
        args.maxpool1 = True
        args.first_conv = True
        normalization = imagenet_normalization()

        args.gaussian_blur = True
        args.jitter_strength = 1.

        args.batch_size = 64
        args.num_nodes = 8
        args.gpus = 8  # per-node
        args.max_epochs = 800

        args.optimizer = 'sgd'
        args.lars_wrapper = True
        args.learning_rate = 4.8
        args.final_lr = 0.0048
        args.start_lr = 0.3
        args.online_ft = True

        dm = ImagenetDataModule(data_dir=args.data_dir,
                                batch_size=args.batch_size,
                                num_workers=args.num_workers)

        args.num_samples = dm.num_samples
        args.input_height = dm.size()[-1]
    else:
        raise NotImplementedError(
            "other datasets have not been implemented till now")

    dm.train_transforms = SimCLRTrainDataTransform(
        input_height=args.input_height,
        gaussian_blur=args.gaussian_blur,
        jitter_strength=args.jitter_strength,
        normalize=normalization,
    )

    dm.val_transforms = SimCLREvalDataTransform(
        input_height=args.input_height,
        gaussian_blur=args.gaussian_blur,
        jitter_strength=args.jitter_strength,
        normalize=normalization,
    )

    model = SimCLR(**args.__dict__)

    online_evaluator = None
    if args.online_ft:
        # online eval
        online_evaluator = SSLOnlineEvaluator(drop_p=0.,
                                              hidden_dim=None,
                                              z_dim=args.hidden_mlp,
                                              num_classes=dm.num_classes,
                                              dataset=args.dataset)

    model_checkpoint = ModelCheckpoint(save_last=True,
                                       save_top_k=1,
                                       monitor='val_loss')
    callbacks = [model_checkpoint, online_evaluator
                 ] if args.online_ft else [model_checkpoint]

    trainer = pl.Trainer.from_argparse_args(
        args,
        sync_batchnorm=True if args.gpus > 1 else False,
        callbacks=callbacks,
    )

    trainer.fit(model, datamodule=dm)
コード例 #7
0
def cli_main():
    from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
    from pl_bolts.models.self_supervised.swav.transforms import SwAVTrainDataTransform, SwAVEvalDataTransform
    from pl_bolts.datamodules import STL10DataModule, CIFAR10DataModule

    parser = ArgumentParser()

    # model args
    parser = SwAV.add_model_specific_args(parser)
    args = parser.parse_args()

    if args.dataset == 'stl10':
        dm = STL10DataModule(data_dir=args.data_path,
                             batch_size=args.batch_size,
                             num_workers=args.num_workers)

        dm.train_dataloader = dm.train_dataloader_mixed
        dm.val_dataloader = dm.val_dataloader_mixed
        args.num_samples = dm.num_unlabeled_samples

        args.maxpool1 = False

        normalization = stl10_normalization()
    elif args.dataset == 'cifar10':
        args.batch_size = 2
        args.num_workers = 0

        dm = CIFAR10DataModule(data_dir=args.data_path,
                               batch_size=args.batch_size,
                               num_workers=args.num_workers)

        args.num_samples = dm.num_samples

        args.maxpool1 = False
        args.first_conv = False

        normalization = cifar10_normalization()

        # cifar10 specific params
        args.size_crops = [32, 16]
        args.nmb_crops = [2, 1]
        args.gaussian_blur = False
    else:
        raise NotImplementedError(
            "other datasets have not been implemented till now")

    dm.train_transforms = SwAVTrainDataTransform(
        normalize=stl10_normalization(),
        size_crops=args.size_crops,
        nmb_crops=args.nmb_crops,
        min_scale_crops=args.min_scale_crops,
        max_scale_crops=args.max_scale_crops,
        gaussian_blur=args.gaussian_blur,
        jitter_strength=args.jitter_strength)

    dm.val_transforms = SwAVEvalDataTransform(
        normalize=stl10_normalization(),
        size_crops=args.size_crops,
        nmb_crops=args.nmb_crops,
        min_scale_crops=args.min_scale_crops,
        max_scale_crops=args.max_scale_crops,
        gaussian_blur=args.gaussian_blur,
        jitter_strength=args.jitter_strength)

    # swav model init
    model = SwAV(**args.__dict__)

    online_evaluator = None
    if args.online_ft:
        # online eval
        online_evaluator = SSLOnlineEvaluator(drop_p=0.,
                                              hidden_dim=None,
                                              z_dim=args.hidden_mlp,
                                              num_classes=dm.num_classes,
                                              dataset=args.dataset)

    trainer = pl.Trainer(
        max_epochs=args.max_epochs,
        max_steps=None if args.max_steps == -1 else args.max_steps,
        gpus=args.gpus,
        distributed_backend='ddp' if args.gpus > 1 else None,
        sync_batchnorm=True if args.gpus > 1 else False,
        precision=32 if args.fp32 else 16,
        callbacks=[online_evaluator] if args.online_ft else None,
        fast_dev_run=args.fast_dev_run)

    trainer.fit(model, dm)
コード例 #8
0
def cli_main():
    from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
    from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule
    from pl_bolts.models.self_supervised.simclr import SimCLREvalDataTransform, SimCLRTrainDataTransform

    seed_everything(1234)

    parser = ArgumentParser()

    # trainer args
    parser = pl.Trainer.add_argparse_args(parser)

    # model args
    parser = SimSiam.add_model_specific_args(parser)
    args = parser.parse_args()

    # pick data
    dm = None

    # init datamodule
    if args.dataset == "stl10":
        dm = STL10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers)

        dm.train_dataloader = dm.train_dataloader_mixed
        dm.val_dataloader = dm.val_dataloader_mixed
        args.num_samples = dm.num_unlabeled_samples

        args.maxpool1 = False
        args.first_conv = True
        args.input_height = dm.size()[-1]

        normalization = stl10_normalization()

        args.gaussian_blur = True
        args.jitter_strength = 1.0
    elif args.dataset == "cifar10":
        val_split = 5000
        if args.nodes * args.gpus * args.batch_size > val_split:
            val_split = args.nodes * args.gpus * args.batch_size

        dm = CIFAR10DataModule(
            data_dir=args.data_dir,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            val_split=val_split,
        )

        args.num_samples = dm.num_samples

        args.maxpool1 = False
        args.first_conv = False
        args.input_height = dm.size()[-1]
        args.temperature = 0.5

        normalization = cifar10_normalization()

        args.gaussian_blur = False
        args.jitter_strength = 0.5
    elif args.dataset == "cifar100":
        val_split = 5000
        if args.nodes * args.gpus * args.batch_size > val_split:
            val_split = args.nodes * args.gpus * args.batch_size

        dm = CIFAR100DataModule(
            data_dir=args.data_dir,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            val_split=val_split,
        )

        args.num_samples = dm.num_samples

        args.maxpool1 = False
        args.first_conv = False
        args.input_height = dm.size()[-1]
        args.temperature = 0.5

        # ((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023))
        normalization = transforms.Normalize(
            mean=(0.5071, 0.4866, 0.4409),
            std=(0.2009, 0.1984, 0.2023),
        )

        args.gaussian_blur = False
        args.jitter_strength = 0.5
    elif args.dataset == "imagenet":
        args.maxpool1 = True
        args.first_conv = True
        normalization = imagenet_normalization()

        args.gaussian_blur = True
        args.jitter_strength = 1.0

        args.batch_size = 64
        args.nodes = 8
        args.gpus = 8  # per-node
        args.max_epochs = 800

        args.optimizer = "sgd"
        args.lars_wrapper = True
        args.learning_rate = 4.8
        args.final_lr = 0.0048
        args.start_lr = 0.3
        args.online_ft = True

        dm = ImagenetDataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers)

        args.num_samples = dm.num_samples
        args.input_height = dm.size()[-1]
    else:
        raise NotImplementedError("other datasets have not been implemented till now")

    dm.train_transforms = SimCLRTrainDataTransform(
        input_height=args.input_height,
        gaussian_blur=args.gaussian_blur,
        jitter_strength=args.jitter_strength,
        normalize=normalization,
    )

    dm.val_transforms = SimCLREvalDataTransform(
        input_height=args.input_height,
        gaussian_blur=args.gaussian_blur,
        jitter_strength=args.jitter_strength,
        normalize=normalization,
    )

    model = SimSiam(**args.__dict__)

    # finetune in real-time
    online_evaluator = None
    if args.online_ft:
        # online eval
        online_evaluator = SSLOnlineEvaluator(
            drop_p=0.0,
            hidden_dim=None,
            z_dim=args.hidden_mlp,
            num_classes=dm.num_classes,
            dataset=args.dataset,
        )

    trainer = pl.Trainer(
        max_epochs=args.max_epochs,
        max_steps=None if args.max_steps == -1 else args.max_steps,
        gpus=args.gpus,
        num_nodes=args.nodes,
        distributed_backend="ddp" if args.gpus > 1 else None,
        sync_batchnorm=True if args.gpus > 1 else False,
        precision=32 if args.fp32 else 16,
        callbacks=[online_evaluator] if args.online_ft else None,
        fast_dev_run=args.fast_dev_run,
    )

    trainer.fit(model, dm)
コード例 #9
0
def cli_main():
    parser = ArgumentParser()
    parser.add_argument("--DATA_PATH",
                        type=str,
                        help="path to folders with images")
    parser.add_argument(
        "--encoder",
        default=None,
        type=str,
        help=
        "encoder to initialize. Can accept SimCLR model checkpoint or just encoder name in from encoders_dali"
    )
    parser.add_argument("--batch_size",
                        default=128,
                        type=int,
                        help="batch size for SSL")
    parser.add_argument("--num_workers",
                        default=1,
                        type=int,
                        help="number of workers to use to fetch data")
    parser.add_argument(
        "--hidden_dims",
        default=128,
        type=int,
        help=
        "hidden dimensions in classification layer added onto model for finetuning"
    )
    parser.add_argument("--epochs",
                        default=400,
                        type=int,
                        help="number of epochs to train model")
    parser.add_argument("--lr",
                        default=1e-3,
                        type=float,
                        help="learning rate for training model")
    parser.add_argument(
        "--patience",
        default=-1,
        type=int,
        help=
        "automatically cuts off training if validation does not drop for (patience) epochs. Leave blank to have no validation based early stopping."
    )
    parser.add_argument("--val_split",
                        default=0.2,
                        type=float,
                        help="percent in validation data")
    parser.add_argument(
        "--withhold_split",
        default=0,
        type=float,
        help=
        "decimal from 0-1 representing how much of the training data to withold from either training or validation. Used for experimenting with labels neeeded"
    )
    parser.add_argument("--gpus",
                        default=1,
                        type=int,
                        help="number of gpus to use for training")
    parser.add_argument("--log_name",
                        type=str,
                        help="name of model to log on wandb and locally")
    parser.add_argument(
        "--online_eval",
        default=False,
        type=bool,
        help="Do finetuning on model if labels are provided as a sanity check")

    args = parser.parse_args()
    DATA_PATH = args.DATA_PATH
    batch_size = args.batch_size
    num_workers = args.num_workers
    hidden_dims = args.hidden_dims
    epochs = args.epochs
    lr = args.lr
    patience = args.patience
    val_split = args.val_split
    withhold = args.withhold_split
    gpus = args.gpus
    encoder = args.encoder
    log_name = 'SIMCLR_SSL_' + args.log_name + '.ckpt'
    online_eval = args.online_eval

    wandb_logger = WandbLogger(name=log_name, project='SpaceForce')
    checkpointed = '.ckpt' in encoder
    if checkpointed:
        print('Resuming SSL Training from Model Checkpoint')
        try:
            model = SIMCLR.load_from_checkpoint(checkpoint_path=encoder)
            embedding_size = model.embedding_size
        except Exception as e:
            print(e)
            print(
                'invalid checkpoint to initialize SIMCLR. This checkpoint needs to include the encoder and projection and is of the SIMCLR class from this library. Will try to initialize just the encoder'
            )
            checkpointed = False

    elif not checkpointed:
        encoder, embedding_size = load_encoder(encoder)
        model = SIMCLR(encoder=encoder,
                       embedding_size=embedding_size,
                       gpus=gpus,
                       epochs=epochs,
                       DATA_PATH=DATA_PATH,
                       withhold=withhold,
                       batch_size=batch_size,
                       val_split=val_split,
                       hidden_dims=hidden_dims,
                       train_transform=SimCLRTrainDataTransform,
                       val_transform=SimCLRTrainDataTransform,
                       num_workers=num_workers,
                       lr=lr)

    online_evaluator = SSLOnlineEvaluator(drop_p=0.,
                                          hidden_dim=None,
                                          z_dim=embedding_size,
                                          num_classes=model.num_classes,
                                          dataset='None')

    cbs = []
    backend = 'dp'

    if patience > 0:
        cb = EarlyStopping('val_loss', patience=patience)
        cbs.append(cb)

    if online_eval:
        cbs.append(online_evaluator)
        backend = 'ddp'

    trainer = Trainer(
        gpus=gpus,
        max_epochs=epochs,
        progress_bar_refresh_rate=5,
        callbacks=cbs,
        distributed_backend=f'{backend}' if args.gpus > 1 else None,
        logger=wandb_logger,
        enable_pl_optimizer=True)

    print('USING BACKEND______________________________ ', backend)
    trainer.fit(model)
    Path(f"./models/SSL").mkdir(parents=True, exist_ok=True)
    trainer.save_checkpoint(f"./models/SSL/{log_name}")