Пример #1
0
class LinearProbe(LightningModule):
    """
    A class to train and evaluate a linear probe on top of representations learned from a noisy clip student.
    """
    def __init__(self, args):
        super(LinearProbe, self).__init__()
        self.hparams = args

        #(1) Set up the dataset
        #Here, we use a 100-class subset of ImageNet
        if self.hparams.dataset not in ["ImageNet100", "CIFAR10", "CIFAR100"]:
            raise ValueError("Unsupported dataset selected.")
        else:
            if self.hparams.distortion == "None":
                self.train_set_transform = ImageNetBaseTransform(self.hparams)
                self.val_set_transform = ImageNetBaseTransformVal(self.hparams)
            elif self.hparams.distortion == "multi":
                self.train_set_transform = ImageNetDistortTrainMulti(self.hparams)
                self.val_set_transform = ImageNetDistortValMulti(self.hparams)

            else:
                #If we are using the ImageNet dataset, then set up the train and val sets to use the same mask if needed!
                self.train_set_transform = ImageNetDistortTrain(self.hparams)

                if self.hparams.fixed_mask:
                    self.val_set_transform = ImageNetDistortVal(self.hparams, fixed_distortion=self.train_set_transform.distortion)
                else:
                    self.val_set_transform = ImageNetDistortVal(self.hparams)

        #This should be initialised as a trained student CLIP network
        if self.hparams.encoder == "clip":
            saved_student = NoisyCLIP.load_from_checkpoint(self.hparams.checkpoint_path)
            self.backbone = saved_student.noisy_visual_encoder
        elif self.hparams.encoder == "clean":
            saved_student = clip.load('RN101', 'cpu', jit=False)[0]
            self.backbone = saved_student.visual


        self.backbone.eval()
        for param in self.backbone.parameters():
            param.requires_grad = False

        #This is the meat
        self.output = nn.Linear(self.hparams.emb_dim, self.hparams.num_classes)

        #Set up training and validation metrics
        self.criterion = nn.CrossEntropyLoss()

        self.val_top_1 = Accuracy(top_k=1)
        self.val_top_5 = Accuracy(top_k=5)
        self.test_top_1 = Accuracy(top_k=1)
        self.test_top_5 = Accuracy(top_k=5)


    def forward(self, x):
        """
        Given a set of images x with shape [N, c, h, w], get their embeddings and then logits.

        Returns: Logits with shape [N, n_classes]
        """

        #Grab the noisy image embeddings
        self.backbone.eval()
        with torch.no_grad():
            if self.hparams.encoder == "clip":
                noisy_embeddings = self.backbone(x.type(torch.float16)).float()
            elif self.hparams.encoder == "clean":
                noisy_embeddings = self.backbone(x.type(torch.float16)).float()

        return self.output(noisy_embeddings)

    def configure_optimizers(self):
        if not hasattr(self.hparams, 'weight_decay'):
            self.hparams.weight_decay = 0

        opt = torch.optim.Adam(self.output.parameters(), lr = self.hparams.lr, weight_decay = self.hparams.weight_decay)

        if self.hparams.dataset == "ImageNet100":
            num_steps = 126689//(self.hparams.batch_size * self.hparams.gpus) #divide N_train by number of distributed iters
            if self.hparams.use_subset:
                num_steps = num_steps * self.hparams.subset_ratio
        else:
            num_steps = 500

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=num_steps)

        return [opt], [scheduler]

    def train_dataloader(self):
        if self.hparams.dataset == "ImageNet100":
            train_dataset = ImageNet100(
                root=self.hparams.dataset_dir,
                split = 'train',
                transform = self.train_set_transform
            )
        N_train = len(train_dataset)
        if self.hparams.use_subset:
            train_dataset = few_shot_dataset(train_dataset, int(np.ceil(N_train*self.hparams.subset_ratio/self.hparams.num_classes)))

        train_dataloader = DataLoader(train_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\
                                        pin_memory=True, shuffle=True)

        return train_dataloader

    def val_dataloader(self):
        if self.hparams.dataset == "ImageNet100":
            val_dataset = ImageNet100(
                root=self.hparams.dataset_dir,
                split = 'val',
                transform = self.val_set_transform
            )

            self.N_val = 5000

        val_dataloader = DataLoader(val_dataset, batch_size=4*self.hparams.batch_size, num_workers=self.hparams.workers,\
                                        pin_memory=True, shuffle=False)

        return val_dataloader

    def training_step(self, batch, batch_idx):
        x, y = batch

        if batch_idx == 0 and self.current_epoch == 0:
            self.logger.experiment.add_image('Train_Sample', img_grid(x), self.current_epoch)

        logits = self.forward(x)

        loss = self.criterion(logits, y)

        self.log("train_loss", loss, prog_bar=True, on_step=True, \
                    on_epoch=True, logger=True, sync_dist=True, sync_dist_op='sum')

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch

        if batch_idx == 0 and self.current_epoch == 0:
            self.logger.experiment.add_image('Val_Sample', img_grid(x), self.current_epoch)

        logits = self.forward(x)
        pred_probs = logits.softmax(dim=-1)

        self.log("val_top_1", self.val_top_1(pred_probs, y), prog_bar=False, logger=False)
        self.log("val_top_5", self.val_top_5(pred_probs, y), prog_bar=False, logger=False)

    def validation_epoch_end(self, outputs):
        self.log("val_top_1", self.val_top_1.compute(), prog_bar=True, logger=True)
        self.log("val_top_5", self.val_top_5.compute(), prog_bar=True, logger=True)
        self.val_top_1.reset()
        self.val_top_5.reset()

    def test_step(self, batch, batch_idx):
        x, y = batch

        logits = self.forward(x)
        pred_probs = logits.softmax(dim=-1)

        self.log("test_top_1", self.test_top_1(pred_probs, y), prog_bar=False, logger=False)
        self.log("test_top_5", self.test_top_5(pred_probs, y), prog_bar=False, logger=False)

    def test_epoch_end(self, outputs):
        self.log("test_top_1", self.test_top_1.compute(), prog_bar=True, logger=True)
        self.log("test_top_5", self.test_top_5.compute(), prog_bar=True, logger=True)
        self.test_top_1.reset()
        self.test_top_5.reset()

    def predict(self, batch, batch_idx, hiddens):
        x, y = batch
        logits = self.forward(x)
        pred_probs = logits.softmax(dim=-1)
        preds = pred_probs.argmax(dim=-1)
        return preds
