def __init__(self,
                 encoder,
                 DATA_PATH,
                 VAL_PATH,
                 hidden_dim,
                 image_size,
                 seed,
                 cpus,
                 transform=SimCLRTransform,
                 **classifier_hparams):
        super().__init__()

        self.DATA_PATH = DATA_PATH
        self.VAL_PATH = VAL_PATH
        self.transform = transform
        self.image_size = image_size
        self.cpus = cpus
        self.seed = seed

        self.batch_size = classifier_hparams['batch_size']
        self.classifier_hparams = classifier_hparams

        self.linear_layer = SSLEvaluator(
            n_input=encoder.embedding_size,
            n_classes=self.classifier_hparams['num_classes'],
            p=self.classifier_hparams['dropout'],
            n_hidden=hidden_dim)

        self.train_acc = Accuracy()
        self.val_acc = Accuracy(compute_on_step=False)

        self.encoder = encoder

        self.save_hyperparameters()
    def on_pretrain_routine_start(self, trainer, pl_module):
        self.classifier = SSLEvaluator(
            n_input=self.z_dim,
            n_classes=self.datamodule.num_classes,
            n_hidden=None
        ).to(pl_module.device)

        self.optimizer = torch.optim.Adam(self.classifier.parameters(), lr=1e-3)
    def __init__(self, hparams):
        super().__init__()

        self.hparams = hparams
        self.online_evaluator = self.hparams.online_ft
        self.dataset = self.get_dataset(hparams.dataset)

        # encoder network (Z vectors)
        dummy_batch = torch.zeros(
            (2, 3, hparams.patch_size, hparams.patch_size))
        self.encoder = CPCResNet101(dummy_batch)

        # info nce loss
        c, h = self.__compute_final_nb_c(hparams.patch_size)
        self.info_nce = InfoNCE(num_input_channels=c,
                                target_dim=64,
                                embed_scale=0.1)

        if self.online_evaluator:
            z_dim = c * h * h
            num_classes = self.dataset.num_classes
            self.non_linear_evaluator = SSLEvaluator(n_input=z_dim,
                                                     n_classes=num_classes,
                                                     p=0.2,
                                                     n_hidden=1024)
 def on_pretrain_routine_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
     """
     Moves metrics and the online evaluator to the correct GPU.
     If training happens via DDP, SyncBatchNorm is enabled for the online evaluator, and it is converted to
     a DDP module.
     """
     for prefix, metrics in [("train", self.train_metrics), ("val", self.val_metrics)]:
         add_submodules_to_same_device(pl_module, metrics, prefix=prefix)
     self.evaluator = SSLEvaluator(n_input=self.z_dim,
                                   n_classes=self.num_classes,
                                   p=self.drop_p,
                                   n_hidden=self.hidden_dim)
     self.evaluator.to(pl_module.device)
     if hasattr(trainer, "accelerator_connector"):
         # This works with Lightning 1.3.8
         accelerator = trainer.accelerator_connector
     elif hasattr(trainer, "_accelerator_connector"):
         # This works with Lightning 1.5.5
         accelerator = trainer._accelerator_connector
     else:
         raise ValueError("Unable to retrieve the accelerator information")
     if accelerator.is_distributed:
         if accelerator.use_ddp:
             self.evaluator = SyncBatchNorm.convert_sync_batchnorm(self.evaluator)
             self.evaluator = DistributedDataParallel(self.evaluator, device_ids=[pl_module.device])  # type: ignore
         else:
             rank_zero_warn("This type of distributed accelerator is not supported. "
                            "The online evaluator will not synchronize across GPUs.")
     self.optimizer = torch.optim.Adam(self.evaluator.parameters(),
                                       lr=self.learning_rate,
                                       weight_decay=self.weight_decay)
     if self.evaluator_state is not None:
         self._wrapped_evaluator().load_state_dict(self.evaluator_state)
     if self.optimizer_state is not None:
         self.optimizer.load_state_dict(self.optimizer_state)
