def cli_main():

    parser = ArgumentParser()
    parser.add_argument("--DATA_PATH",
                        type=str,
                        help="path to folders with images")
    parser.add_argument("--MODEL_PATH",
                        default=None,
                        type=str,
                        help="path to model checkpoint")
    parser.add_argument("--batch_size",
                        default=128,
                        type=int,
                        help="batch size for SSL")
    parser.add_argument("--image_size",
                        default=256,
                        type=int,
                        help="image size for SSL")
    parser.add_argument("--num_workers",
                        default=1,
                        type=int,
                        help="number of CPU cores to use for data processing")
    parser.add_argument("--image_embedding_size",
                        default=128,
                        type=int,
                        help="size of image representation of SIMCLR")
    parser.add_argument("--epochs",
                        default=200,
                        type=int,
                        help="number of epochs to train model")
    parser.add_argument("--lr",
                        default=1e-3,
                        type=float,
                        help="learning rate for training model")
    parser.add_argument(
        "--patience",
        default=-1,
        type=int,
        help=
        "automatically cuts off training if validation does not drop for (patience) epochs. Leave blank to have no validation based early stopping."
    )
    parser.add_argument("--val_split",
                        default=0.2,
                        type=float,
                        help="percent in validation data")
    parser.add_argument(
        "--pretrain_encoder",
        default=False,
        type=bool,
        help=
        "initialize resnet encoder with pretrained imagenet weights. Cannot be true if passing previous SSL model checkpoint."
    )
    parser.add_argument(
        "--withold_train_percent",
        default=0,
        type=float,
        help=
        "decimal from 0-1 representing how much of the training data to withold during SSL training"
    )
    parser.add_argument("--version",
                        default="0",
                        type=str,
                        help="version to name checkpoint for saving")
    parser.add_argument("--gpus",
                        default=1,
                        type=int,
                        help="number of gpus to use for training")
    parser.add_argument("--num_workers",
                        default=0,
                        type=int,
                        help="number of workers to use to fetch data")

    args = parser.parse_args()
    URL = args.DATA_PATH
    batch_size = args.batch_size
    image_size = args.image_size
    num_workers = args.num_workers
    embedding_size = args.image_embedding_size
    epochs = args.epochs
    lr = args.lr
    patience = args.patience
    val_split = args.val_split
    pretrain = args.pretrain_encoder
    withold_train_percent = args.withold_train_percent
    version = args.version
    model_checkpoint = args.MODEL_PATH
    gpus = args.gpus
    num_workers = args.num_workers

    train_transform = SimCLRTrainDataTransform(256)
    val_transform = SimCLREvalDataTransform(256)
    dm = ImageDataModule(URL,
                         train_transform=train_transform,
                         val_transform=val_transform,
                         val_split=val_split,
                         num_workers=num_workers)
    dm.setup()

    #init model with batch size, num_samples (len of data), epochs to train, and autofinds learning rate
    model = SimCLR(arch='resnet18',
                   batch_size=batch_size,
                   num_samples=dm.num_samples,
                   gpus=gpus,
                   dataset='None',
                   max_epochs=epochs,
                   learning_rate=lr)  #

    model.encoder = resnet18(pretrained=pretrain,
                             first_conv=model.first_conv,
                             maxpool1=model.maxpool1,
                             return_all_feature_maps=False)
    model.projection = Projection(input_dim=512,
                                  hidden_dim=256,
                                  output_dim=embedding_size)  #overrides

    if patience > 0:
        cb = EarlyStopping('val_loss', patience=patience)
        trainer = Trainer(gpus=gpus,
                          max_epochs=epochs,
                          callbacks=[cb],
                          progress_bar_refresh_rate=5)
    else:
        trainer = Trainer(gpus=gpus,
                          max_epochs=epochs,
                          progress_bar_refresh_rate=5)

    if model_checkpoint is not None:
        model.load_state_dict(torch.load(model_checkpoint))
        print(
            'Successfully loaded your checkpoint. Keep in mind that this does not preserve the previous trainer states, only the model weights'
        )

    model.cuda()

    print('Model Initialized')
    trainer.fit(model, dm)

    Path(f"./models/SSL/SIMCLR_SSL_{version}").mkdir(parents=True,
                                                     exist_ok=True)
    torch.save(model.state_dict(),
               f"./models/SSL/SIMCLR_SSL_{version}/SIMCLR_SSL_{version}.pt")