class NoisyCLIP(LightningModule):
    def __init__(self, args):
        """
        A class that trains OpenAI CLIP in a student-teacher fashion to classify distorted images.

        Given two identical pre-trained networks, Teacher - T() and Student S(), we freeze T() and train S().

        Given a batch of {x1, x2, ..., xN}, apply a given distortion to each one to obtain noisy images {y1, y2, ..., yN}.

        Feed original images to T() and obtain embeddings {T(x1), ..., T(xN)} and feed distorted images to S() and obtain embeddings {S(y1), ..., S(yN)}.

        Maximize the similarity between the pairs {(T(x1), S(y1)), ..., (T(xN), S(yN))} while minimizing the similarity between all non-matched pairs.
        """
        super(NoisyCLIP, self).__init__()
        self.hparams = args
        self.world_size = self.hparams.num_nodes * self.hparams.gpus

        #(1) Load the correct dataset class names
        if self.hparams.dataset == "ImageNet100" or self.hparams.dataset == "Imagenet-100":
            self.N_val = 5000 # Default ImageNet validation set, only 100 classes.

            #Retrieve the text labels for classes in order to do zero-shot classification.
            if self.hparams.mapping_and_text_file is None:
                raise ValueError('No file from which to read text labels was specified.')

            text_labels = pickle.load(open(self.hparams.mapping_and_text_file, 'rb'))
            self.text_list = ['A photo of '+label.strip().replace('_',' ') for label in text_labels]
        else:
            raise NotImplementedError('Handling of the dataset not implemented yet.')

        #Set up the dataset
        #Here, we use a 100-class subset of ImageNet
        if self.hparams.dataset != "ImageNet100" and self.hparams.dataset != "Imagenet-100":
            raise ValueError("Unsupported dataset selected.")
        elif not hasattr(self.hparams, 'increasing') or not self.hparams.increasing:
            if self.hparams.distortion == "None":
                self.train_set_transform = ImageNetBaseTrainContrastive(self.hparams)
                self.val_set_transform = ImageNetBaseTransformVal(self.hparams)
            elif self.hparams.distortion == "multi":
                self.train_set_transform = ImageNetDistortTrainMultiContrastive(self.hparams)
                self.val_set_transform = ImageNetDistortValMulti(self.hparams)
            else:
                #If we are using the ImageNet dataset, then set up the train and val sets to use the same mask if needed!
                self.train_set_transform = ImageNetDistortTrainContrastive(self.hparams)

                if self.hparams.fixed_mask:
                    self.val_set_transform = ImageNetDistortVal(self.hparams, fixed_distortion=self.train_set_transform.distortion)
                else:
                    self.val_set_transform = ImageNetDistortVal(self.hparams)

        #(2) set up the teacher CLIP network - freeze it and don't use gradients!
        self.logit_scale = self.hparams.logit_scale
        self.baseclip = clip.load(self.hparams.baseclip_type, self.hparams.device, jit=False)[0]
        self.baseclip.eval()
        self.baseclip.requires_grad_(False)

        #(3) set up the student CLIP network - unfreeze it and use gradients!
        self.noisy_visual_encoder = clip.load(self.hparams.baseclip_type, self.hparams.device, jit=False)[0].visual
        self.noisy_visual_encoder.train()

        #(4) set up the training and validation accuracy metrics.
        self.train_top_1 = Accuracy(top_k=1)
        self.train_top_5 = Accuracy(top_k=5)
        self.val_top_1 = Accuracy(top_k=1)
        self.val_top_5 = Accuracy(top_k=5)

    def criterion(self, input1, input2, reduction='mean'):
        """
        Args:
            input1: Embeddings of the clean/noisy images from the teacher/student. Size [N, embedding_dim].
            input2: Embeddings of the clean/noisy images from the teacher/student (the ones not used as input1). Size [N, embedding_dim].
            reduction: how to scale the final loss
        """
        bsz = input1.shape[0]

        # Use the simclr style InfoNCE
        if self.hparams.loss_type == 'simclr':
            # Create similarity matrix between embeddings.
            full_tensor = torch.cat([input1.unsqueeze(1),input2.unsqueeze(1)], dim=1).view(2*bsz, -1)
            #tensor1 = full_tensor.expand(2*bsz,2*bsz,-1)
            #tensor2 = full_tensor.permute(1,0,2).expand(2*bsz,2*bsz,-1)
            #sim_mat = torch.nn.CosineSimilarity(dim=-1)(tensor1,tensor2)
            full_tensor = full_tensor / full_tensor.norm(dim=-1, keepdim=True)
            sim_mat = full_tensor @ full_tensor.t()
            print(torch.sum(sim_mat < 0))
            # Calculate logits used for the contrastive loss.
            exp_sim_mat = torch.exp(sim_mat/self.hparams.loss_tau)
            mask = torch.ones_like(exp_sim_mat) - torch.eye(2*bsz).type_as(exp_sim_mat)
            logmat = -torch.log(exp_sim_mat)+torch.log(torch.sum(mask*exp_sim_mat, 1))

            #Grab the two off-diagonal similarities
            part1 = torch.sum(torch.diag(logmat, diagonal=1)[np.arange(0,2*bsz,2)])
            part2 = torch.sum(torch.diag(logmat, diagonal=-1)[np.arange(0,2*bsz,2)])

            #Take the mean of the two off-diagonals
            loss = (part1 + part2)/2

        #Use the CLIP-style InfoNCE
        elif self.hparams.loss_type == 'clip':
            # Create similarity matrix between embeddings.
            tensor1 = input1 / input1.norm(dim=-1, keepdim=True)
            tensor2 = input2 / input2.norm(dim=-1, keepdim=True)
            sim_mat = (1/self.hparams.loss_tau)*tensor1 @ tensor2.t()

            #Calculate the cross entropy between the similarities of the positive pairs, counted two ways
            part1 = F.cross_entropy(sim_mat, torch.LongTensor(np.arange(bsz)).to(self.device))
            part2 = F.cross_entropy(sim_mat.t(), torch.LongTensor(np.arange(bsz)).to(self.device))

            #Take the mean of the two off-diagonals
            loss = (part1+part2)/2

        #Take the simple MSE between the clean and noisy embeddings
        elif self.hparams.loss_type == 'mse':
            return F.mse_loss(input2, input1)

        elif self.hparams.loss_type.startswith('simclr_'):
            assert self.hparams.loss_type in ['simclr_ss', 'simclr_st', 'simclr_both']
            # Various schemes for the negative examples
            teacher_embeds = F.normalize(input1, dim=1)
            student_embeds = F.normalize(input2, dim=1)
            # First compute positive examples by taking <S(x_i), T(x_i)>/T for all i
            pos_term = (teacher_embeds * student_embeds).sum(dim=1) / self.hparams.loss_tau
            # Then generate the negative term by constructing various similarity matrices
            if self.hparams.loss_type == 'simclr_ss':
                cov = torch.mm(student_embeds, student_embeds.t())
                sim = torch.exp(cov / self.hparams.loss_tau) # shape is [bsz, bsz]
                neg_term = torch.log(sim.sum(dim=1) - sim.diag())
            elif self.hparams.loss_type == 'simclr_st':
                cov = torch.mm(student_embeds, teacher_embeds.t())
                sim = torch.exp(cov / self.hparams.loss_tau) # shape is [bsz, bsz]
                neg_term = torch.log(sim.sum(dim=1)) # Not removing the diagonal here!
            else:
                cat_embeds = torch.cat([student_embeds, teacher_embeds])
                cov = torch.mm(student_embeds, cat_embeds.t())
                sim = torch.exp(cov / self.hparams.loss_tau) # shape is [bsz, 2 * bsz]
                # and take row-wise sums w/o diagonals and
                neg_term = torch.log(sim.sum(dim=1) - sim.diag())
            # Final loss is
            loss = -1 * (pos_term - neg_term).sum() # (summed and then mean-reduced later)

        else:
            raise ValueError('Loss function not understood.')

        return loss/bsz if reduction == 'mean' else loss


    def configure_optimizers(self):
        optim = torch.optim.Adam(self.noisy_visual_encoder.parameters(), lr=self.hparams.lr)
        num_steps = 126689//(self.hparams.batch_size * self.hparams.gpus) #divide N_train by number of distributed iters
        sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=num_steps)
        return [optim], [sched]

    def encode_noisy_image(self, image):
        """
        Return S(yi) where S() is the student network and yi is distorted images.
        """
        return self.noisy_visual_encoder(image.type(torch.float16))

    def forward(self, image_features, text=None):
        """
        Given a set of noisy image embeddings, calculate the cosine similarity (scaled by temperature) of each image with each class text prompt.
        Calculates the similarity in two ways: logits per image (size = [N, n_classes]), logits per text (size = [n_classes, N]).
        This is mainly used for validation and classification.

        Args:
            image_features: the noisy image embeddings S(yi) where S() is the student and yi = Distort(xi). Shape [N, embedding_dim]
        """

        #load the pre-computed text features and load them on the correct device
        text_features = self.baseclip.encode_text(clip.tokenize(self.text_list).to(self.device))

        # normalized features
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        # cosine similarity as logits
        logits_per_image = self.logit_scale * image_features.type(torch.float16) @ text_features.type(torch.float16).t()
        logits_per_text = self.logit_scale * text_features.type(torch.float16) @ image_features.type(torch.float16).t()

        return logits_per_image, logits_per_text

    # Training methods - here we are concerned with contrastive loss (or MSE) between clean and noisy image embeddings.
    def training_step(self, train_batch, batch_idx):
        """
        Takes a batch of clean and noisy images and returns their respective embeddings.

        Returns:
            embed_clean: T(xi) where T() is the teacher and xi are clean images. Shape [N, embed_dim]
            embed_noisy: S(yi) where S() is the student and yi are noisy images. Shape [N, embed_dim]
        """
        image_clean, image_noisy, labels = train_batch
        embed_clean = self.baseclip.encode_image(image_clean)
        embed_noisy = self.encode_noisy_image(image_noisy)
        return {'embed_clean': embed_clean, 'embed_noisy': embed_noisy}

    def training_step_end(self, outputs):
        """
        Given all the clean and noisy image embeddings form across GPUs from training_step, gather them onto a single GPU and calculate overall loss.
        """
        embed_clean_full = outputs['embed_clean']
        embed_noisy_full = outputs['embed_noisy']
        loss = self.criterion(embed_clean_full, embed_noisy_full)
        self.log('train_loss', loss, prog_bar=False, logger=True, sync_dist=True, on_step=True, on_epoch=True)
        return loss

    # Validation methods - here we are concerned with similarity between noisy image embeddings and classification text embeddings.
    def validation_step(self, test_batch, batch_idx):
        """
        Grab the noisy image embeddings: S(yi), where S() is the student and yi = Distort(xi). Done on each GPU.
        Return these to be evaluated in validation step end.
        """
        images_noisy, labels = test_batch
        if batch_idx == 0 and self.current_epoch < 1:
            self.logger.experiment.add_image('Val_Sample', img_grid(images_noisy), self.current_epoch)
        image_features = self.encode_noisy_image(images_noisy)
        return {'image_features': image_features, 'labels': labels}

    def validation_step_end(self, outputs):
        """
        Gather the noisy image features and their labels from each GPU.
        Then calculate their similarities, convert to probabilities, and calculate accuracy on each GPU.
        """
        image_features_full = outputs['image_features']
        labels_full = outputs['labels']

        image_logits, _ = self.forward(image_features_full)
        image_logits = image_logits.float()
        image_probs = image_logits.softmax(dim=-1)
        self.log('val_top_1_step', self.val_top_1(image_probs, labels_full), prog_bar=False, logger=False)
        self.log('val_top_5_step', self.val_top_5(image_probs, labels_full), prog_bar=False, logger=False)

    def validation_epoch_end(self, outputs):
        """
        Gather the zero-shot validation accuracies from across GPUs and reduce.
        """
        self.log('val_top_1', self.val_top_1.compute(), prog_bar=True, logger=True)
        self.log('val_top_5', self.val_top_5.compute(), prog_bar=True, logger=True)
        self.val_top_1.reset()
        self.val_top_5.reset()

    # Default dataloaders - can be overwritten by datamodule.
    def train_dataloader(self):
        if hasattr(self.hparams, 'increasing') and self.hparams.increasing:
            datatf = ImageNetDistortTrainContrastive(self.hparams, epoch=self.current_epoch)
        else:
            datatf = self.train_set_transform

        train_dataset = ImageNet100(
            root=self.hparams.dataset_dir,
            split = 'train',
            transform = None
        )
        train_contrastive = ContrastiveUnsupervisedDataset(train_dataset, transform_contrastive=datatf, return_label=True)

        train_dataloader = DataLoader(train_contrastive, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\
                                        pin_memory=True, shuffle=True)

        return train_dataloader

    def val_dataloader(self):
        if hasattr(self.hparams, 'increasing') and self.hparams.increasing:
            datatf = ImageNetDistortVal(self.hparams, epoch=self.current_epoch)
        else:
            datatf = self.val_set_transform

        val_dataset = ImageNet100(
            root=self.hparams.dataset_dir,
            split = 'val',
            transform = datatf
        )
        self.N_val = len(val_dataset)

        val_dataloader = DataLoader(val_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\
                                        pin_memory=True, shuffle=False)

        return val_dataloader

    def test_dataloader(self):
        return self.val_dataloader()