Beispiel #5
0
    def __init__(self,
                 dataset,
                 data_dir,
                 lr,
                 wd,
                 input_height,
                 batch_size,
                 online_ft=False,
                 num_workers=0,
                 optimizer='adam',
                 step=30,
                 gamma=0.5,
                 temperature=0.5,
                 **kwargs):
        super().__init__()

        self.hparams = Namespace(
            **{
                'lr': lr,
                'step': step,
                'gamma': gamma,
                'temperature': temperature,
                'dataset': dataset,
                'data_dir': data_dir,
                'wd': wd,
                'input_height': input_height,
                'batch_size': batch_size,
                'online_ft': online_ft,
                'num_workers': num_workers,
                'optimizer': optimizer
            })

        self.online_evaluator = online_ft
        self.batch_size = batch_size
        self.input_height = input_height
        self.gamma = gamma
        self.step = step
        self.optimizer = optimizer
        self.wd = wd
        self.lr = lr
        self.temp = temperature
        self.data_dir = data_dir
        self.num_workers = num_workers
        self.dataset_name = dataset
        self.dataset = self.get_dataset(dataset)
        self.loss_func = self.init_loss()
        self.encoder = self.init_encoder()
        self.projection = self.init_projection()

        if self.online_evaluator:
            z_dim = self.projection.output_dim
            num_classes = self.dataset.num_classes
            self.non_linear_evaluator = SSLEvaluator(n_input=z_dim,
                                                     n_classes=num_classes,
                                                     p=0.2,
                                                     n_hidden=1024)
Beispiel #6
0
    def __init__(self,
                 encoder,
                 DATA_PATH,
                 VAL_PATH,
                 hidden_dim,
                 image_size,
                 seed,
                 cpus,
                 transform=SimCLRTransform,
                 **classifier_hparams):
        super().__init__()

        self.DATA_PATH = DATA_PATH
        self.VAL_PATH = VAL_PATH
        self.transform = transform
        self.image_size = image_size
        self.cpus = cpus
        self.seed = seed

        self.batch_size = classifier_hparams['batch_size']
        self.classifier_hparams = classifier_hparams

        self.linear_layer = SSLEvaluator(
            n_input=encoder.embedding_size,
            n_classes=self.classifier_hparams['num_classes'],
            p=self.classifier_hparams['dropout'],
            n_hidden=hidden_dim)

        self.train_acc = Accuracy()
        self.val_acc = Accuracy(compute_on_step=False)
        self.encoder = encoder

        self.weights = None

        print(classifier_hparams)
        if classifier_hparams['weights'] is not None:
            self.weights = torch.tensor([
                float(item)
                for item in classifier_hparams['weights'].split(',')
            ])
            self.weights = self.weights.cuda()

        self.save_hyperparameters()
    def on_pretrain_routine_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
        from pl_bolts.models.self_supervised.evaluator import SSLEvaluator

        pl_module.non_linear_evaluator = SSLEvaluator(
            n_input=self.z_dim,
            n_classes=self.num_classes,
            p=self.drop_p,
            n_hidden=self.hidden_dim,
        ).to(pl_module.device)

        self.optimizer = torch.optim.Adam(pl_module.non_linear_evaluator.parameters(), lr=1e-4)