def cli_main():

    parser = ArgumentParser()
    parser.add_argument("--DATA_PATH",
                        type=str,
                        help="path to folders with images")
    parser.add_argument("--MODEL_PATH",
                        default=None,
                        type=str,
                        help="path to model checkpoint")
    parser.add_argument("--batch_size",
                        default=128,
                        type=int,
                        help="batch size for SSL")
    parser.add_argument("--image_size",
                        default=256,
                        type=int,
                        help="image size for SSL")
    parser.add_argument(
        "--image_type",
        default="tif",
        type=str,
        help=
        "extension of image for PIL to open and parse - i.e. jpeg, gif, tif, etc. Only put the extension name, not the dot (.)"
    )
    parser.add_argument("--num_workers",
                        default=1,
                        type=int,
                        help="number of CPU cores to use for data processing")
    parser.add_argument("--image_embedding_size",
                        default=128,
                        type=int,
                        help="size of image representation of SIMCLR")
    parser.add_argument("--epochs",
                        default=200,
                        type=int,
                        help="number of epochs to train model")
    parser.add_argument("--lr",
                        default=1e-3,
                        type=float,
                        help="learning rate for training model")
    parser.add_argument(
        "--patience",
        default=-1,
        type=int,
        help=
        "automatically cuts off training if validation does not drop for (patience) epochs. Leave blank to have no validation based early stopping."
    )
    parser.add_argument("--val_split",
                        default=0.2,
                        type=float,
                        help="percent in validation data")
    parser.add_argument(
        "--pretrain_encoder",
        default=False,
        type=bool,
        help=
        "initialize resnet encoder with pretrained imagenet weights. Cannot be true if passing previous SSL model checkpoint."
    )
    parser.add_argument(
        "--withold_train_percent",
        default=0,
        type=float,
        help=
        "decimal from 0-1 representing how much of the training data to withold during SSL training"
    )
    parser.add_argument("--version",
                        default="0",
                        type=str,
                        help="version to name checkpoint for saving")
    parser.add_argument("--gpus",
                        default=1,
                        type=int,
                        help="number of gpus to use for training")

    args = parser.parse_args()
    URL = args.DATA_PATH
    batch_size = args.batch_size
    image_size = args.image_size
    image_type = args.image_type
    num_workers = args.num_workers
    embedding_size = args.image_embedding_size
    epochs = args.epochs
    lr = args.lr
    patience = args.patience
    val_split = args.val_split
    pretrain = args.pretrain_encoder
    withold_train_percent = args.withold_train_percent
    version = args.version
    model_checkpoint = args.MODEL_PATH
    gpus = args.gpus

    # #testing
    # batch_size = 128
    # image_type = 'tif'
    # image_size = 256
    # num_workers = 4
    # URL ='/content/UCMerced_LandUse/Images'
    # embedding_size = 128
    # epochs = 2
    # lr = 1e-3
    # patience = 1
    # val_split = 0.2
    # pretrain = False
    # withold_train_percent = 0.2
    # version = "1"
    # model_checkpoint = '/content/models/SSL/SIMCLR_SSL_0.pt'
    # gpus = 1

    # #gets dataset. We can't combine since validation data has different transform needed
    train_dataset = FolderDataset(
        URL,
        validation=False,
        val_split=val_split,
        withold_train_percent=withold_train_percent,
        transform=SimCLRTrainDataTransform(image_size),
        image_type=image_type)

    data_loader = torch.utils.data.DataLoader(train_dataset,
                                              batch_size=batch_size,
                                              num_workers=num_workers,
                                              drop_last=True)

    print('Training Data Loaded...')
    val_dataset = FolderDataset(URL,
                                validation=True,
                                val_split=val_split,
                                transform=SimCLREvalDataTransform(image_size),
                                image_type=image_type)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             num_workers=num_workers,
                                             drop_last=True)
    print('Validation Data Loaded...')

    num_samples = len(train_dataset)

    #init model with batch size, num_samples (len of data), epochs to train, and autofinds learning rate
    model = SimCLR(arch='resnet18',
                   batch_size=batch_size,
                   num_samples=num_samples,
                   gpus=gpus,
                   dataset='None',
                   max_epochs=epochs,
                   learning_rate=lr)  #

    model.encoder = resnet18(pretrained=pretrain,
                             first_conv=model.first_conv,
                             maxpool1=model.maxpool1,
                             return_all_feature_maps=False)
    model.projection = Projection(input_dim=512,
                                  hidden_dim=256,
                                  output_dim=embedding_size)  #overrides

    if patience > 0:
        cb = EarlyStopping('val_loss', patience=patience)
        trainer = Trainer(gpus=gpus,
                          max_epochs=epochs,
                          callbacks=[cb],
                          progress_bar_refresh_rate=5)
    else:
        trainer = Trainer(gpus=gpus,
                          max_epochs=epochs,
                          progress_bar_refresh_rate=5)

    if model_checkpoint is not None:
        model.load_state_dict(torch.load(model_checkpoint))
        print(
            'Successfully loaded your checkpoint. Keep in mind that this does not preserve the previous trainer states, only the model weights'
        )

    model.cuda()

    print('Model Initialized')
    trainer.fit(model, data_loader, val_loader)

    Path(f"./models/SSL/SIMCLR_SSL_{version}").mkdir(parents=True,
                                                     exist_ok=True)
    torch.save(model.state_dict(),
               f"./models/SSL/SIMCLR_SSL_{version}/SIMCLR_SSL_{version}.pt")