Пример #3
0
class HPALit(pl.LightningModule):
    def __init__(self, params):
        super().__init__()

        self.hparams = params
        self.lr = self.hparams.lr
        self.save_hyperparameters()
        self.model = Net(name=self.hparams.model)
        self.criterion = torch.nn.BCEWithLogitsLoss()

        self.train_accuracy = Accuracy(subset_accuracy=True)
        self.val_accuracy = Accuracy(subset_accuracy=True)
        self.train_recall = Recall()
        self.val_recall = Recall()

        self.train_df = pd.read_csv(
            f'data_preprocessing/train_fold_{self.hparams.fold}.csv')
        self.valid_df = pd.read_csv(
            f'data_preprocessing/valid_fold_{self.hparams.fold}.csv')

        self.train_transforms = A.Compose([
            A.Rotate(),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.Resize(width=self.hparams.img_size,
                     height=self.hparams.img_size),
            A.Normalize(),
            ToTensorV2(),
        ])

        self.valid_transforms = A.Compose([
            A.Resize(width=self.hparams.img_size,
                     height=self.hparams.img_size),
            A.Normalize(),
            ToTensorV2(),
        ])

        self.train_dataset = CellDataset(data_dir=self.hparams.data_dir,
                                         csv_file=self.train_df,
                                         transform=self.train_transforms)
        self.val_dataset = CellDataset(data_dir=self.hparams.data_dir,
                                       csv_file=self.valid_df,
                                       transform=self.valid_transforms)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.hparams.batch_size,
            shuffle=True,
            num_workers=self.hparams.n_workers,
            pin_memory=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.hparams.batch_size,
            shuffle=False,
            num_workers=self.hparams.n_workers,
            pin_memory=True,
        )

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                               'min',
                                                               factor=0.5,
                                                               patience=2,
                                                               eps=1e-6)
        lr_scheduler = {
            'scheduler': scheduler,
            'interval': 'epoch',
            'monitor': 'valid_loss_epoch'
        }

        return [optimizer], [lr_scheduler]

    def training_step(self, batch, batch_idx):
        x, y = batch
        pred = self.model(x)

        train_loss = self.criterion(pred, y)
        self.train_accuracy(torch.sigmoid(pred), y.type(torch.int))
        self.train_recall(torch.sigmoid(pred), y.type(torch.int))

        return {'loss': train_loss}

    def training_epoch_end(self, outputs):
        train_loss_epoch = torch.stack([x['loss'] for x in outputs]).mean()
        self.log('train_loss_epoch', train_loss_epoch)
        self.log('train_acc_epoch', self.train_accuracy.compute())
        self.log('train_recall_epoch', self.train_recall.compute())
        self.train_accuracy.reset()
        self.train_recall.reset()

    def validation_step(self, batch, batch_idx):
        x, y = batch
        pred = self.model(x)

        val_loss = self.criterion(pred, y)
        self.val_accuracy(torch.sigmoid(pred), y.type(torch.int))
        self.val_recall(torch.sigmoid(pred), y.type(torch.int))

        return {'valid_loss': val_loss}

    def validation_epoch_end(self, outputs):
        val_loss_epoch = torch.stack([x['valid_loss'] for x in outputs]).mean()
        self.log('valid_loss_epoch', val_loss_epoch)
        self.log('valid_acc_epoch', self.val_accuracy.compute())
        self.log('valid_recall_epoch', self.val_recall.compute())
        self.val_accuracy.reset()
        self.val_recall.reset()