Beispiel #8
0
 def on_pretrain_routine_start(self, trainer, pl_module):
     from pl_bolts.models.self_supervised.evaluator import SSLEvaluator
     # takes an input --> flatten it --> add dropouts --> put a linear layer
     # simple linear mapping
     pl_module.non_linear_evaluator = SSLEvaluator(
         n_input=self.z_dim,
         n_classes = self.num_classes,
         p=self.drop_p,).to(pl_module.device)
     # pass in the parameters of the non_linear_evaluator and use SGD
     # need to learn the weights for the mlp_loss
     # need an optimizer
     self.optimizer = torch.optim.SGD(pl_module.non_linear_evaluator.parameters(), lr=1e-3)
    def __init__(self, encoder, DATA_PATH, withhold, batch_size, val_split,
                 hidden_dims, train_transform, val_transform, num_workers,
                 **kwargs):
        super().__init__()

        self.DATA_PATH = DATA_PATH
        self.val_split = val_split
        self.batch_size = batch_size
        self.hidden_dims = hidden_dims
        self.train_transform = train_transform
        self.val_transform = val_transform
        self.num_workers = num_workers
        self.withhold = withhold

        #data stuff
        shutil.rmtree('split_data', ignore_errors=True)
        if not (path.isdir(f"{self.DATA_PATH}/train")
                and path.isdir(f"{self.DATA_PATH}/val")):
            splitfolders.ratio(self.DATA_PATH,
                               output=f"split_data",
                               ratio=(1 - self.val_split - self.withhold,
                                      self.val_split, self.withhold),
                               seed=10)
            self.DATA_PATH = 'split_data'
            print(
                f'automatically splitting data into train and validation data {self.val_split} and withhold {self.withhold}'
            )

        self.num_classes = len(os.listdir(f'{self.DATA_PATH}/train'))

        #model stuff
        self.train_acc = Accuracy()
        self.val_acc = Accuracy(compute_on_step=False)
        print('KWARGS:', kwargs)
        self.encoder, self.embedding_size = load_encoder(encoder, kwargs)

        self.linear_layer = SSLEvaluator(n_input=self.embedding_size,
                                         n_classes=self.num_classes,
                                         p=0.1,
                                         n_hidden=self.hidden_dims)
Beispiel #10
0
    def on_pretrain_routine_start(self, trainer: pl.Trainer,
                                  pl_module: pl.LightningModule) -> None:
        """
        Initializes modules and moves metrics and class weights to module device
        """
        for metric in [*self.train_metrics, *self.val_metrics]:
            metric.to(device=pl_module.device)  # type: ignore

        pl_module.non_linear_evaluator = SSLEvaluator(
            n_input=self.z_dim,
            n_classes=self.num_classes,
            p=self.drop_p,
            n_hidden=self.hidden_dim).to(pl_module.device)
        assert isinstance(pl_module.non_linear_evaluator, torch.nn.Module)
        self.optimizer = torch.optim.Adam(
            pl_module.non_linear_evaluator.parameters(),
            lr=self.learning_rate,
            weight_decay=self.weight_decay)
Beispiel #11
0
    def on_pretrain_routine_start(self, trainer, pl_module):
        from pl_bolts.models.self_supervised.evaluator import SSLEvaluator

        # attach the evaluator to the module

        if hasattr(pl_module, 'z_dim'):
            self.z_dim = pl_module.z_dim
        if hasattr(pl_module, 'num_classes'):
            self.num_classes = pl_module.num_classes

        pl_module.non_linear_evaluator = SSLEvaluator(
            n_input=self.z_dim,
            n_classes=self.num_classes,
            p=self.drop_p,
            n_hidden=self.hidden_dim
        ).to(pl_module.device)

        self.optimizer = torch.optim.SGD(pl_module.non_linear_evaluator.parameters(), lr=1e-3)
Beispiel #12
0
    def __init__(self, hparams):
        super().__init__()

        self.hparams = hparams
        self.online_evaluator = self.hparams.online_ft
        self.dataset = self.get_dataset(hparams.dataset)

        self.encoder = self.init_encoder()

        # info nce loss
        c, h = self.__compute_final_nb_c(hparams.patch_size)
        self.info_nce = InfoNCE(num_input_channels=c,
                                target_dim=64,
                                embed_scale=0.1)

        if self.online_evaluator:
            z_dim = c * h * h
            num_classes = self.dataset.num_classes
            self.non_linear_evaluator = SSLEvaluator(n_input=z_dim,
                                                     n_classes=num_classes,
                                                     p=0.2,
                                                     n_hidden=1024)
