class LinearProbe(LightningModule): """ A class to train and evaluate a linear probe on top of representations learned from a noisy clip student. """ def __init__(self, args): super(LinearProbe, self).__init__() self.hparams = args #(1) Set up the dataset #Here, we use a 100-class subset of ImageNet if self.hparams.dataset not in ["ImageNet100", "CIFAR10", "CIFAR100"]: raise ValueError("Unsupported dataset selected.") else: if self.hparams.distortion == "None": self.train_set_transform = ImageNetBaseTransform(self.hparams) self.val_set_transform = ImageNetBaseTransformVal(self.hparams) elif self.hparams.distortion == "multi": self.train_set_transform = ImageNetDistortTrainMulti(self.hparams) self.val_set_transform = ImageNetDistortValMulti(self.hparams) else: #If we are using the ImageNet dataset, then set up the train and val sets to use the same mask if needed! self.train_set_transform = ImageNetDistortTrain(self.hparams) if self.hparams.fixed_mask: self.val_set_transform = ImageNetDistortVal(self.hparams, fixed_distortion=self.train_set_transform.distortion) else: self.val_set_transform = ImageNetDistortVal(self.hparams) #This should be initialised as a trained student CLIP network if self.hparams.encoder == "clip": saved_student = NoisyCLIP.load_from_checkpoint(self.hparams.checkpoint_path) self.backbone = saved_student.noisy_visual_encoder elif self.hparams.encoder == "clean": saved_student = clip.load('RN101', 'cpu', jit=False)[0] self.backbone = saved_student.visual self.backbone.eval() for param in self.backbone.parameters(): param.requires_grad = False #This is the meat self.output = nn.Linear(self.hparams.emb_dim, self.hparams.num_classes) #Set up training and validation metrics self.criterion = nn.CrossEntropyLoss() self.val_top_1 = Accuracy(top_k=1) self.val_top_5 = Accuracy(top_k=5) self.test_top_1 = Accuracy(top_k=1) self.test_top_5 = Accuracy(top_k=5) def forward(self, x): """ Given a set of images x with shape [N, c, h, w], get their embeddings and then logits. Returns: Logits with shape [N, n_classes] """ #Grab the noisy image embeddings self.backbone.eval() with torch.no_grad(): if self.hparams.encoder == "clip": noisy_embeddings = self.backbone(x.type(torch.float16)).float() elif self.hparams.encoder == "clean": noisy_embeddings = self.backbone(x.type(torch.float16)).float() return self.output(noisy_embeddings) def configure_optimizers(self): if not hasattr(self.hparams, 'weight_decay'): self.hparams.weight_decay = 0 opt = torch.optim.Adam(self.output.parameters(), lr = self.hparams.lr, weight_decay = self.hparams.weight_decay) if self.hparams.dataset == "ImageNet100": num_steps = 126689//(self.hparams.batch_size * self.hparams.gpus) #divide N_train by number of distributed iters if self.hparams.use_subset: num_steps = num_steps * self.hparams.subset_ratio else: num_steps = 500 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=num_steps) return [opt], [scheduler] def train_dataloader(self): if self.hparams.dataset == "ImageNet100": train_dataset = ImageNet100( root=self.hparams.dataset_dir, split = 'train', transform = self.train_set_transform ) N_train = len(train_dataset) if self.hparams.use_subset: train_dataset = few_shot_dataset(train_dataset, int(np.ceil(N_train*self.hparams.subset_ratio/self.hparams.num_classes))) train_dataloader = DataLoader(train_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\ pin_memory=True, shuffle=True) return train_dataloader def val_dataloader(self): if self.hparams.dataset == "ImageNet100": val_dataset = ImageNet100( root=self.hparams.dataset_dir, split = 'val', transform = self.val_set_transform ) self.N_val = 5000 val_dataloader = DataLoader(val_dataset, batch_size=4*self.hparams.batch_size, num_workers=self.hparams.workers,\ pin_memory=True, shuffle=False) return val_dataloader def training_step(self, batch, batch_idx): x, y = batch if batch_idx == 0 and self.current_epoch == 0: self.logger.experiment.add_image('Train_Sample', img_grid(x), self.current_epoch) logits = self.forward(x) loss = self.criterion(logits, y) self.log("train_loss", loss, prog_bar=True, on_step=True, \ on_epoch=True, logger=True, sync_dist=True, sync_dist_op='sum') return loss def validation_step(self, batch, batch_idx): x, y = batch if batch_idx == 0 and self.current_epoch == 0: self.logger.experiment.add_image('Val_Sample', img_grid(x), self.current_epoch) logits = self.forward(x) pred_probs = logits.softmax(dim=-1) self.log("val_top_1", self.val_top_1(pred_probs, y), prog_bar=False, logger=False) self.log("val_top_5", self.val_top_5(pred_probs, y), prog_bar=False, logger=False) def validation_epoch_end(self, outputs): self.log("val_top_1", self.val_top_1.compute(), prog_bar=True, logger=True) self.log("val_top_5", self.val_top_5.compute(), prog_bar=True, logger=True) self.val_top_1.reset() self.val_top_5.reset() def test_step(self, batch, batch_idx): x, y = batch logits = self.forward(x) pred_probs = logits.softmax(dim=-1) self.log("test_top_1", self.test_top_1(pred_probs, y), prog_bar=False, logger=False) self.log("test_top_5", self.test_top_5(pred_probs, y), prog_bar=False, logger=False) def test_epoch_end(self, outputs): self.log("test_top_1", self.test_top_1.compute(), prog_bar=True, logger=True) self.log("test_top_5", self.test_top_5.compute(), prog_bar=True, logger=True) self.test_top_1.reset() self.test_top_5.reset() def predict(self, batch, batch_idx, hiddens): x, y = batch logits = self.forward(x) pred_probs = logits.softmax(dim=-1) preds = pred_probs.argmax(dim=-1) return preds
class NoisyCLIP(LightningModule): def __init__(self, args): """ A class that trains OpenAI CLIP in a student-teacher fashion to classify distorted images. Given two identical pre-trained networks, Teacher - T() and Student S(), we freeze T() and train S(). Given a batch of {x1, x2, ..., xN}, apply a given distortion to each one to obtain noisy images {y1, y2, ..., yN}. Feed original images to T() and obtain embeddings {T(x1), ..., T(xN)} and feed distorted images to S() and obtain embeddings {S(y1), ..., S(yN)}. Maximize the similarity between the pairs {(T(x1), S(y1)), ..., (T(xN), S(yN))} while minimizing the similarity between all non-matched pairs. """ super(NoisyCLIP, self).__init__() self.hparams = args self.world_size = self.hparams.num_nodes * self.hparams.gpus #(1) Load the correct dataset class names if self.hparams.dataset == "ImageNet100" or self.hparams.dataset == "Imagenet-100": self.N_val = 5000 # Default ImageNet validation set, only 100 classes. #Retrieve the text labels for classes in order to do zero-shot classification. if self.hparams.mapping_and_text_file is None: raise ValueError('No file from which to read text labels was specified.') text_labels = pickle.load(open(self.hparams.mapping_and_text_file, 'rb')) self.text_list = ['A photo of '+label.strip().replace('_',' ') for label in text_labels] else: raise NotImplementedError('Handling of the dataset not implemented yet.') #Set up the dataset #Here, we use a 100-class subset of ImageNet if self.hparams.dataset != "ImageNet100" and self.hparams.dataset != "Imagenet-100": raise ValueError("Unsupported dataset selected.") elif not hasattr(self.hparams, 'increasing') or not self.hparams.increasing: if self.hparams.distortion == "None": self.train_set_transform = ImageNetBaseTrainContrastive(self.hparams) self.val_set_transform = ImageNetBaseTransformVal(self.hparams) elif self.hparams.distortion == "multi": self.train_set_transform = ImageNetDistortTrainMultiContrastive(self.hparams) self.val_set_transform = ImageNetDistortValMulti(self.hparams) else: #If we are using the ImageNet dataset, then set up the train and val sets to use the same mask if needed! self.train_set_transform = ImageNetDistortTrainContrastive(self.hparams) if self.hparams.fixed_mask: self.val_set_transform = ImageNetDistortVal(self.hparams, fixed_distortion=self.train_set_transform.distortion) else: self.val_set_transform = ImageNetDistortVal(self.hparams) #(2) set up the teacher CLIP network - freeze it and don't use gradients! self.logit_scale = self.hparams.logit_scale self.baseclip = clip.load(self.hparams.baseclip_type, self.hparams.device, jit=False)[0] self.baseclip.eval() self.baseclip.requires_grad_(False) #(3) set up the student CLIP network - unfreeze it and use gradients! self.noisy_visual_encoder = clip.load(self.hparams.baseclip_type, self.hparams.device, jit=False)[0].visual self.noisy_visual_encoder.train() #(4) set up the training and validation accuracy metrics. self.train_top_1 = Accuracy(top_k=1) self.train_top_5 = Accuracy(top_k=5) self.val_top_1 = Accuracy(top_k=1) self.val_top_5 = Accuracy(top_k=5) def criterion(self, input1, input2, reduction='mean'): """ Args: input1: Embeddings of the clean/noisy images from the teacher/student. Size [N, embedding_dim]. input2: Embeddings of the clean/noisy images from the teacher/student (the ones not used as input1). Size [N, embedding_dim]. reduction: how to scale the final loss """ bsz = input1.shape[0] # Use the simclr style InfoNCE if self.hparams.loss_type == 'simclr': # Create similarity matrix between embeddings. full_tensor = torch.cat([input1.unsqueeze(1),input2.unsqueeze(1)], dim=1).view(2*bsz, -1) #tensor1 = full_tensor.expand(2*bsz,2*bsz,-1) #tensor2 = full_tensor.permute(1,0,2).expand(2*bsz,2*bsz,-1) #sim_mat = torch.nn.CosineSimilarity(dim=-1)(tensor1,tensor2) full_tensor = full_tensor / full_tensor.norm(dim=-1, keepdim=True) sim_mat = full_tensor @ full_tensor.t() print(torch.sum(sim_mat < 0)) # Calculate logits used for the contrastive loss. exp_sim_mat = torch.exp(sim_mat/self.hparams.loss_tau) mask = torch.ones_like(exp_sim_mat) - torch.eye(2*bsz).type_as(exp_sim_mat) logmat = -torch.log(exp_sim_mat)+torch.log(torch.sum(mask*exp_sim_mat, 1)) #Grab the two off-diagonal similarities part1 = torch.sum(torch.diag(logmat, diagonal=1)[np.arange(0,2*bsz,2)]) part2 = torch.sum(torch.diag(logmat, diagonal=-1)[np.arange(0,2*bsz,2)]) #Take the mean of the two off-diagonals loss = (part1 + part2)/2 #Use the CLIP-style InfoNCE elif self.hparams.loss_type == 'clip': # Create similarity matrix between embeddings. tensor1 = input1 / input1.norm(dim=-1, keepdim=True) tensor2 = input2 / input2.norm(dim=-1, keepdim=True) sim_mat = (1/self.hparams.loss_tau)*tensor1 @ tensor2.t() #Calculate the cross entropy between the similarities of the positive pairs, counted two ways part1 = F.cross_entropy(sim_mat, torch.LongTensor(np.arange(bsz)).to(self.device)) part2 = F.cross_entropy(sim_mat.t(), torch.LongTensor(np.arange(bsz)).to(self.device)) #Take the mean of the two off-diagonals loss = (part1+part2)/2 #Take the simple MSE between the clean and noisy embeddings elif self.hparams.loss_type == 'mse': return F.mse_loss(input2, input1) elif self.hparams.loss_type.startswith('simclr_'): assert self.hparams.loss_type in ['simclr_ss', 'simclr_st', 'simclr_both'] # Various schemes for the negative examples teacher_embeds = F.normalize(input1, dim=1) student_embeds = F.normalize(input2, dim=1) # First compute positive examples by taking <S(x_i), T(x_i)>/T for all i pos_term = (teacher_embeds * student_embeds).sum(dim=1) / self.hparams.loss_tau # Then generate the negative term by constructing various similarity matrices if self.hparams.loss_type == 'simclr_ss': cov = torch.mm(student_embeds, student_embeds.t()) sim = torch.exp(cov / self.hparams.loss_tau) # shape is [bsz, bsz] neg_term = torch.log(sim.sum(dim=1) - sim.diag()) elif self.hparams.loss_type == 'simclr_st': cov = torch.mm(student_embeds, teacher_embeds.t()) sim = torch.exp(cov / self.hparams.loss_tau) # shape is [bsz, bsz] neg_term = torch.log(sim.sum(dim=1)) # Not removing the diagonal here! else: cat_embeds = torch.cat([student_embeds, teacher_embeds]) cov = torch.mm(student_embeds, cat_embeds.t()) sim = torch.exp(cov / self.hparams.loss_tau) # shape is [bsz, 2 * bsz] # and take row-wise sums w/o diagonals and neg_term = torch.log(sim.sum(dim=1) - sim.diag()) # Final loss is loss = -1 * (pos_term - neg_term).sum() # (summed and then mean-reduced later) else: raise ValueError('Loss function not understood.') return loss/bsz if reduction == 'mean' else loss def configure_optimizers(self): optim = torch.optim.Adam(self.noisy_visual_encoder.parameters(), lr=self.hparams.lr) num_steps = 126689//(self.hparams.batch_size * self.hparams.gpus) #divide N_train by number of distributed iters sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=num_steps) return [optim], [sched] def encode_noisy_image(self, image): """ Return S(yi) where S() is the student network and yi is distorted images. """ return self.noisy_visual_encoder(image.type(torch.float16)) def forward(self, image_features, text=None): """ Given a set of noisy image embeddings, calculate the cosine similarity (scaled by temperature) of each image with each class text prompt. Calculates the similarity in two ways: logits per image (size = [N, n_classes]), logits per text (size = [n_classes, N]). This is mainly used for validation and classification. Args: image_features: the noisy image embeddings S(yi) where S() is the student and yi = Distort(xi). Shape [N, embedding_dim] """ #load the pre-computed text features and load them on the correct device text_features = self.baseclip.encode_text(clip.tokenize(self.text_list).to(self.device)) # normalized features image_features = image_features / image_features.norm(dim=-1, keepdim=True) text_features = text_features / text_features.norm(dim=-1, keepdim=True) # cosine similarity as logits logits_per_image = self.logit_scale * image_features.type(torch.float16) @ text_features.type(torch.float16).t() logits_per_text = self.logit_scale * text_features.type(torch.float16) @ image_features.type(torch.float16).t() return logits_per_image, logits_per_text # Training methods - here we are concerned with contrastive loss (or MSE) between clean and noisy image embeddings. def training_step(self, train_batch, batch_idx): """ Takes a batch of clean and noisy images and returns their respective embeddings. Returns: embed_clean: T(xi) where T() is the teacher and xi are clean images. Shape [N, embed_dim] embed_noisy: S(yi) where S() is the student and yi are noisy images. Shape [N, embed_dim] """ image_clean, image_noisy, labels = train_batch embed_clean = self.baseclip.encode_image(image_clean) embed_noisy = self.encode_noisy_image(image_noisy) return {'embed_clean': embed_clean, 'embed_noisy': embed_noisy} def training_step_end(self, outputs): """ Given all the clean and noisy image embeddings form across GPUs from training_step, gather them onto a single GPU and calculate overall loss. """ embed_clean_full = outputs['embed_clean'] embed_noisy_full = outputs['embed_noisy'] loss = self.criterion(embed_clean_full, embed_noisy_full) self.log('train_loss', loss, prog_bar=False, logger=True, sync_dist=True, on_step=True, on_epoch=True) return loss # Validation methods - here we are concerned with similarity between noisy image embeddings and classification text embeddings. def validation_step(self, test_batch, batch_idx): """ Grab the noisy image embeddings: S(yi), where S() is the student and yi = Distort(xi). Done on each GPU. Return these to be evaluated in validation step end. """ images_noisy, labels = test_batch if batch_idx == 0 and self.current_epoch < 1: self.logger.experiment.add_image('Val_Sample', img_grid(images_noisy), self.current_epoch) image_features = self.encode_noisy_image(images_noisy) return {'image_features': image_features, 'labels': labels} def validation_step_end(self, outputs): """ Gather the noisy image features and their labels from each GPU. Then calculate their similarities, convert to probabilities, and calculate accuracy on each GPU. """ image_features_full = outputs['image_features'] labels_full = outputs['labels'] image_logits, _ = self.forward(image_features_full) image_logits = image_logits.float() image_probs = image_logits.softmax(dim=-1) self.log('val_top_1_step', self.val_top_1(image_probs, labels_full), prog_bar=False, logger=False) self.log('val_top_5_step', self.val_top_5(image_probs, labels_full), prog_bar=False, logger=False) def validation_epoch_end(self, outputs): """ Gather the zero-shot validation accuracies from across GPUs and reduce. """ self.log('val_top_1', self.val_top_1.compute(), prog_bar=True, logger=True) self.log('val_top_5', self.val_top_5.compute(), prog_bar=True, logger=True) self.val_top_1.reset() self.val_top_5.reset() # Default dataloaders - can be overwritten by datamodule. def train_dataloader(self): if hasattr(self.hparams, 'increasing') and self.hparams.increasing: datatf = ImageNetDistortTrainContrastive(self.hparams, epoch=self.current_epoch) else: datatf = self.train_set_transform train_dataset = ImageNet100( root=self.hparams.dataset_dir, split = 'train', transform = None ) train_contrastive = ContrastiveUnsupervisedDataset(train_dataset, transform_contrastive=datatf, return_label=True) train_dataloader = DataLoader(train_contrastive, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\ pin_memory=True, shuffle=True) return train_dataloader def val_dataloader(self): if hasattr(self.hparams, 'increasing') and self.hparams.increasing: datatf = ImageNetDistortVal(self.hparams, epoch=self.current_epoch) else: datatf = self.val_set_transform val_dataset = ImageNet100( root=self.hparams.dataset_dir, split = 'val', transform = datatf ) self.N_val = len(val_dataset) val_dataloader = DataLoader(val_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\ pin_memory=True, shuffle=False) return val_dataloader def test_dataloader(self): return self.val_dataloader()
class HPALit(pl.LightningModule): def __init__(self, params): super().__init__() self.hparams = params self.lr = self.hparams.lr self.save_hyperparameters() self.model = Net(name=self.hparams.model) self.criterion = torch.nn.BCEWithLogitsLoss() self.train_accuracy = Accuracy(subset_accuracy=True) self.val_accuracy = Accuracy(subset_accuracy=True) self.train_recall = Recall() self.val_recall = Recall() self.train_df = pd.read_csv( f'data_preprocessing/train_fold_{self.hparams.fold}.csv') self.valid_df = pd.read_csv( f'data_preprocessing/valid_fold_{self.hparams.fold}.csv') self.train_transforms = A.Compose([ A.Rotate(), A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5), A.Resize(width=self.hparams.img_size, height=self.hparams.img_size), A.Normalize(), ToTensorV2(), ]) self.valid_transforms = A.Compose([ A.Resize(width=self.hparams.img_size, height=self.hparams.img_size), A.Normalize(), ToTensorV2(), ]) self.train_dataset = CellDataset(data_dir=self.hparams.data_dir, csv_file=self.train_df, transform=self.train_transforms) self.val_dataset = CellDataset(data_dir=self.hparams.data_dir, csv_file=self.valid_df, transform=self.valid_transforms) def train_dataloader(self): return DataLoader( self.train_dataset, batch_size=self.hparams.batch_size, shuffle=True, num_workers=self.hparams.n_workers, pin_memory=True, ) def val_dataloader(self): return DataLoader( self.val_dataset, batch_size=self.hparams.batch_size, shuffle=False, num_workers=self.hparams.n_workers, pin_memory=True, ) def configure_optimizers(self): optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=2, eps=1e-6) lr_scheduler = { 'scheduler': scheduler, 'interval': 'epoch', 'monitor': 'valid_loss_epoch' } return [optimizer], [lr_scheduler] def training_step(self, batch, batch_idx): x, y = batch pred = self.model(x) train_loss = self.criterion(pred, y) self.train_accuracy(torch.sigmoid(pred), y.type(torch.int)) self.train_recall(torch.sigmoid(pred), y.type(torch.int)) return {'loss': train_loss} def training_epoch_end(self, outputs): train_loss_epoch = torch.stack([x['loss'] for x in outputs]).mean() self.log('train_loss_epoch', train_loss_epoch) self.log('train_acc_epoch', self.train_accuracy.compute()) self.log('train_recall_epoch', self.train_recall.compute()) self.train_accuracy.reset() self.train_recall.reset() def validation_step(self, batch, batch_idx): x, y = batch pred = self.model(x) val_loss = self.criterion(pred, y) self.val_accuracy(torch.sigmoid(pred), y.type(torch.int)) self.val_recall(torch.sigmoid(pred), y.type(torch.int)) return {'valid_loss': val_loss} def validation_epoch_end(self, outputs): val_loss_epoch = torch.stack([x['valid_loss'] for x in outputs]).mean() self.log('valid_loss_epoch', val_loss_epoch) self.log('valid_acc_epoch', self.val_accuracy.compute()) self.log('valid_recall_epoch', self.val_recall.compute()) self.val_accuracy.reset() self.val_recall.reset()
class Baseline(LightningModule): """ Class for training a classification model on distorted ImageNet inan end-to-end fashion. """ def __init__(self, args): super(Baseline, self).__init__() self.hparams = args self.world_size = self.hparams.num_nodes * self.hparams.gpus self.lr = self.hparams.lr #initalise this specially as a tuneable parameter #(1) Set up the dataset #Here, we use a 100-class subset of ImageNet if self.hparams.dataset != "ImageNet100": raise ValueError("Unsupported dataset selected.") else: if self.hparams.distortion == "None": self.train_set_transform = ImageNetBaseTransform(self.hparams) self.val_set_transform = ImageNetBaseTransformVal(self.hparams) elif self.hparams.distortion == "multi": self.train_set_transform = ImageNetDistortTrainMulti( self.hparams) self.val_set_transform = ImageNetDistortValMulti(self.hparams) else: #If we are using the ImageNet dataset, then set up the train and val sets to use the same mask if needed! self.train_set_transform = ImageNetDistortTrain(self.hparams) if self.hparams.fixed_mask: self.val_set_transform = ImageNetDistortVal( self.hparams, fixed_distortion=self.train_set_transform.distortion) else: self.val_set_transform = ImageNetDistortVal(self.hparams) #(2) Grab the correct baseline pre-trained model if self.hparams.encoder == 'resnet': self.encoder = RESNET_finetune(self.hparams) elif self.hparams.encoder == 'clip': self.encoder = CLIP_finetune(self.hparams) else: raise ValueError("Please select a valid encoder model.") #(3) Set up our criterion - here we use reduction as "sum" so that we are able to average over all validation sets self.criterion = nn.CrossEntropyLoss(reduction="sum") self.train_top_1 = Accuracy(top_k=1) self.train_top_5 = Accuracy(top_k=5) self.val_top_1 = Accuracy(top_k=1) self.val_top_5 = Accuracy(top_k=5) self.test_top_1 = Accuracy(top_k=1) self.test_top_5 = Accuracy(top_k=5) def forward(self, x): return self.encoder(x) def configure_optimizers(self): opt = torch.optim.Adam(self.encoder.parameters(), lr=self.lr) if self.hparams.dataset == "ImageNet100": num_steps = 126689 // ( self.hparams.batch_size * self.hparams.gpus ) #divide N_train by number of distributed iters else: num_steps = 500 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=num_steps) return [opt], [scheduler] #DATALOADERS def train_dataloader(self): if self.hparams.dataset == "ImageNet100": train_dataset = ImageNet100(root=self.hparams.dataset_dir, split='train', transform=self.train_set_transform) train_dataloader = DataLoader(train_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\ pin_memory=True, shuffle=True) return train_dataloader def val_dataloader(self): if self.hparams.dataset == "ImageNet100": val_dataset = ImageNet100(root=self.hparams.dataset_dir, split='val', transform=self.val_set_transform) self.N_val = 5000 val_dataloader = DataLoader(val_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\ pin_memory=True, shuffle=False) return val_dataloader def test_dataloader(self): if self.hparams.dataset == "ImageNet100": test_dataset = ImageNet100(root=self.hparams.dataset_dir, split='val', transform=self.val_set_transform) test_dataloader = DataLoader(test_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\ pin_memory=True, shuffle=False) return test_dataloader #TRAINING def training_step(self, batch, batch_idx): """ Given a batch of images, train the model for one step. Calculates the crossentropy loss as well as top_1 and top_5 per batch Inputs: batch - the images to train on, shape [batch_size, num_channels, height, width] batch_idx - the index of the current batch Returns: loss - the crossentropy loss between the model's logits and the true classes of the inputs """ x, y = batch if batch_idx == 0 and self.current_epoch == 0: self.logger.experiment.add_image('Train_Sample', img_grid(x), self.current_epoch) logits = self.forward(x) pred_probs = logits.softmax(dim=-1) loss = self.criterion(logits, y) self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True, logger=True, sync_dist=True, sync_dist_op='sum') self.log("train_top_1", self.train_top_1(pred_probs, y), prog_bar=False, logger=False) self.log("train_top_5", self.train_top_5(pred_probs, y), prog_bar=False, logger=False) return loss def training_epoch_end(self, outputs): self.log("train_top_1", self.train_top_1.compute(), prog_bar=True, logger=True) self.log("train_top_5", self.train_top_5.compute(), prog_bar=True, logger=True) self.train_top_1.reset() self.train_top_5.reset() #VALIDATION def validation_step(self, batch, batch_idx): x, y = batch if batch_idx == 0 and self.current_epoch == 0: self.logger.experiment.add_image('Val_Sample', img_grid(x), self.current_epoch) logits = self.forward(x) pred_probs = logits.softmax(dim=-1) self.log("val_top_1", self.val_top_1(pred_probs, y), prog_bar=False, logger=False) self.log("val_top_5", self.val_top_5(pred_probs, y), prog_bar=False, logger=False) def validation_epoch_end(self, outputs): self.log("val_top_1", self.val_top_1.compute(), prog_bar=True, logger=True) self.log("val_top_5", self.val_top_5.compute(), prog_bar=True, logger=True) self.val_top_1.reset() self.val_top_5.reset() #TESTING def test_step(self, batch, batch_idx): x, y = batch if batch_idx == 0 and self.current_epoch == 0: self.logger.experiment.add_image('Test_Sample', img_grid(x), self.current_epoch) logits = self.forward(x) pred_probs = logits.softmax(dim=-1) self.log("test_top_1", self.test_top_1(pred_probs, y), prog_bar=False, logger=False) self.log("test_top_5", self.test_top_5(pred_probs, y), prog_bar=False, logger=False) def test_epoch_end(self, outputs): self.log("test_top_1", self.test_top_1.compute(), prog_bar=True, logger=True) self.log("test_top_5", self.test_top_5.compute(), prog_bar=True, logger=True) self.test_top_1.reset() self.test_top_5.reset()
class TransferLearning(LightningModule): def __init__(self, args): super(TransferLearning, self).__init__() self.hparams = args self.world_size = self.hparams.num_nodes * self.hparams.gpus self.train_set_transform, self.val_set_transform = grab_transforms( self.hparams) #Grab the correct model - only want the embeddings from the final layer! if self.hparams.saved_model_type == 'contrastive': saved_model = NoisyCLIP.load_from_checkpoint( self.hparams.checkpoint_path) self.backbone = saved_model.noisy_visual_encoder elif self.hparams.saved_model_type == 'baseline': saved_model = Baseline.load_from_checkpoint( self.hparams.checkpoint_path) self.backbone = saved_model.encoder.feature_extractor for param in self.backbone.parameters(): param.requires_grad = False #Set up a classifier with the correct dimensions self.output = nn.Linear(self.hparams.emb_dim, self.hparams.num_classes) #Set up the criterion and stuff #(3) Set up our criterion - here we use reduction as "sum" so that we are able to average over all validation sets self.criterion = nn.CrossEntropyLoss(reduction="mean") self.train_top_1 = Accuracy(top_k=1) self.train_top_5 = Accuracy(top_k=5) self.val_top_1 = Accuracy(top_k=1) self.val_top_5 = Accuracy(top_k=5) self.test_top_1 = Accuracy(top_k=1) self.test_top_5 = Accuracy(top_k=5) #class INFECTED has label 0 if self.hparams.dataset == 'COVID': self.val_auc = AUROC(pos_label=0) self.test_auc = AUROC(pos_label=0) def forward(self, x): #Grab the noisy image embeddings self.backbone.eval() with torch.no_grad(): if self.hparams.encoder == "clip": noisy_embeddings = self.backbone(x.type(torch.float16)).float() elif self.hparams.encoder == "resnet": noisy_embeddings = self.backbone(x) return self.output(noisy_embeddings.flatten(1)) def configure_optimizers(self): if not hasattr(self.hparams, 'weight_decay'): self.hparams.weight_decay = 0 opt = torch.optim.Adam(self.output.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay) num_steps = self.hparams.max_epochs scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=num_steps) return [opt], [scheduler] def _grab_dataset(self, split): """ Given a split ("train" or "val" or "test") and a dataset, returns the proper dataset. Dataset needed is defined in this object's hparams Args: split: the split to use in the dataset Returns: dataset: the desired dataset with the correct split """ if self.hparams.dataset == "CIFAR10": if split == 'train': train = True transform = self.train_set_transform else: train = False transform = self.val_set_transform dataset = CIFAR10(root=self.hparams.dataset_dir, train=train, transform=transform, download=True) elif self.hparams.dataset == "CIFAR100": if split == 'train': train = True transform = self.train_set_transform else: train = False transform = self.val_set_transform dataset = CIFAR100(root=self.hparams.dataset_dir, train=train, transform=transform, download=True) elif self.hparams.dataset == 'STL10': if split == 'train': stlsplit = 'train' transform = self.train_set_transform else: stlsplit = 'test' transform = self.val_set_transform dataset = STL10(root=self.hparams.dataset_dir, split=stlsplit, transform=transform, download=True) elif self.hparams.dataset == 'COVID': if split == 'train': covidsplit = 'train' transform = self.train_set_transform else: covidsplit = 'test' transform = self.val_set_transform dataset = torchvision.datasets.ImageFolder( root=self.hparams.dataset_dir + covidsplit, transform=transform) elif self.hparams.dataset == 'ImageNet100B' or self.hparams.dataset == 'imagenet-100B': if split == 'train': transform = self.train_set_transform else: split = 'val' transform = self.val_set_transform dataset = ImageNet100(root=self.hparams.dataset_dir, split=split, transform=transform) elif self.hparams.dataset == 'COVID': if split == 'train': covidsplit = 'train' transform = self.train_set_transform else: covidsplit = 'test' transform = self.val_set_transform dataset = torchvision.datasets.ImageFolder( root=self.hparams.dataset_dir + covidsplit, transform=transform) elif self.hparams.dataset == 'ImageNet100B' or self.hparams.dataset == 'imagenet-100B': if split == 'train': transform = self.train_set_transform else: split = 'val' transform = self.val_set_transform dataset = ImageNet100(root=self.hparams.dataset_dir, split=split, transform=transform) return dataset def train_dataloader(self): train_dataset = self._grab_dataset(split='train') N_train = len(train_dataset) if self.hparams.use_subset: train_dataset = few_shot_dataset( train_dataset, int( np.ceil(N_train * self.hparams.subset_ratio / self.hparams.num_classes))) train_dataloader = DataLoader(train_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\ pin_memory=True, shuffle=True) return train_dataloader def val_dataloader(self): val_dataset = self._grab_dataset(split='val') N_val = len(val_dataset) #SET SHUFFLE TO TRUE SINCE AUROC FREAKS OUT IF IT GETS AN ALL-1 OR ALL-0 BATCH val_dataloader = DataLoader(val_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\ pin_memory=True, shuffle=True) return val_dataloader def test_dataloader(self): test_dataset = self._grab_dataset(split='test') N_test = len(test_dataset) #SET SHUFFLE TO TRUE SINCE AUROC FREAKS OUT IF IT GETS AN ALL-1 OR ALL-0 BATCH test_dataloader = DataLoader(test_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\ pin_memory=True, shuffle=True) return test_dataloader def training_step(self, batch, batch_idx): x, y = batch if batch_idx == 0 and self.current_epoch == 0: self.logger.experiment.add_image('Train_Sample', img_grid(x), self.current_epoch) logits = self.forward(x) loss = self.criterion(logits, y) self.log("train_loss", loss, prog_bar=False, on_step=True, \ on_epoch=True, logger=True, sync_dist=True, sync_dist_op='sum') return loss def validation_step(self, batch, batch_idx): x, y = batch if batch_idx == 0 and self.current_epoch == 0: self.logger.experiment.add_image('Val_Sample', img_grid(x), self.current_epoch) logits = self.forward(x) pred_probs = logits.softmax(dim=-1) #(N, num_classes) if self.hparams.dataset == 'COVID': positive_prob = pred_probs[:, 0].flatten() #class 0 is INFECTED label true_labels = y.flatten() self.val_auc.update(positive_prob, true_labels) self.log("val_auc", self.val_auc, prog_bar=False, logger=False) self.log("val_top_1", self.val_top_1(pred_probs, y), prog_bar=False, logger=False) if self.hparams.dataset != 'COVID': self.log("val_top_5", self.val_top_5(pred_probs, y), prog_bar=False, logger=False) def validation_epoch_end(self, outputs): self.log("val_top_1", self.val_top_1.compute(), prog_bar=True, logger=True) if self.hparams.dataset != 'COVID': self.log("val_top_5", self.val_top_5.compute(), prog_bar=True, logger=True) if self.hparams.dataset == 'COVID': self.log("val_auc", self.val_auc.compute(), prog_bar=True, logger=True) self.val_auc.reset() self.val_top_1.reset() self.val_top_5.reset() def test_step(self, batch, batch_idx): x, y = batch logits = self.forward(x) pred_probs = logits.softmax(dim=-1) if self.hparams.dataset == 'COVID': positive_prob = pred_probs[:, 0].flatten() #class 0 is INFECTED label true_labels = y.flatten() self.test_auc.update(positive_prob, true_labels) self.log("test_auc", self.test_auc, prog_bar=False, logger=False) self.log("test_top_1", self.test_top_1(pred_probs, y), prog_bar=False, logger=False) if self.hparams.dataset != 'COVID': self.log("test_top_5", self.test_top_5(pred_probs, y), prog_bar=False, logger=False) def test_epoch_end(self, outputs): self.log("test_top_1", self.test_top_1.compute(), prog_bar=True, logger=True) if self.hparams.dataset != 'COVID': self.log("test_top_5", self.test_top_5.compute(), prog_bar=True, logger=True) if self.hparams.dataset == 'COVID': self.log("test_auc", self.test_auc.compute(), prog_bar=True, logger=True) self.test_auc.reset() self.test_top_1.reset() self.test_top_5.reset()