Пример #4
0
class Baseline(LightningModule):
    """
    Class for training a classification model on distorted ImageNet inan end-to-end fashion.
    """
    def __init__(self, args):
        super(Baseline, self).__init__()

        self.hparams = args
        self.world_size = self.hparams.num_nodes * self.hparams.gpus

        self.lr = self.hparams.lr  #initalise this specially as a tuneable parameter

        #(1) Set up the dataset
        #Here, we use a 100-class subset of ImageNet
        if self.hparams.dataset != "ImageNet100":
            raise ValueError("Unsupported dataset selected.")
        else:
            if self.hparams.distortion == "None":
                self.train_set_transform = ImageNetBaseTransform(self.hparams)
                self.val_set_transform = ImageNetBaseTransformVal(self.hparams)
            elif self.hparams.distortion == "multi":
                self.train_set_transform = ImageNetDistortTrainMulti(
                    self.hparams)
                self.val_set_transform = ImageNetDistortValMulti(self.hparams)
            else:
                #If we are using the ImageNet dataset, then set up the train and val sets to use the same mask if needed!
                self.train_set_transform = ImageNetDistortTrain(self.hparams)

                if self.hparams.fixed_mask:
                    self.val_set_transform = ImageNetDistortVal(
                        self.hparams,
                        fixed_distortion=self.train_set_transform.distortion)
                else:
                    self.val_set_transform = ImageNetDistortVal(self.hparams)

        #(2) Grab the correct baseline pre-trained model
        if self.hparams.encoder == 'resnet':
            self.encoder = RESNET_finetune(self.hparams)
        elif self.hparams.encoder == 'clip':
            self.encoder = CLIP_finetune(self.hparams)
        else:
            raise ValueError("Please select a valid encoder model.")

        #(3) Set up our criterion - here we use reduction as "sum" so that we are able to average over all validation sets
        self.criterion = nn.CrossEntropyLoss(reduction="sum")

        self.train_top_1 = Accuracy(top_k=1)
        self.train_top_5 = Accuracy(top_k=5)

        self.val_top_1 = Accuracy(top_k=1)
        self.val_top_5 = Accuracy(top_k=5)

        self.test_top_1 = Accuracy(top_k=1)
        self.test_top_5 = Accuracy(top_k=5)

    def forward(self, x):
        return self.encoder(x)

    def configure_optimizers(self):
        opt = torch.optim.Adam(self.encoder.parameters(), lr=self.lr)

        if self.hparams.dataset == "ImageNet100":
            num_steps = 126689 // (
                self.hparams.batch_size * self.hparams.gpus
            )  #divide N_train by number of distributed iters
        else:
            num_steps = 500

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt,
                                                               T_max=num_steps)

        return [opt], [scheduler]

    #DATALOADERS
    def train_dataloader(self):
        if self.hparams.dataset == "ImageNet100":
            train_dataset = ImageNet100(root=self.hparams.dataset_dir,
                                        split='train',
                                        transform=self.train_set_transform)

        train_dataloader = DataLoader(train_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\
                                        pin_memory=True, shuffle=True)

        return train_dataloader

    def val_dataloader(self):
        if self.hparams.dataset == "ImageNet100":
            val_dataset = ImageNet100(root=self.hparams.dataset_dir,
                                      split='val',
                                      transform=self.val_set_transform)

            self.N_val = 5000

        val_dataloader = DataLoader(val_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\
                                        pin_memory=True, shuffle=False)

        return val_dataloader

    def test_dataloader(self):
        if self.hparams.dataset == "ImageNet100":
            test_dataset = ImageNet100(root=self.hparams.dataset_dir,
                                       split='val',
                                       transform=self.val_set_transform)

        test_dataloader = DataLoader(test_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\
                                        pin_memory=True, shuffle=False)

        return test_dataloader

    #TRAINING
    def training_step(self, batch, batch_idx):
        """
        Given a batch of images, train the model for one step.
        Calculates the crossentropy loss as well as top_1 and top_5 per batch

        Inputs:
            batch - the images to train on, shape [batch_size, num_channels, height, width]
            batch_idx - the index of the current batch
        
        Returns:
            loss - the crossentropy loss between the model's logits and the true classes of the inputs
        """
        x, y = batch

        if batch_idx == 0 and self.current_epoch == 0:
            self.logger.experiment.add_image('Train_Sample', img_grid(x),
                                             self.current_epoch)

        logits = self.forward(x)
        pred_probs = logits.softmax(dim=-1)

        loss = self.criterion(logits, y)

        self.log("train_loss",
                 loss,
                 prog_bar=True,
                 on_step=True,
                 on_epoch=True,
                 logger=True,
                 sync_dist=True,
                 sync_dist_op='sum')
        self.log("train_top_1",
                 self.train_top_1(pred_probs, y),
                 prog_bar=False,
                 logger=False)
        self.log("train_top_5",
                 self.train_top_5(pred_probs, y),
                 prog_bar=False,
                 logger=False)

        return loss

    def training_epoch_end(self, outputs):
        self.log("train_top_1",
                 self.train_top_1.compute(),
                 prog_bar=True,
                 logger=True)
        self.log("train_top_5",
                 self.train_top_5.compute(),
                 prog_bar=True,
                 logger=True)

        self.train_top_1.reset()
        self.train_top_5.reset()

    #VALIDATION
    def validation_step(self, batch, batch_idx):
        x, y = batch

        if batch_idx == 0 and self.current_epoch == 0:
            self.logger.experiment.add_image('Val_Sample', img_grid(x),
                                             self.current_epoch)

        logits = self.forward(x)
        pred_probs = logits.softmax(dim=-1)

        self.log("val_top_1",
                 self.val_top_1(pred_probs, y),
                 prog_bar=False,
                 logger=False)
        self.log("val_top_5",
                 self.val_top_5(pred_probs, y),
                 prog_bar=False,
                 logger=False)

    def validation_epoch_end(self, outputs):
        self.log("val_top_1",
                 self.val_top_1.compute(),
                 prog_bar=True,
                 logger=True)
        self.log("val_top_5",
                 self.val_top_5.compute(),
                 prog_bar=True,
                 logger=True)

        self.val_top_1.reset()
        self.val_top_5.reset()

    #TESTING
    def test_step(self, batch, batch_idx):
        x, y = batch

        if batch_idx == 0 and self.current_epoch == 0:
            self.logger.experiment.add_image('Test_Sample', img_grid(x),
                                             self.current_epoch)

        logits = self.forward(x)
        pred_probs = logits.softmax(dim=-1)

        self.log("test_top_1",
                 self.test_top_1(pred_probs, y),
                 prog_bar=False,
                 logger=False)
        self.log("test_top_5",
                 self.test_top_5(pred_probs, y),
                 prog_bar=False,
                 logger=False)

    def test_epoch_end(self, outputs):
        self.log("test_top_1",
                 self.test_top_1.compute(),
                 prog_bar=True,
                 logger=True)
        self.log("test_top_5",
                 self.test_top_5.compute(),
                 prog_bar=True,
                 logger=True)

        self.test_top_1.reset()
        self.test_top_5.reset()