Beispiel #13
0
    def on_pretrain_routine_start(self, trainer: Trainer,
                                  pl_module: LightningModule) -> None:
        # must move to device after setup, as during setup, pl_module is still on cpu
        self.online_evaluator = SSLEvaluator(
            n_input=self.z_dim,
            n_classes=self.num_classes,
            p=self.drop_p,
            n_hidden=self.hidden_dim,
        ).to(pl_module.device)

        # switch fo PL compatibility reasons
        accel = (trainer.accelerator_connector if hasattr(
            trainer, "accelerator_connector") else
                 trainer._accelerator_connector)
        if accel.is_distributed:
            if accel.use_ddp:
                from torch.nn.parallel import DistributedDataParallel as DDP

                self.online_evaluator = DDP(self.online_evaluator,
                                            device_ids=[pl_module.device])
            elif accel.use_dp:
                from torch.nn.parallel import DataParallel as DP

                self.online_evaluator = DP(self.online_evaluator,
                                           device_ids=[pl_module.device])
            else:
                rank_zero_warn(
                    "Does not support this type of distributed accelerator. The online evaluator will not sync."
                )

        self.optimizer = torch.optim.Adam(self.online_evaluator.parameters(),
                                          lr=1e-4)

        if self._recovered_callback_state is not None:
            self.online_evaluator.load_state_dict(
                self._recovered_callback_state["state_dict"])
            self.optimizer.load_state_dict(
                self._recovered_callback_state["optimizer_state"])
class SSLOnlineEvaluator(Callback):

    def __init__(self, data_dir, z_dim, max_epochs=10, check_val_every_n_epoch=1, batch_size=1024, num_workers=32):
        self.z_dim = z_dim
        self.max_epochs = max_epochs
        self.check_val_every_n_epoch = check_val_every_n_epoch

        self.datamodule = BigearthnetDataModule(
            data_dir=data_dir,
            train_frac=0.01,
            val_frac=0.01,
            lmdb=True,
            batch_size=batch_size,
            num_workers=num_workers
        )
        self.datamodule.setup()

        self.criterion = nn.MultiLabelSoftMarginLoss()
        self.metric = lambda output, target: average_precision_score(target, output, average='micro') * 100.0

    def on_pretrain_routine_start(self, trainer, pl_module):
        self.classifier = SSLEvaluator(
            n_input=self.z_dim,
            n_classes=self.datamodule.num_classes,
            n_hidden=None
        ).to(pl_module.device)

        self.optimizer = torch.optim.Adam(self.classifier.parameters(), lr=1e-3)

    def on_epoch_end(self, trainer, pl_module):
        if (trainer.current_epoch + 1) % self.check_val_every_n_epoch != 0:
            return

        encoder = pl_module.encoder_q

        self.classifier.train()
        for _ in range(self.max_epochs):
            for inputs, targets in self.datamodule.train_dataloader():
                inputs = inputs.to(pl_module.device)
                targets = targets.to(pl_module.device)

                with torch.no_grad():
                    representations = encoder(inputs)
                representations = representations.detach()

                logits = self.classifier(representations)
                loss = self.criterion(logits, targets)

                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()

        self.classifier.eval()
        accuracies = []
        for inputs, targets in self.datamodule.val_dataloader():
            inputs = inputs.to(pl_module.device)

            with torch.no_grad():
                representations = encoder(inputs)
            representations = representations.detach()

            logits = self.classifier(representations)
            preds = torch.sigmoid(logits).detach().cpu()
            acc = self.metric(preds, targets)
            accuracies.append(acc)
        acc = torch.mean(torch.tensor(accuracies))

        metrics = {'online_val_acc': acc}
        trainer.logger_connector.log_metrics(metrics, {})
        trainer.logger_connector.add_progress_bar_metrics(metrics)
Beispiel #15
0
class CLASSIFIER(pl.LightningModule):  #SSLFineTuner
    def __init__(self,
                 encoder,
                 DATA_PATH,
                 VAL_PATH,
                 hidden_dim,
                 image_size,
                 seed,
                 cpus,
                 transform=SimCLRTransform,
                 **classifier_hparams):
        super().__init__()

        self.DATA_PATH = DATA_PATH
        self.VAL_PATH = VAL_PATH
        self.transform = transform
        self.image_size = image_size
        self.cpus = cpus
        self.seed = seed

        self.batch_size = classifier_hparams['batch_size']
        self.classifier_hparams = classifier_hparams

        self.linear_layer = SSLEvaluator(
            n_input=encoder.embedding_size,
            n_classes=self.classifier_hparams['num_classes'],
            p=self.classifier_hparams['dropout'],
            n_hidden=hidden_dim)

        self.train_acc = Accuracy()
        self.val_acc = Accuracy(compute_on_step=False)
        self.encoder = encoder

        self.weights = None

        print(classifier_hparams)
        if classifier_hparams['weights'] is not None:
            self.weights = torch.tensor([
                float(item)
                for item in classifier_hparams['weights'].split(',')
            ])
            self.weights = self.weights.cuda()

        self.save_hyperparameters()

    #override optimizer to allow modification of encoder learning rate
    def configure_optimizers(self):
        optimizer = SGD([{
            'params': self.encoder.parameters(),
            'lr': 0
        }, {
            'params': self.linear_layer.parameters(),
            'lr': self.classifier_hparams['linear_lr']
        }],
                        lr=self.classifier_hparams['learning_rate'],
                        momentum=self.classifier_hparams['momentum'])

        if self.classifier_hparams['scheduler_type'] == "step":
            scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer,
                self.classifier_hparams['decay_epochs'],
                gamma=self.classifier_hparams['gamma'])
        elif self.classifier_hparams['scheduler_type'] == "cosine":

            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer,
                self.classifier_hparams['epochs'],
                eta_min=self.classifier_hparams[
                    'final_lr']  # total epochs to run
            )

        return [optimizer], [scheduler]

    def forward(self, x):
        feats = self.encoder(x)[-1]
        feats = feats.view(feats.size(0), -1)
        logits = self.linear_layer(feats)
        return logits

    def shared_step(self, batch):
        x, y = batch
        logits = self.forward(x)
        loss = self.loss_fn(logits, y)
        return loss, logits, y

    def training_step(self, batch, batch_idx):

        loss, logits, y = self.shared_step(batch)
        acc = self.train_acc(logits, y)
        self.log('tloss', loss, prog_bar=True)
        self.log('tastep', acc, prog_bar=True)
        self.log('ta_epoch', self.train_acc)

        return loss

    def validation_step(self, batch, batch_idx):
        with torch.no_grad():
            loss, logits, y = self.shared_step(batch)
            acc = self.val_acc(logits, y)

        acc = self.val_acc(logits, y)
        self.log('val_loss', loss, prog_bar=True, sync_dist=True)
        self.log('val_acc_epoch', self.val_acc, prog_bar=True)
        self.log('val_acc_epoch', self.val_acc, prog_bar=True)
        return loss

    def loss_fn(self, logits, labels):
        return F.cross_entropy(logits, labels, weight=self.weights)

    def setup(self, stage='inference'):
        Options = Enum('Loader', 'fit test inference')
        if stage == Options.fit.name:
            train = self.transform(self.DATA_PATH,
                                   batch_size=self.batch_size,
                                   input_height=self.image_size,
                                   copies=1,
                                   stage='train',
                                   num_threads=self.cpus,
                                   device_id=self.local_rank,
                                   seed=self.seed)
            val = self.transform(self.VAL_PATH,
                                 batch_size=self.batch_size,
                                 input_height=self.image_size,
                                 copies=1,
                                 stage='validation',
                                 num_threads=self.cpus,
                                 device_id=self.local_rank,
                                 seed=self.seed)
            self.train_loader = ClassifierWrapper(transform=train)
            self.val_loader = ClassifierWrapper(transform=val)

        elif stage == Options.inference.name:
            self.test_dataloader = ClassifierWrapper(
                transform=self.transform(self.DATA_PATH,
                                         batch_size=self.batch_size,
                                         input_height=self.image_size,
                                         copies=1,
                                         stage='inference',
                                         num_threads=2 * self.cpus,
                                         device_id=self.local_rank,
                                         seed=self.seed))
            self.inference_dataloader = self.test_dataloader

    def train_dataloader(self):
        return self.train_loader

    def val_dataloader(self):
        return self.val_loader

    #give user permission to add extra arguments for SIMSIAM model particularly. This cannot share the name of any parameters from train.py
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)

        # training params
        parser.add_argument("--linear_lr",
                            default=1e-1,
                            type=float,
                            help="learning rate for classification head.")
        parser.add_argument("--dropout",
                            default=0.1,
                            type=float,
                            help="dropout of neurons during training [0-1].")
        parser.add_argument("--nesterov",
                            default=False,
                            type=bool,
                            help="Use nesterov during training.")
        parser.add_argument(
            "--scheduler_type",
            default='cosine',
            type=str,
            help="learning rate scheduler: ['cosine' or 'step']")
        parser.add_argument("--gamma",
                            default=0.1,
                            type=float,
                            help="gamma param for learning rate.")
        parser.add_argument("--decay_epochs",
                            default=[60, 80],
                            type=list,
                            help="epochs to do optimizer decay")
        parser.add_argument("--weight_decay",
                            default=1e-6,
                            type=float,
                            help="weight decay")
        parser.add_argument("--final_lr",
                            type=float,
                            default=1e-6,
                            help="final learning rate")
        parser.add_argument("--momentum",
                            type=float,
                            default=0.9,
                            help="momentum for learning rate")
        parser.add_argument(
            '--weights',
            type=str,
            help='delimited list of weights for penalty during classification')
        return parser