class TransferLearning(LightningModule):
    def __init__(self, args):
        super(TransferLearning, self).__init__()

        self.hparams = args
        self.world_size = self.hparams.num_nodes * self.hparams.gpus

        self.train_set_transform, self.val_set_transform = grab_transforms(
            self.hparams)

        #Grab the correct model - only want the embeddings from the final layer!
        if self.hparams.saved_model_type == 'contrastive':
            saved_model = NoisyCLIP.load_from_checkpoint(
                self.hparams.checkpoint_path)
            self.backbone = saved_model.noisy_visual_encoder
        elif self.hparams.saved_model_type == 'baseline':
            saved_model = Baseline.load_from_checkpoint(
                self.hparams.checkpoint_path)
            self.backbone = saved_model.encoder.feature_extractor

        for param in self.backbone.parameters():
            param.requires_grad = False

        #Set up a classifier with the correct dimensions
        self.output = nn.Linear(self.hparams.emb_dim, self.hparams.num_classes)

        #Set up the criterion and stuff
        #(3) Set up our criterion - here we use reduction as "sum" so that we are able to average over all validation sets
        self.criterion = nn.CrossEntropyLoss(reduction="mean")

        self.train_top_1 = Accuracy(top_k=1)
        self.train_top_5 = Accuracy(top_k=5)

        self.val_top_1 = Accuracy(top_k=1)
        self.val_top_5 = Accuracy(top_k=5)

        self.test_top_1 = Accuracy(top_k=1)
        self.test_top_5 = Accuracy(top_k=5)

        #class INFECTED has label 0
        if self.hparams.dataset == 'COVID':
            self.val_auc = AUROC(pos_label=0)

            self.test_auc = AUROC(pos_label=0)

    def forward(self, x):
        #Grab the noisy image embeddings
        self.backbone.eval()
        with torch.no_grad():
            if self.hparams.encoder == "clip":
                noisy_embeddings = self.backbone(x.type(torch.float16)).float()
            elif self.hparams.encoder == "resnet":
                noisy_embeddings = self.backbone(x)

        return self.output(noisy_embeddings.flatten(1))

    def configure_optimizers(self):
        if not hasattr(self.hparams, 'weight_decay'):
            self.hparams.weight_decay = 0

        opt = torch.optim.Adam(self.output.parameters(),
                               lr=self.hparams.lr,
                               weight_decay=self.hparams.weight_decay)

        num_steps = self.hparams.max_epochs

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt,
                                                               T_max=num_steps)

        return [opt], [scheduler]

    def _grab_dataset(self, split):
        """
        Given a split ("train" or "val" or "test") and a dataset, returns the proper dataset.
        Dataset needed is defined in this object's hparams

        Args:
            split: the split to use in the dataset
        Returns:
            dataset: the desired dataset with the correct split
        """
        if self.hparams.dataset == "CIFAR10":
            if split == 'train':
                train = True
                transform = self.train_set_transform
            else:
                train = False
                transform = self.val_set_transform

            dataset = CIFAR10(root=self.hparams.dataset_dir,
                              train=train,
                              transform=transform,
                              download=True)

        elif self.hparams.dataset == "CIFAR100":
            if split == 'train':
                train = True
                transform = self.train_set_transform
            else:
                train = False
                transform = self.val_set_transform

            dataset = CIFAR100(root=self.hparams.dataset_dir,
                               train=train,
                               transform=transform,
                               download=True)

        elif self.hparams.dataset == 'STL10':
            if split == 'train':
                stlsplit = 'train'
                transform = self.train_set_transform
            else:
                stlsplit = 'test'
                transform = self.val_set_transform

            dataset = STL10(root=self.hparams.dataset_dir,
                            split=stlsplit,
                            transform=transform,
                            download=True)

        elif self.hparams.dataset == 'COVID':
            if split == 'train':
                covidsplit = 'train'
                transform = self.train_set_transform
            else:
                covidsplit = 'test'
                transform = self.val_set_transform

            dataset = torchvision.datasets.ImageFolder(
                root=self.hparams.dataset_dir + covidsplit,
                transform=transform)

        elif self.hparams.dataset == 'ImageNet100B' or self.hparams.dataset == 'imagenet-100B':
            if split == 'train':
                transform = self.train_set_transform
            else:
                split = 'val'
                transform = self.val_set_transform

            dataset = ImageNet100(root=self.hparams.dataset_dir,
                                  split=split,
                                  transform=transform)

        elif self.hparams.dataset == 'COVID':
            if split == 'train':
                covidsplit = 'train'
                transform = self.train_set_transform
            else:
                covidsplit = 'test'
                transform = self.val_set_transform

            dataset = torchvision.datasets.ImageFolder(
                root=self.hparams.dataset_dir + covidsplit,
                transform=transform)

        elif self.hparams.dataset == 'ImageNet100B' or self.hparams.dataset == 'imagenet-100B':
            if split == 'train':
                transform = self.train_set_transform
            else:
                split = 'val'
                transform = self.val_set_transform

            dataset = ImageNet100(root=self.hparams.dataset_dir,
                                  split=split,
                                  transform=transform)

        return dataset

    def train_dataloader(self):
        train_dataset = self._grab_dataset(split='train')

        N_train = len(train_dataset)
        if self.hparams.use_subset:
            train_dataset = few_shot_dataset(
                train_dataset,
                int(
                    np.ceil(N_train * self.hparams.subset_ratio /
                            self.hparams.num_classes)))

        train_dataloader = DataLoader(train_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\
                                        pin_memory=True, shuffle=True)

        return train_dataloader

    def val_dataloader(self):
        val_dataset = self._grab_dataset(split='val')

        N_val = len(val_dataset)

        #SET SHUFFLE TO TRUE SINCE AUROC FREAKS OUT IF IT GETS AN ALL-1 OR ALL-0 BATCH
        val_dataloader = DataLoader(val_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\
                                        pin_memory=True, shuffle=True)

        return val_dataloader

    def test_dataloader(self):
        test_dataset = self._grab_dataset(split='test')

        N_test = len(test_dataset)

        #SET SHUFFLE TO TRUE SINCE AUROC FREAKS OUT IF IT GETS AN ALL-1 OR ALL-0 BATCH
        test_dataloader = DataLoader(test_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\
                                        pin_memory=True, shuffle=True)

        return test_dataloader

    def training_step(self, batch, batch_idx):
        x, y = batch

        if batch_idx == 0 and self.current_epoch == 0:
            self.logger.experiment.add_image('Train_Sample', img_grid(x),
                                             self.current_epoch)

        logits = self.forward(x)

        loss = self.criterion(logits, y)

        self.log("train_loss", loss, prog_bar=False, on_step=True, \
                    on_epoch=True, logger=True, sync_dist=True, sync_dist_op='sum')

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch

        if batch_idx == 0 and self.current_epoch == 0:
            self.logger.experiment.add_image('Val_Sample', img_grid(x),
                                             self.current_epoch)

        logits = self.forward(x)
        pred_probs = logits.softmax(dim=-1)  #(N, num_classes)

        if self.hparams.dataset == 'COVID':
            positive_prob = pred_probs[:,
                                       0].flatten()  #class 0 is INFECTED label
            true_labels = y.flatten()

            self.val_auc.update(positive_prob, true_labels)

            self.log("val_auc", self.val_auc, prog_bar=False, logger=False)

        self.log("val_top_1",
                 self.val_top_1(pred_probs, y),
                 prog_bar=False,
                 logger=False)

        if self.hparams.dataset != 'COVID':
            self.log("val_top_5",
                     self.val_top_5(pred_probs, y),
                     prog_bar=False,
                     logger=False)

    def validation_epoch_end(self, outputs):
        self.log("val_top_1",
                 self.val_top_1.compute(),
                 prog_bar=True,
                 logger=True)

        if self.hparams.dataset != 'COVID':
            self.log("val_top_5",
                     self.val_top_5.compute(),
                     prog_bar=True,
                     logger=True)

        if self.hparams.dataset == 'COVID':
            self.log("val_auc",
                     self.val_auc.compute(),
                     prog_bar=True,
                     logger=True)

            self.val_auc.reset()

        self.val_top_1.reset()
        self.val_top_5.reset()

    def test_step(self, batch, batch_idx):
        x, y = batch

        logits = self.forward(x)
        pred_probs = logits.softmax(dim=-1)

        if self.hparams.dataset == 'COVID':
            positive_prob = pred_probs[:,
                                       0].flatten()  #class 0 is INFECTED label
            true_labels = y.flatten()

            self.test_auc.update(positive_prob, true_labels)

            self.log("test_auc", self.test_auc, prog_bar=False, logger=False)

        self.log("test_top_1",
                 self.test_top_1(pred_probs, y),
                 prog_bar=False,
                 logger=False)
        if self.hparams.dataset != 'COVID':
            self.log("test_top_5",
                     self.test_top_5(pred_probs, y),
                     prog_bar=False,
                     logger=False)

    def test_epoch_end(self, outputs):
        self.log("test_top_1",
                 self.test_top_1.compute(),
                 prog_bar=True,
                 logger=True)
        if self.hparams.dataset != 'COVID':
            self.log("test_top_5",
                     self.test_top_5.compute(),
                     prog_bar=True,
                     logger=True)

        if self.hparams.dataset == 'COVID':
            self.log("test_auc",
                     self.test_auc.compute(),
                     prog_bar=True,
                     logger=True)

            self.test_auc.reset()

        self.test_top_1.reset()
        self.test_top_5.reset()