Beispiel #16
0
    def __init__(
        self,
        datamodule: pl.LightningDataModule = None,
        encoder: Union[str, torch.nn.Module,
                       pl.LightningModule] = 'cpc_encoder',
        patch_size: int = 8,
        patch_overlap: int = 4,
        online_ft: int = True,
        task: str = 'cpc',
        num_workers: int = 4,
        learning_rate: int = 1e-4,
        data_dir: str = '',
        batch_size: int = 32,
        pretrained: str = None,
        **kwargs,
    ):
        """
        PyTorch Lightning implementation of `Data-Efficient Image Recognition with Contrastive Predictive Coding
        <https://arxiv.org/abs/1905.09272>`_

        Paper authors: (Olivier J. Hénaff, Aravind Srinivas, Jeffrey De Fauw, Ali Razavi,
        Carl Doersch, S. M. Ali Eslami, Aaron van den Oord).

        Model implemented by:

            - `William Falcon <https://github.com/williamFalcon>`_
            - `Tullie Murrell <https://github.com/tullie>`_

        Example:

            >>> from pl_bolts.models.self_supervised import CPCV2
            ...
            >>> model = CPCV2()

        Train::

            trainer = Trainer()
            trainer.fit(model)

        CLI command::

            # cifar10
            python cpc_module.py --gpus 1

            # imagenet
            python cpc_module.py
                --gpus 8
                --dataset imagenet2012
                --data_dir /path/to/imagenet/
                --meta_dir /path/to/folder/with/meta.bin/
                --batch_size 32

        Some uses::

            # load resnet18 pretrained using CPC on imagenet
            model = CPCV2(encoder='resnet18', pretrained=True)
            resnet18 = model.encoder
            renset18.freeze()

            # it supportes any torchvision resnet
            model = CPCV2(encoder='resnet50', pretrained=True)

            # use it as a feature extractor
            x = torch.rand(2, 3, 224, 224)
            out = model(x)

        Args:
            datamodule: A Datamodule (optional). Otherwise set the dataloaders directly
            encoder: A string for any of the resnets in torchvision, or the original CPC encoder,
                or a custon nn.Module encoder
            patch_size: How big to make the image patches
            patch_overlap: How much overlap should each patch have.
            online_ft: Enable a 1024-unit MLP to fine-tune online
            task: Which self-supervised task to use ('cpc', 'amdim', etc...)
            num_workers: num dataloader worksers
            learning_rate: what learning rate to use
            data_dir: where to store data
            batch_size: batch size
            pretrained: If true, will use the weights pretrained (using CPC) on Imagenet
        """

        super().__init__()
        self.save_hyperparameters()

        self.online_evaluator = self.hparams.online_ft

        if pretrained:
            self.hparams.dataset = pretrained
            self.online_evaluator = True

        # link data
        if datamodule is None:
            datamodule = CIFAR10DataModule(
                self.hparams.data_dir,
                num_workers=self.hparams.num_workers,
                batch_size=batch_size)
            datamodule.train_transforms = CPCTrainTransformsCIFAR10()
            datamodule.val_transforms = CPCEvalTransformsCIFAR10()
        self.datamodule = datamodule

        # init encoder
        self.encoder = encoder
        if isinstance(encoder, str):
            self.encoder = self.init_encoder()

        # info nce loss
        c, h = self.__compute_final_nb_c(self.hparams.patch_size)
        self.contrastive_task = CPCTask(num_input_channels=c,
                                        target_dim=64,
                                        embed_scale=0.1)

        if self.online_evaluator:
            z_dim = c * h * h
            num_classes = self.datamodule.num_classes
            self.non_linear_evaluator = SSLEvaluator(n_input=z_dim,
                                                     n_classes=num_classes,
                                                     p=0.2,
                                                     n_hidden=1024)

        if pretrained:
            self.load_pretrained(encoder)
class finetuneSIMCLR(pl.LightningModule):
    def __init__(self, encoder, DATA_PATH, withhold, batch_size, val_split,
                 hidden_dims, train_transform, val_transform, num_workers,
                 **kwargs):
        super().__init__()

        self.DATA_PATH = DATA_PATH
        self.val_split = val_split
        self.batch_size = batch_size
        self.hidden_dims = hidden_dims
        self.train_transform = train_transform
        self.val_transform = val_transform
        self.num_workers = num_workers
        self.withhold = withhold

        #data stuff
        shutil.rmtree('split_data', ignore_errors=True)
        if not (path.isdir(f"{self.DATA_PATH}/train")
                and path.isdir(f"{self.DATA_PATH}/val")):
            splitfolders.ratio(self.DATA_PATH,
                               output=f"split_data",
                               ratio=(1 - self.val_split - self.withhold,
                                      self.val_split, self.withhold),
                               seed=10)
            self.DATA_PATH = 'split_data'
            print(
                f'automatically splitting data into train and validation data {self.val_split} and withhold {self.withhold}'
            )

        self.num_classes = len(os.listdir(f'{self.DATA_PATH}/train'))

        #model stuff
        self.train_acc = Accuracy()
        self.val_acc = Accuracy(compute_on_step=False)
        print('KWARGS:', kwargs)
        self.encoder, self.embedding_size = load_encoder(encoder, kwargs)

        self.linear_layer = SSLEvaluator(n_input=self.embedding_size,
                                         n_classes=self.num_classes,
                                         p=0.1,
                                         n_hidden=self.hidden_dims)


#   def forward(self, x):
#       x = self.encoder(x)[0]
#       x = F.log_softmax(self.fc1(x), dim = 1)
#       return x

    def shared_step(self, batch):
        x, y = batch
        feats = self.encoder(x)[-1]
        feats = feats.view(feats.size(0), -1)
        logits = self.linear_layer(feats)
        loss = self.loss_fn(logits, y)
        return loss, logits, y

    def training_step(self, batch, batch_idx):
        loss, logits, y = self.shared_step(batch)
        acc = self.train_acc(logits, y)
        self.log('tloss', loss, prog_bar=True)
        self.log('tastep', acc, prog_bar=True)
        self.log('ta_epoch', self.train_acc)

        return loss

    def validation_step(self, batch, batch_idx):
        with torch.no_grad():
            loss, logits, y = self.shared_step(batch)
            acc = self.val_acc(logits, y)

        acc = self.val_acc(logits, y)
        self.log('vloss', loss, prog_bar=True, sync_dist=True)
        self.log('val_acc_epoch', self.val_acc, prog_bar=True)

        return loss

    def loss_fn(self, logits, labels):
        return F.cross_entropy(logits, labels)

    def configure_optimizers(self):
        opt = SGD([{
            'params': self.encoder.parameters()
        }, {
            'params': self.linear_layer.parameters(),
            'lr': 0.1
        }],
                  lr=1e-4,
                  momentum=0.9)

        return [opt]

    def prepare_data(self):

        train_pipeline = self.train_transform(
            DATA_PATH=f"{self.DATA_PATH}/train",
            input_height=256,
            batch_size=self.batch_size,
            num_threads=self.num_workers,
            device_id=0)
        print(f"{self.DATA_PATH}/train")
        val_pipeline = self.val_transform(DATA_PATH=f"{self.DATA_PATH}/val",
                                          input_height=256,
                                          batch_size=self.batch_size,
                                          num_threads=self.num_workers,
                                          device_id=0)

        class LightningWrapper(DALIClassificationIterator):
            def __init__(self, *kargs, **kvargs):
                super().__init__(*kargs, **kvargs)

            def __next__(self):
                out = super().__next__()
                out = out[0]
                return [
                    out[k] if k != "label" else torch.squeeze(out[k])
                    for k in self.output_map
                ]

        self.train_loader = LightningWrapper(train_pipeline,
                                             fill_last_batch=False,
                                             auto_reset=True,
                                             reader_name="Reader")
        self.val_loader = LightningWrapper(val_pipeline,
                                           fill_last_batch=False,
                                           auto_reset=True,
                                           reader_name="Reader")

    def train_dataloader(self):
        return self.train_loader

    def val_dataloader(self):
        return self.val_loader
    def __init__(self,
                 datamodule: pl_bolts.datamodules.LightningDataModule = None,
                 data_dir: str = '',
                 learning_rate: float = 0.00006,
                 weight_decay: float = 0.0005,
                 input_height: int = 32,
                 batch_size: int = 128,
                 online_ft: bool = False,
                 num_workers: int = 4,
                 optimizer: str = 'lars',
                 lr_sched_step: float = 30.0,
                 lr_sched_gamma: float = 0.5,
                 lars_momentum: float = 0.9,
                 lars_eta: float = 0.001,
                 loss_temperature: float = 0.5,
                 **kwargs):
        """
        PyTorch Lightning implementation of `SIMCLR <https://arxiv.org/abs/2002.05709.>`_

        Paper authors: Ting Chen, Simon Kornblith, Mohammad Norouzi, Geoffrey Hinton.

        Model implemented by:

            - `William Falcon <https://github.com/williamFalcon>`_
            - `Tullie Murrell <https://github.com/tullie>`_

        Example:

            >>> from pl_bolts.models.self_supervised import SimCLR
            ...
            >>> model = SimCLR()

        Train::

            trainer = Trainer()
            trainer.fit(model)

        CLI command::

            # cifar10
            python simclr_module.py --gpus 1

            # imagenet
            python simclr_module.py
                --gpus 8
                --dataset imagenet2012
                --data_dir /path/to/imagenet/
                --meta_dir /path/to/folder/with/meta.bin/
                --batch_size 32

        Args:
            datamodule: The datamodule
            data_dir: directory to store data
            learning_rate: the learning rate
            weight_decay: optimizer weight decay
            input_height: image input height
            batch_size: the batch size
            online_ft: whether to tune online or not
            num_workers: number of workers
            optimizer: optimizer name
            lr_sched_step: step for learning rate scheduler
            lr_sched_gamma: gamma for learning rate scheduler
            lars_momentum: the mom param for lars optimizer
            lars_eta: for lars optimizer
            loss_temperature: float = 0.
        """
        super().__init__()
        self.save_hyperparameters()
        self.online_evaluator = online_ft

        # init default datamodule
        if datamodule is None:
            datamodule = CIFAR10DataModule(data_dir, num_workers=num_workers)
            datamodule.train_transforms = SimCLRTrainDataTransform(
                input_height)
            datamodule.val_transforms = SimCLREvalDataTransform(input_height)

        self.datamodule = datamodule
        self.loss_func = self.init_loss()
        self.encoder = self.init_encoder()
        self.projection = self.init_projection()

        if self.online_evaluator:
            z_dim = self.projection.output_dim
            num_classes = self.datamodule.num_classes
            self.non_linear_evaluator = SSLEvaluator(n_input=z_dim,
                                                     n_classes=num_classes,
                                                     p=0.2,
                                                     n_hidden=1024)