class IrisClassification(pl.LightningModule): def __init__(self, **kwargs): super().__init__() self.train_acc = Accuracy() self.val_acc = Accuracy() self.test_acc = Accuracy() self.args = kwargs self.fc1 = nn.Linear(4, 10) self.fc2 = nn.Linear(10, 10) self.fc3 = nn.Linear(10, 3) self.cross_entropy_loss = nn.CrossEntropyLoss() self.lr = kwargs.get("lr", 0.01) self.momentum = kwargs.get("momentum", 0.9) self.weight_decay = kwargs.get("weight_decay", 0.1) def forward(self, x): x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = F.relu(self.fc3(x)) return x def configure_optimizers(self): return torch.optim.SGD(self.parameters(), lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay) def training_step(self, batch, batch_idx): x, y = batch logits = self.forward(x) loss = self.cross_entropy_loss(logits, y) self.train_acc(torch.argmax(logits, dim=1), y) self.log("train_acc", self.train_acc.compute(), on_step=False, on_epoch=True) self.log("loss", loss) return {"loss": loss} def validation_step(self, batch, batch_idx): x, y = batch logits = self.forward(x) loss = F.cross_entropy(logits, y) self.val_acc(torch.argmax(logits, dim=1), y) self.log("val_acc", self.val_acc.compute()) self.log("val_loss", loss, sync_dist=True) def test_step(self, batch, batch_idx): x, y = batch logits = self.forward(x) loss = F.cross_entropy(logits, y) self.test_acc(torch.argmax(logits, dim=1), y) self.log("test_loss", loss) self.log("test_acc", self.test_acc.compute())
def test_wrong_params(top_k, threshold): preds, target = _input_mcls_prob.preds, _input_mcls_prob.target with pytest.raises(ValueError): acc = Accuracy(threshold=threshold, top_k=top_k) acc(preds, target) acc.compute() with pytest.raises(ValueError): accuracy(preds, target, threshold=threshold, top_k=top_k)
class SimpleModel(LightningModule): def __init__(self, vocab_size, embedding_dim=32): super().__init__() self.embeddings_layer = nn.Embedding(vocab_size, embedding_dim) self.loss = nn.BCEWithLogitsLoss() self.valid_accuracy = Accuracy() self.test_accuracy = Accuracy() def forward(self, inputs, labels): raise NotImplementedError("forward not implemented") def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) return [optimizer] def training_step(self, batch, _): inputs, labels = batch loss, logits = self(inputs, labels) return loss def validation_step(self, batch, _): inputs, labels = batch val_loss, logits = self(inputs, labels) if torch.max(labels) == 1: pred = torch.sigmoid(logits) else: pred = torch.softmax(logits, 1) self.valid_accuracy.update(pred, labels.long()) self.log("val_loss", val_loss, prog_bar=True) self.log("val_acc", self.valid_accuracy) def validation_epoch_end(self, outs): self.log("val_acc_epoch", self.valid_accuracy.compute(), prog_bar=True) def test_step(self, batch, _): inputs, labels = batch test_loss, logits = self(inputs, labels) if torch.max(labels) == 1: pred = torch.sigmoid(logits) else: pred = torch.softmax(logits, 1) self.test_accuracy.update(pred, labels.long()) self.log("test_loss", test_loss, prog_bar=True) self.log("test_acc", self.test_accuracy) def test_epoch_end(self, outs): self.log("test_acc_epoch", self.test_accuracy.compute(), prog_bar=True)
def segmentation_model_attack(model, model_type, config, num_classes=NUM_CLASSES): """Salt and pepper augmentation of segmentation images, return accuracy - difference between that and normal is a measure of resiliency""" if model_type == "pt": device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") test = perturbed_pt_gis_test_data() test_set = PT_GISDataset(test) testloader = DataLoader(test_set, batch_size=int(config['batch_size'])) accuracy = Accuracy() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") for sample in tqdm(testloader): cuda_in = sample[0].to(device) out = model(cuda_in) output = out.to('cpu').squeeze(1) accuracy(output, sample[1]) # cuda_in = cuda_in.detach() # label = label.detach() return accuracy.compute().item() elif model_type == "tf": x_test, y_test = perturbed_tf_gis_test_data() test_acc = model.evaluate(x_test, y_test, batch_size=config['batch_size']) return test_acc else: print("Unknown model type, failure.") return None
class ModelParallelClassificationModel(LightningModule): def __init__(self, lr: float = 0.01, num_blocks: int = 5): super().__init__() self.lr = lr self.num_blocks = num_blocks self.train_acc = Accuracy() self.valid_acc = Accuracy() self.test_acc = Accuracy() def make_block(self): return nn.Sequential(nn.Linear(32, 32, bias=False), nn.ReLU()) def configure_sharded_model(self) -> None: self.model = nn.Sequential(*(self.make_block() for x in range(self.num_blocks)), nn.Linear(32, 3)) def forward(self, x): x = self.model(x) # Ensure output is in float32 for softmax operation x = x.float() logits = F.softmax(x, dim=1) return logits def training_step(self, batch, batch_idx): x, y = batch logits = self.forward(x) loss = F.cross_entropy(logits, y) self.log('train_loss', loss, prog_bar=True) self.log('train_acc', self.train_acc(logits, y), prog_bar=True, sync_dist=True) return {"loss": loss} def validation_step(self, batch, batch_idx): x, y = batch logits = self.forward(x) self.log('val_loss', F.cross_entropy(logits, y), prog_bar=False, sync_dist=True) self.log('val_acc', self.valid_acc(logits, y), prog_bar=True, sync_dist=True) def test_step(self, batch, batch_idx): x, y = batch logits = self.forward(x) self.log('test_loss', F.cross_entropy(logits, y), prog_bar=False, sync_dist=True) self.log('test_acc', self.test_acc(logits, y), prog_bar=True, sync_dist=True) def predict_step(self, batch, batch_idx, dataloader_idx=None): x, y = batch logits = self.forward(x) self.test_acc(logits, y) return self.test_acc.compute() def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99) return [optimizer], [{ 'scheduler': lr_scheduler, 'interval': 'step', }]
def str_accuracy(m: Accuracy, detail: bool = False): backup = m.correct, m.total metric = m.compute() m.correct, m.total = backup if math.isnan(metric) or math.isinf(metric): return 'N/A' elif not detail: return f'{metric * 100:.2f}%' else: return f'{metric * 100:.2f}%(= {m.correct}/{m.total})'
def test_topk_accuracy(preds, target, exp_result, k, subset_accuracy): topk = Accuracy(top_k=k, subset_accuracy=subset_accuracy) for batch in range(preds.shape[0]): topk(preds[batch], target[batch]) assert topk.compute() == exp_result # Test functional total_samples = target.shape[0] * target.shape[1] preds = preds.view(total_samples, 4, -1) target = target.view(total_samples, -1) assert accuracy(preds, target, top_k=k, subset_accuracy=subset_accuracy) == exp_result
class Densenet121Lightning(BaseParticipantModel, pl.LightningModule): def __init__(self, num_classes, *args, weights=None, pretrain=True, **kwargs): model = torchvision.models.densenet121(pretrained=pretrain) model.classifier = Linear(in_features=1024, out_features=num_classes, bias=True) super().__init__(*args, model=model, **kwargs) self.model = model self.accuracy = Accuracy() self.train_accuracy = Accuracy() self.criterion = CrossEntropyLoss(weight=weights) def training_step(self, train_batch, batch_idx): x, y = train_batch y = y.long() logits = self.model(x) loss = self.criterion(logits, y) preds = torch.argmax(logits, dim=1) self.log('train/acc/{}'.format(self.participant_name), self.train_accuracy(preds, y)) self.log('train/loss/{}'.format(self.participant_name), loss.item()) return loss def test_step(self, test_batch, batch_idx): x, y = test_batch y = y.long() logits = self.model(x) loss = self.criterion(logits, y) preds = torch.argmax(logits, dim=1) self.accuracy.update(preds, y) GlobalConfusionMatrix().update(preds, y) return {'loss': loss} def test_epoch_end(self, outputs: List[Any]) -> None: loss_list = [o['loss'] for o in outputs] loss = torch.stack(loss_list) self.log(f'sample_num', self.accuracy.total.item()) self.log(f'test/acc/{self.participant_name}', self.accuracy.compute()) self.log(f'test/loss/{self.participant_name}', loss.mean().item())
class CNNLightning(BaseParticipantModel, pl.LightningModule): def __init__(self, only_digits=False, input_channels=1, *args, **kwargs): model = CNN_OriginalFedAvg(only_digits=only_digits, input_channels=input_channels) super().__init__(*args, model=model, **kwargs) self.model = model # self.model.apply(init_weights) self.accuracy = Accuracy() self.train_accuracy = Accuracy() def training_step(self, train_batch, batch_idx): x, y = train_batch y = y.long() logits = self.model(x) loss = F.cross_entropy(logits, y) preds = torch.argmax(logits, dim=1) self.log(f'train/acc/{self.participant_name}', self.train_accuracy(preds, y).item()) self.log(f'train/loss/{self.participant_name}', loss.mean().item()) return loss def test_step(self, test_batch, batch_idx): x, y = test_batch y = y.long() logits = self.model(x) loss = F.cross_entropy(logits, y) preds = torch.argmax(logits, dim=1) self.accuracy.update(preds, y) return {'loss': loss} def test_epoch_end( self, outputs: List[Any] ) -> None: loss_list = [o['loss'] for o in outputs] loss = torch.stack(loss_list) self.log(f'sample_num', self.accuracy.total.item()) self.log(f'test/acc/{self.participant_name}', self.accuracy.compute()) self.log(f'test/loss/{self.participant_name}', loss.mean().item())
class LinearProbe(LightningModule): """ A class to train and evaluate a linear probe on top of representations learned from a noisy clip student. """ def __init__(self, args): super(LinearProbe, self).__init__() self.hparams = args #(1) Set up the dataset #Here, we use a 100-class subset of ImageNet if self.hparams.dataset not in ["ImageNet100", "CIFAR10", "CIFAR100"]: raise ValueError("Unsupported dataset selected.") else: if self.hparams.distortion == "None": self.train_set_transform = ImageNetBaseTransform(self.hparams) self.val_set_transform = ImageNetBaseTransformVal(self.hparams) elif self.hparams.distortion == "multi": self.train_set_transform = ImageNetDistortTrainMulti(self.hparams) self.val_set_transform = ImageNetDistortValMulti(self.hparams) else: #If we are using the ImageNet dataset, then set up the train and val sets to use the same mask if needed! self.train_set_transform = ImageNetDistortTrain(self.hparams) if self.hparams.fixed_mask: self.val_set_transform = ImageNetDistortVal(self.hparams, fixed_distortion=self.train_set_transform.distortion) else: self.val_set_transform = ImageNetDistortVal(self.hparams) #This should be initialised as a trained student CLIP network if self.hparams.encoder == "clip": saved_student = NoisyCLIP.load_from_checkpoint(self.hparams.checkpoint_path) self.backbone = saved_student.noisy_visual_encoder elif self.hparams.encoder == "clean": saved_student = clip.load('RN101', 'cpu', jit=False)[0] self.backbone = saved_student.visual self.backbone.eval() for param in self.backbone.parameters(): param.requires_grad = False #This is the meat self.output = nn.Linear(self.hparams.emb_dim, self.hparams.num_classes) #Set up training and validation metrics self.criterion = nn.CrossEntropyLoss() self.val_top_1 = Accuracy(top_k=1) self.val_top_5 = Accuracy(top_k=5) self.test_top_1 = Accuracy(top_k=1) self.test_top_5 = Accuracy(top_k=5) def forward(self, x): """ Given a set of images x with shape [N, c, h, w], get their embeddings and then logits. Returns: Logits with shape [N, n_classes] """ #Grab the noisy image embeddings self.backbone.eval() with torch.no_grad(): if self.hparams.encoder == "clip": noisy_embeddings = self.backbone(x.type(torch.float16)).float() elif self.hparams.encoder == "clean": noisy_embeddings = self.backbone(x.type(torch.float16)).float() return self.output(noisy_embeddings) def configure_optimizers(self): if not hasattr(self.hparams, 'weight_decay'): self.hparams.weight_decay = 0 opt = torch.optim.Adam(self.output.parameters(), lr = self.hparams.lr, weight_decay = self.hparams.weight_decay) if self.hparams.dataset == "ImageNet100": num_steps = 126689//(self.hparams.batch_size * self.hparams.gpus) #divide N_train by number of distributed iters if self.hparams.use_subset: num_steps = num_steps * self.hparams.subset_ratio else: num_steps = 500 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=num_steps) return [opt], [scheduler] def train_dataloader(self): if self.hparams.dataset == "ImageNet100": train_dataset = ImageNet100( root=self.hparams.dataset_dir, split = 'train', transform = self.train_set_transform ) N_train = len(train_dataset) if self.hparams.use_subset: train_dataset = few_shot_dataset(train_dataset, int(np.ceil(N_train*self.hparams.subset_ratio/self.hparams.num_classes))) train_dataloader = DataLoader(train_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\ pin_memory=True, shuffle=True) return train_dataloader def val_dataloader(self): if self.hparams.dataset == "ImageNet100": val_dataset = ImageNet100( root=self.hparams.dataset_dir, split = 'val', transform = self.val_set_transform ) self.N_val = 5000 val_dataloader = DataLoader(val_dataset, batch_size=4*self.hparams.batch_size, num_workers=self.hparams.workers,\ pin_memory=True, shuffle=False) return val_dataloader def training_step(self, batch, batch_idx): x, y = batch if batch_idx == 0 and self.current_epoch == 0: self.logger.experiment.add_image('Train_Sample', img_grid(x), self.current_epoch) logits = self.forward(x) loss = self.criterion(logits, y) self.log("train_loss", loss, prog_bar=True, on_step=True, \ on_epoch=True, logger=True, sync_dist=True, sync_dist_op='sum') return loss def validation_step(self, batch, batch_idx): x, y = batch if batch_idx == 0 and self.current_epoch == 0: self.logger.experiment.add_image('Val_Sample', img_grid(x), self.current_epoch) logits = self.forward(x) pred_probs = logits.softmax(dim=-1) self.log("val_top_1", self.val_top_1(pred_probs, y), prog_bar=False, logger=False) self.log("val_top_5", self.val_top_5(pred_probs, y), prog_bar=False, logger=False) def validation_epoch_end(self, outputs): self.log("val_top_1", self.val_top_1.compute(), prog_bar=True, logger=True) self.log("val_top_5", self.val_top_5.compute(), prog_bar=True, logger=True) self.val_top_1.reset() self.val_top_5.reset() def test_step(self, batch, batch_idx): x, y = batch logits = self.forward(x) pred_probs = logits.softmax(dim=-1) self.log("test_top_1", self.test_top_1(pred_probs, y), prog_bar=False, logger=False) self.log("test_top_5", self.test_top_5(pred_probs, y), prog_bar=False, logger=False) def test_epoch_end(self, outputs): self.log("test_top_1", self.test_top_1.compute(), prog_bar=True, logger=True) self.log("test_top_5", self.test_top_5.compute(), prog_bar=True, logger=True) self.test_top_1.reset() self.test_top_5.reset() def predict(self, batch, batch_idx, hiddens): x, y = batch logits = self.forward(x) pred_probs = logits.softmax(dim=-1) preds = pred_probs.argmax(dim=-1) return preds
class CIFAR10Classifier(pl.LightningModule): def __init__(self, **kwargs): """ Initializes the network, optimizer and scheduler """ super(CIFAR10Classifier, self).__init__() self.model_conv = models.resnet50(pretrained=True) for param in self.model_conv.parameters(): param.requires_grad = False num_ftrs = self.model_conv.fc.in_features num_classes = 10 self.model_conv.fc = nn.Linear(num_ftrs, num_classes) self.scheduler = None self.optimizer = None self.args = kwargs self.train_acc = Accuracy() self.val_acc = Accuracy() self.test_acc = Accuracy() def forward(self, x): out = self.model_conv(x) return out def training_step(self, train_batch, batch_idx): if batch_idx == 0: self.reference_image = (train_batch[0][0]).unsqueeze(0) #self.reference_image.resize((1,1,28,28)) print("\n\nREFERENCE IMAGE!!!") print(self.reference_image.shape) x, y = train_batch y_hat = self.forward(x) loss = F.cross_entropy(y_hat, y) self.log("train_loss", loss) self.train_acc(y_hat, y) self.log("train_acc", self.train_acc.compute()) return {"loss": loss} def test_step(self, test_batch, batch_idx): x, y = test_batch y_hat = self.forward(x) loss = F.cross_entropy(y_hat, y) if self.args["accelerator"] is not None: self.log("test_loss", loss, sync_dist=True) else: self.log("test_loss", loss) self.test_acc(y_hat, y) self.log("test_acc", self.test_acc.compute()) return {"test_acc": self.test_acc.compute()} def validation_step(self, val_batch, batch_idx): x, y = val_batch y_hat = self.forward(x) loss = F.cross_entropy(y_hat, y) if self.args["accelerator"] is not None: self.log("val_loss", loss, sync_dist=True) else: self.log("val_loss", loss) self.val_acc(y_hat, y) self.log("val_acc", self.val_acc.compute()) return {"val_step_loss": loss, "val_loss": loss} def configure_optimizers(self): """ Initializes the optimizer and learning rate scheduler :return: output - Initialized optimizer and scheduler """ self.optimizer = torch.optim.Adam(self.parameters(), lr=self.args["lr"]) self.scheduler = { "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, mode="min", factor=0.2, patience=3, min_lr=1e-6, verbose=True, ), "monitor": "val_loss", } return [self.optimizer], [self.scheduler] def makegrid(self, output, numrows): outer = torch.Tensor.cpu(output).detach() plt.figure(figsize=(20, 5)) b = np.array([]).reshape(0, outer.shape[2]) c = np.array([]).reshape(numrows * outer.shape[2], 0) i = 0 j = 0 while i < outer.shape[1]: img = outer[0][i] b = np.concatenate((img, b), axis=0) j += 1 if j == numrows: c = np.concatenate((c, b), axis=1) b = np.array([]).reshape(0, outer.shape[2]) j = 0 i += 1 return c def showActivations(self, x): # logging reference image self.logger.experiment.add_image("input", torch.Tensor.cpu(x[0][0]), self.current_epoch, dataformats="HW") # logging layer 1 activations out = self.model_conv.conv1(x) c = self.makegrid(out, 4) self.logger.experiment.add_image("layer 1", c, self.current_epoch, dataformats="HW") def training_epoch_end(self, outputs): self.showActivations(self.reference_image) # Logging graph if (self.current_epoch == 0): sampleImg = torch.rand((1, 3, 64, 64)) self.logger.experiment.add_graph(CIFAR10Classifier(), sampleImg)
class NoisyCLIP(LightningModule): def __init__(self, args): """ A class that trains OpenAI CLIP in a student-teacher fashion to classify distorted images. Given two identical pre-trained networks, Teacher - T() and Student S(), we freeze T() and train S(). Given a batch of {x1, x2, ..., xN}, apply a given distortion to each one to obtain noisy images {y1, y2, ..., yN}. Feed original images to T() and obtain embeddings {T(x1), ..., T(xN)} and feed distorted images to S() and obtain embeddings {S(y1), ..., S(yN)}. Maximize the similarity between the pairs {(T(x1), S(y1)), ..., (T(xN), S(yN))} while minimizing the similarity between all non-matched pairs. """ super(NoisyCLIP, self).__init__() self.hparams = args self.world_size = self.hparams.num_nodes * self.hparams.gpus #(1) Load the correct dataset class names if self.hparams.dataset == "ImageNet100" or self.hparams.dataset == "Imagenet-100": self.N_val = 5000 # Default ImageNet validation set, only 100 classes. #Retrieve the text labels for classes in order to do zero-shot classification. if self.hparams.mapping_and_text_file is None: raise ValueError('No file from which to read text labels was specified.') text_labels = pickle.load(open(self.hparams.mapping_and_text_file, 'rb')) self.text_list = ['A photo of '+label.strip().replace('_',' ') for label in text_labels] else: raise NotImplementedError('Handling of the dataset not implemented yet.') #Set up the dataset #Here, we use a 100-class subset of ImageNet if self.hparams.dataset != "ImageNet100" and self.hparams.dataset != "Imagenet-100": raise ValueError("Unsupported dataset selected.") elif not hasattr(self.hparams, 'increasing') or not self.hparams.increasing: if self.hparams.distortion == "None": self.train_set_transform = ImageNetBaseTrainContrastive(self.hparams) self.val_set_transform = ImageNetBaseTransformVal(self.hparams) elif self.hparams.distortion == "multi": self.train_set_transform = ImageNetDistortTrainMultiContrastive(self.hparams) self.val_set_transform = ImageNetDistortValMulti(self.hparams) else: #If we are using the ImageNet dataset, then set up the train and val sets to use the same mask if needed! self.train_set_transform = ImageNetDistortTrainContrastive(self.hparams) if self.hparams.fixed_mask: self.val_set_transform = ImageNetDistortVal(self.hparams, fixed_distortion=self.train_set_transform.distortion) else: self.val_set_transform = ImageNetDistortVal(self.hparams) #(2) set up the teacher CLIP network - freeze it and don't use gradients! self.logit_scale = self.hparams.logit_scale self.baseclip = clip.load(self.hparams.baseclip_type, self.hparams.device, jit=False)[0] self.baseclip.eval() self.baseclip.requires_grad_(False) #(3) set up the student CLIP network - unfreeze it and use gradients! self.noisy_visual_encoder = clip.load(self.hparams.baseclip_type, self.hparams.device, jit=False)[0].visual self.noisy_visual_encoder.train() #(4) set up the training and validation accuracy metrics. self.train_top_1 = Accuracy(top_k=1) self.train_top_5 = Accuracy(top_k=5) self.val_top_1 = Accuracy(top_k=1) self.val_top_5 = Accuracy(top_k=5) def criterion(self, input1, input2, reduction='mean'): """ Args: input1: Embeddings of the clean/noisy images from the teacher/student. Size [N, embedding_dim]. input2: Embeddings of the clean/noisy images from the teacher/student (the ones not used as input1). Size [N, embedding_dim]. reduction: how to scale the final loss """ bsz = input1.shape[0] # Use the simclr style InfoNCE if self.hparams.loss_type == 'simclr': # Create similarity matrix between embeddings. full_tensor = torch.cat([input1.unsqueeze(1),input2.unsqueeze(1)], dim=1).view(2*bsz, -1) #tensor1 = full_tensor.expand(2*bsz,2*bsz,-1) #tensor2 = full_tensor.permute(1,0,2).expand(2*bsz,2*bsz,-1) #sim_mat = torch.nn.CosineSimilarity(dim=-1)(tensor1,tensor2) full_tensor = full_tensor / full_tensor.norm(dim=-1, keepdim=True) sim_mat = full_tensor @ full_tensor.t() print(torch.sum(sim_mat < 0)) # Calculate logits used for the contrastive loss. exp_sim_mat = torch.exp(sim_mat/self.hparams.loss_tau) mask = torch.ones_like(exp_sim_mat) - torch.eye(2*bsz).type_as(exp_sim_mat) logmat = -torch.log(exp_sim_mat)+torch.log(torch.sum(mask*exp_sim_mat, 1)) #Grab the two off-diagonal similarities part1 = torch.sum(torch.diag(logmat, diagonal=1)[np.arange(0,2*bsz,2)]) part2 = torch.sum(torch.diag(logmat, diagonal=-1)[np.arange(0,2*bsz,2)]) #Take the mean of the two off-diagonals loss = (part1 + part2)/2 #Use the CLIP-style InfoNCE elif self.hparams.loss_type == 'clip': # Create similarity matrix between embeddings. tensor1 = input1 / input1.norm(dim=-1, keepdim=True) tensor2 = input2 / input2.norm(dim=-1, keepdim=True) sim_mat = (1/self.hparams.loss_tau)*tensor1 @ tensor2.t() #Calculate the cross entropy between the similarities of the positive pairs, counted two ways part1 = F.cross_entropy(sim_mat, torch.LongTensor(np.arange(bsz)).to(self.device)) part2 = F.cross_entropy(sim_mat.t(), torch.LongTensor(np.arange(bsz)).to(self.device)) #Take the mean of the two off-diagonals loss = (part1+part2)/2 #Take the simple MSE between the clean and noisy embeddings elif self.hparams.loss_type == 'mse': return F.mse_loss(input2, input1) elif self.hparams.loss_type.startswith('simclr_'): assert self.hparams.loss_type in ['simclr_ss', 'simclr_st', 'simclr_both'] # Various schemes for the negative examples teacher_embeds = F.normalize(input1, dim=1) student_embeds = F.normalize(input2, dim=1) # First compute positive examples by taking <S(x_i), T(x_i)>/T for all i pos_term = (teacher_embeds * student_embeds).sum(dim=1) / self.hparams.loss_tau # Then generate the negative term by constructing various similarity matrices if self.hparams.loss_type == 'simclr_ss': cov = torch.mm(student_embeds, student_embeds.t()) sim = torch.exp(cov / self.hparams.loss_tau) # shape is [bsz, bsz] neg_term = torch.log(sim.sum(dim=1) - sim.diag()) elif self.hparams.loss_type == 'simclr_st': cov = torch.mm(student_embeds, teacher_embeds.t()) sim = torch.exp(cov / self.hparams.loss_tau) # shape is [bsz, bsz] neg_term = torch.log(sim.sum(dim=1)) # Not removing the diagonal here! else: cat_embeds = torch.cat([student_embeds, teacher_embeds]) cov = torch.mm(student_embeds, cat_embeds.t()) sim = torch.exp(cov / self.hparams.loss_tau) # shape is [bsz, 2 * bsz] # and take row-wise sums w/o diagonals and neg_term = torch.log(sim.sum(dim=1) - sim.diag()) # Final loss is loss = -1 * (pos_term - neg_term).sum() # (summed and then mean-reduced later) else: raise ValueError('Loss function not understood.') return loss/bsz if reduction == 'mean' else loss def configure_optimizers(self): optim = torch.optim.Adam(self.noisy_visual_encoder.parameters(), lr=self.hparams.lr) num_steps = 126689//(self.hparams.batch_size * self.hparams.gpus) #divide N_train by number of distributed iters sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=num_steps) return [optim], [sched] def encode_noisy_image(self, image): """ Return S(yi) where S() is the student network and yi is distorted images. """ return self.noisy_visual_encoder(image.type(torch.float16)) def forward(self, image_features, text=None): """ Given a set of noisy image embeddings, calculate the cosine similarity (scaled by temperature) of each image with each class text prompt. Calculates the similarity in two ways: logits per image (size = [N, n_classes]), logits per text (size = [n_classes, N]). This is mainly used for validation and classification. Args: image_features: the noisy image embeddings S(yi) where S() is the student and yi = Distort(xi). Shape [N, embedding_dim] """ #load the pre-computed text features and load them on the correct device text_features = self.baseclip.encode_text(clip.tokenize(self.text_list).to(self.device)) # normalized features image_features = image_features / image_features.norm(dim=-1, keepdim=True) text_features = text_features / text_features.norm(dim=-1, keepdim=True) # cosine similarity as logits logits_per_image = self.logit_scale * image_features.type(torch.float16) @ text_features.type(torch.float16).t() logits_per_text = self.logit_scale * text_features.type(torch.float16) @ image_features.type(torch.float16).t() return logits_per_image, logits_per_text # Training methods - here we are concerned with contrastive loss (or MSE) between clean and noisy image embeddings. def training_step(self, train_batch, batch_idx): """ Takes a batch of clean and noisy images and returns their respective embeddings. Returns: embed_clean: T(xi) where T() is the teacher and xi are clean images. Shape [N, embed_dim] embed_noisy: S(yi) where S() is the student and yi are noisy images. Shape [N, embed_dim] """ image_clean, image_noisy, labels = train_batch embed_clean = self.baseclip.encode_image(image_clean) embed_noisy = self.encode_noisy_image(image_noisy) return {'embed_clean': embed_clean, 'embed_noisy': embed_noisy} def training_step_end(self, outputs): """ Given all the clean and noisy image embeddings form across GPUs from training_step, gather them onto a single GPU and calculate overall loss. """ embed_clean_full = outputs['embed_clean'] embed_noisy_full = outputs['embed_noisy'] loss = self.criterion(embed_clean_full, embed_noisy_full) self.log('train_loss', loss, prog_bar=False, logger=True, sync_dist=True, on_step=True, on_epoch=True) return loss # Validation methods - here we are concerned with similarity between noisy image embeddings and classification text embeddings. def validation_step(self, test_batch, batch_idx): """ Grab the noisy image embeddings: S(yi), where S() is the student and yi = Distort(xi). Done on each GPU. Return these to be evaluated in validation step end. """ images_noisy, labels = test_batch if batch_idx == 0 and self.current_epoch < 1: self.logger.experiment.add_image('Val_Sample', img_grid(images_noisy), self.current_epoch) image_features = self.encode_noisy_image(images_noisy) return {'image_features': image_features, 'labels': labels} def validation_step_end(self, outputs): """ Gather the noisy image features and their labels from each GPU. Then calculate their similarities, convert to probabilities, and calculate accuracy on each GPU. """ image_features_full = outputs['image_features'] labels_full = outputs['labels'] image_logits, _ = self.forward(image_features_full) image_logits = image_logits.float() image_probs = image_logits.softmax(dim=-1) self.log('val_top_1_step', self.val_top_1(image_probs, labels_full), prog_bar=False, logger=False) self.log('val_top_5_step', self.val_top_5(image_probs, labels_full), prog_bar=False, logger=False) def validation_epoch_end(self, outputs): """ Gather the zero-shot validation accuracies from across GPUs and reduce. """ self.log('val_top_1', self.val_top_1.compute(), prog_bar=True, logger=True) self.log('val_top_5', self.val_top_5.compute(), prog_bar=True, logger=True) self.val_top_1.reset() self.val_top_5.reset() # Default dataloaders - can be overwritten by datamodule. def train_dataloader(self): if hasattr(self.hparams, 'increasing') and self.hparams.increasing: datatf = ImageNetDistortTrainContrastive(self.hparams, epoch=self.current_epoch) else: datatf = self.train_set_transform train_dataset = ImageNet100( root=self.hparams.dataset_dir, split = 'train', transform = None ) train_contrastive = ContrastiveUnsupervisedDataset(train_dataset, transform_contrastive=datatf, return_label=True) train_dataloader = DataLoader(train_contrastive, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\ pin_memory=True, shuffle=True) return train_dataloader def val_dataloader(self): if hasattr(self.hparams, 'increasing') and self.hparams.increasing: datatf = ImageNetDistortVal(self.hparams, epoch=self.current_epoch) else: datatf = self.val_set_transform val_dataset = ImageNet100( root=self.hparams.dataset_dir, split = 'val', transform = datatf ) self.N_val = len(val_dataset) val_dataloader = DataLoader(val_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\ pin_memory=True, shuffle=False) return val_dataloader def test_dataloader(self): return self.val_dataloader()
class BertNewsClassifier(pl.LightningModule): def __init__(self, **kwargs): """ Initializes the network, optimizer and scheduler """ super(BertNewsClassifier, self).__init__() self.train_acc = Accuracy() self.val_acc = Accuracy() self.test_acc = Accuracy() self.PRE_TRAINED_MODEL_NAME = "bert-base-uncased" self.bert_model = BertModel.from_pretrained( self.PRE_TRAINED_MODEL_NAME) for param in self.bert_model.parameters(): param.requires_grad = False self.drop = nn.Dropout(p=0.2) # assigning labels self.class_names = ["world", "Sports", "Business", "Sci/Tech"] n_classes = len(self.class_names) self.fc1 = nn.Linear(self.bert_model.config.hidden_size, 512) self.out = nn.Linear(512, n_classes) self.args = kwargs def forward(self, input_ids, attention_mask): """ :param input_ids: Input data :param attention_maks: Attention mask value :return: output - Type of news for the given news snippet """ pooled_output = self.bert_model( input_ids=input_ids, attention_mask=attention_mask).pooler_output output = F.relu(self.fc1(pooled_output)) output = self.drop(output) output = self.out(output) return output @staticmethod def add_model_specific_args(parent_parser): """ Returns the review text and the targets of the specified item :param parent_parser: Application specific parser :return: Returns the augmented arugument parser """ parser = ArgumentParser(parents=[parent_parser], add_help=False) parser.add_argument( "--lr", type=float, default=0.001, metavar="LR", help="learning rate (default: 0.001)", ) return parser def training_step(self, train_batch, batch_idx): """ Training the data as batches and returns training loss on each batch :param train_batch Batch data :param batch_idx: Batch indices :return: output - Training loss """ input_ids = train_batch["input_ids"] attention_mask = train_batch["attention_mask"] targets = train_batch["targets"] output = self.forward(input_ids, attention_mask) _, y_hat = torch.max(output, dim=1) loss = F.cross_entropy(output, targets) self.train_acc(y_hat, targets) self.log("train_acc", self.train_acc.compute().cpu()) self.log("train_loss", loss.cpu()) return {"loss": loss} def test_step(self, test_batch, batch_idx): """ Performs test and computes the accuracy of the model :param test_batch: Batch data :param batch_idx: Batch indices :return: output - Testing accuracy """ input_ids = test_batch["input_ids"] attention_mask = test_batch["attention_mask"] targets = test_batch["targets"] output = self.forward(input_ids, attention_mask) _, y_hat = torch.max(output, dim=1) self.test_acc(y_hat, targets) self.log("test_acc", self.test_acc.compute().cpu()) def validation_step(self, val_batch, batch_idx): """ Performs validation of data in batches :param val_batch: Batch data :param batch_idx: Batch indices :return: output - valid step loss """ input_ids = val_batch["input_ids"] attention_mask = val_batch["attention_mask"] targets = val_batch["targets"] output = self.forward(input_ids, attention_mask) _, y_hat = torch.max(output, dim=1) loss = F.cross_entropy(output, targets) self.val_acc(y_hat, targets) self.log("val_acc", self.val_acc.compute().cpu()) self.log("val_loss", loss, sync_dist=True) def configure_optimizers(self): """ Initializes the optimizer and learning rate scheduler :return: output - Initialized optimizer and scheduler """ optimizer = AdamW(self.parameters(), lr=self.args["lr"]) scheduler = { "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", factor=0.2, patience=2, min_lr=1e-6, verbose=True, ), "monitor": "val_loss", } return [optimizer], [scheduler]
'batch_size': 4, 'learning_rate': .001, 'epochs': 1, 'adam_epsilon': 10**-9 } res = tensorflow_unet.gis_tf_objective(test_config) x_test, y_test = perturbed_tf_gis_test_data() test_acc = res[1].evaluate(x_test, y_test, batch_size=test_config['batch_size']) print(res[0]) print(test_acc) print("PyTorch Model evaluation...") acc, pt_model = pytorch_unet.gis_pt_objective(test_config) test = perturbed_pt_gis_test_data() test_set = PT_GISDataset(test) testloader = DataLoader(test_set, batch_size=int(test_config['batch_size'])) accuracy = Accuracy() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") for sample in tqdm(testloader): cuda_in = sample[0].to(device) out = pt_model(cuda_in) output = out.to('cpu').squeeze(1) accuracy(output, sample[1]) # cuda_in = cuda_in.detach() # label = label.detach() print("AVERAGE ACCURACY:") print(accuracy.compute()[0])
class BertNewsClassifier(pl.LightningModule): def __init__(self, **kwargs): """ Initializes the network, optimizer and scheduler """ super(BertNewsClassifier, self).__init__() self.PRE_TRAINED_MODEL_NAME = "bert-base-uncased" self.bert_model = BertModel.from_pretrained( self.PRE_TRAINED_MODEL_NAME) for param in self.bert_model.parameters(): param.requires_grad = False self.drop = nn.Dropout(p=0.2) # assigning labels self.class_names = ["World", "Sports", "Business", "Sci/Tech"] n_classes = len(self.class_names) self.fc1 = nn.Linear(self.bert_model.config.hidden_size, 512) self.out = nn.Linear(512, n_classes) self.bert_model.embedding = self.bert_model.embeddings self.embedding = self.bert_model.embeddings self.scheduler = None self.optimizer = None self.args = kwargs self.train_acc = Accuracy() self.val_acc = Accuracy() self.test_acc = Accuracy() def forward(self, input_ids, attention_mask): """ :param input_ids: Input data :param attention_maks: Attention mask value :return: output - Type of news for the given news snippet """ output = self.bert_model(input_ids=input_ids, attention_mask=attention_mask) output = F.relu(self.fc1(output.pooler_output)) output = self.drop(output) output = self.out(output) return output def training_step(self, train_batch, batch_idx): """ Training the data as batches and returns training loss on each batch :param train_batch Batch data :param batch_idx: Batch indices :return: output - Training loss """ input_ids = train_batch["input_ids"].to(self.device) attention_mask = train_batch["attention_mask"].to(self.device) targets = train_batch["targets"].to(self.device) output = self.forward(input_ids, attention_mask) _, y_hat = torch.max(output, dim=1) loss = F.cross_entropy(output, targets) self.train_acc(y_hat, targets) self.log("train_acc", self.train_acc.compute()) self.log("train_loss", loss) return {"loss": loss, "acc": self.train_acc.compute()} def test_step(self, test_batch, batch_idx): """ Performs test and computes the accuracy of the model :param test_batch: Batch data :param batch_idx: Batch indices :return: output - Testing accuracy """ input_ids = test_batch["input_ids"].to(self.device) attention_mask = test_batch["attention_mask"].to(self.device) targets = test_batch["targets"].to(self.device) output = self.forward(input_ids, attention_mask) _, y_hat = torch.max(output, dim=1) test_acc = accuracy_score(y_hat.cpu(), targets.cpu()) self.test_acc(y_hat, targets) self.log("test_acc", self.test_acc.compute()) return {"test_acc": torch.tensor(test_acc)} def validation_step(self, val_batch, batch_idx): """ Performs validation of data in batches :param val_batch: Batch data :param batch_idx: Batch indices :return: output - valid step loss """ input_ids = val_batch["input_ids"].to(self.device) attention_mask = val_batch["attention_mask"].to(self.device) targets = val_batch["targets"].to(self.device) output = self.forward(input_ids, attention_mask) _, y_hat = torch.max(output, dim=1) loss = F.cross_entropy(output, targets) self.val_acc(y_hat, targets) self.log("val_acc", self.val_acc.compute()) self.log("val_loss", loss, sync_dist=True) return {"val_step_loss": loss, "acc": self.val_acc.compute()} def configure_optimizers(self): """ Initializes the optimizer and learning rate scheduler :return: output - Initialized optimizer and scheduler """ self.optimizer = AdamW(self.parameters(), lr=self.args["lr"]) self.scheduler = { "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, mode="min", factor=0.2, patience=2, min_lr=1e-6, verbose=True, ), "monitor": "val_loss", } return [self.optimizer], [self.scheduler]
class BertNewsClassifier(pl.LightningModule): #pylint: disable=too-many-ancestors,too-many-instance-attributes """Bert Model Class.""" def __init__(self, **kwargs): """Initializes the network, optimizer and scheduler.""" super(BertNewsClassifier, self).__init__() #pylint: disable=super-with-arguments self.pre_trained_model_name = "bert-base-uncased" #pylint: disable=invalid-name self.bert_model = BertModel.from_pretrained(self.pre_trained_model_name) for param in self.bert_model.parameters(): param.requires_grad = False self.drop = nn.Dropout(p=0.2) # assigning labels self.class_names = ["World", "Sports", "Business", "Sci/Tech"] n_classes = len(self.class_names) self.fc1 = nn.Linear(self.bert_model.config.hidden_size, 512) self.out = nn.Linear(512, n_classes) # self.bert_model.embedding = self.bert_model.embeddings # self.embedding = self.bert_model.embeddings self.scheduler = None self.optimizer = None self.args = kwargs self.train_acc = Accuracy() self.val_acc = Accuracy() self.test_acc = Accuracy() self.preds = [] self.target = [] def compute_bert_outputs( #pylint: disable=no-self-use self, model_bert, embedding_input, attention_mask=None, head_mask=None ): """Computes Bert Outputs. Args: model_bert : the bert model embedding_input : input for bert embeddings. attention_mask : attention mask head_mask : head mask Returns: output : the bert output """ if attention_mask is None: attention_mask = torch.ones( #pylint: disable=no-member embedding_input.shape[0], embedding_input.shape[1] ).to(embedding_input) extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) extended_attention_mask = extended_attention_mask.to( dtype=next(model_bert.parameters()).dtype ) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 if head_mask is not None: if head_mask.dim() == 1: head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze( -1 ).unsqueeze(-1) head_mask = head_mask.expand( model_bert.config.num_hidden_layers, -1, -1, -1, -1 ) elif head_mask.dim() == 2: head_mask = ( head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) ) # We can specify head_mask for each layer head_mask = head_mask.to( dtype=next(model_bert.parameters()).dtype ) # switch to fload if need + fp16 compatibility else: head_mask = [None] * model_bert.config.num_hidden_layers encoder_outputs = model_bert.encoder( embedding_input, extended_attention_mask, head_mask=head_mask ) sequence_output = encoder_outputs[0] pooled_output = model_bert.pooler(sequence_output) outputs = ( sequence_output, pooled_output, ) + encoder_outputs[1:] return outputs def forward(self, input_ids, attention_mask=None): """ Forward function. Args: input_ids: Input data attention_maks: Attention mask value Returns: output - Type of news for the given news snippet """ embedding_input = self.bert_model.embeddings(input_ids) outputs = self.compute_bert_outputs( self.bert_model, embedding_input, attention_mask ) pooled_output = outputs[1] output = torch.tanh(self.fc1(pooled_output)) output = self.drop(output) output = self.out(output) return output def training_step(self, train_batch, batch_idx): """Training the data as batches and returns training loss on each batch. Args: train_batch Batch data batch_idx: Batch indices Returns: output - Training loss """ input_ids = train_batch["input_ids"].to(self.device) attention_mask = train_batch["attention_mask"].to(self.device) targets = train_batch["targets"].to(self.device) output = self.forward(input_ids, attention_mask) _, y_hat = torch.max(output, dim=1) #pylint: disable=no-member loss = F.cross_entropy(output, targets) self.train_acc(y_hat, targets) self.log("train_acc", self.train_acc.compute()) self.log("train_loss", loss) return {"loss": loss, "acc": self.train_acc.compute()} def test_step(self, test_batch, batch_idx): """Performs test and computes the accuracy of the model. Args: test_batch: Batch data batch_idx: Batch indices Returns: output - Testing accuracy """ input_ids = test_batch["input_ids"].to(self.device) attention_mask = test_batch["attention_mask"].to(self.device) targets = test_batch["targets"].to(self.device) output = self.forward(input_ids, attention_mask) _, y_hat = torch.max(output, dim=1) #pylint: disable=no-member test_acc = accuracy_score(y_hat.cpu(), targets.cpu()) self.test_acc(y_hat, targets) self.preds += y_hat.tolist() self.target += targets.tolist() self.log("test_acc", self.test_acc.compute()) return {"test_acc": torch.tensor(test_acc)} #pylint: disable=no-member def validation_step(self, val_batch, batch_idx): """Performs validation of data in batches. Args: val_batch: Batch data batch_idx: Batch indices Returns: output - valid step loss """ input_ids = val_batch["input_ids"].to(self.device) attention_mask = val_batch["attention_mask"].to(self.device) targets = val_batch["targets"].to(self.device) output = self.forward(input_ids, attention_mask) _, y_hat = torch.max(output, dim=1) #pylint: disable=no-member loss = F.cross_entropy(output, targets) self.val_acc(y_hat, targets) self.log("val_acc", self.val_acc.compute()) self.log("val_loss", loss, sync_dist=True) return {"val_step_loss": loss, "acc": self.val_acc.compute()} def configure_optimizers(self): """Initializes the optimizer and learning rate scheduler. Returns: output - Initialized optimizer and scheduler """ self.optimizer = AdamW(self.parameters(), lr=self.args.get("lr", 0.001)) self.scheduler = { "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, mode="min", factor=0.2, patience=2, min_lr=1e-6, verbose=True, ), "monitor": "val_loss", } return [self.optimizer], [self.scheduler]
class Baseline(LightningModule): """ Class for training a classification model on distorted ImageNet inan end-to-end fashion. """ def __init__(self, args): super(Baseline, self).__init__() self.hparams = args self.world_size = self.hparams.num_nodes * self.hparams.gpus self.lr = self.hparams.lr #initalise this specially as a tuneable parameter #(1) Set up the dataset #Here, we use a 100-class subset of ImageNet if self.hparams.dataset != "ImageNet100": raise ValueError("Unsupported dataset selected.") else: if self.hparams.distortion == "None": self.train_set_transform = ImageNetBaseTransform(self.hparams) self.val_set_transform = ImageNetBaseTransformVal(self.hparams) elif self.hparams.distortion == "multi": self.train_set_transform = ImageNetDistortTrainMulti( self.hparams) self.val_set_transform = ImageNetDistortValMulti(self.hparams) else: #If we are using the ImageNet dataset, then set up the train and val sets to use the same mask if needed! self.train_set_transform = ImageNetDistortTrain(self.hparams) if self.hparams.fixed_mask: self.val_set_transform = ImageNetDistortVal( self.hparams, fixed_distortion=self.train_set_transform.distortion) else: self.val_set_transform = ImageNetDistortVal(self.hparams) #(2) Grab the correct baseline pre-trained model if self.hparams.encoder == 'resnet': self.encoder = RESNET_finetune(self.hparams) elif self.hparams.encoder == 'clip': self.encoder = CLIP_finetune(self.hparams) else: raise ValueError("Please select a valid encoder model.") #(3) Set up our criterion - here we use reduction as "sum" so that we are able to average over all validation sets self.criterion = nn.CrossEntropyLoss(reduction="sum") self.train_top_1 = Accuracy(top_k=1) self.train_top_5 = Accuracy(top_k=5) self.val_top_1 = Accuracy(top_k=1) self.val_top_5 = Accuracy(top_k=5) self.test_top_1 = Accuracy(top_k=1) self.test_top_5 = Accuracy(top_k=5) def forward(self, x): return self.encoder(x) def configure_optimizers(self): opt = torch.optim.Adam(self.encoder.parameters(), lr=self.lr) if self.hparams.dataset == "ImageNet100": num_steps = 126689 // ( self.hparams.batch_size * self.hparams.gpus ) #divide N_train by number of distributed iters else: num_steps = 500 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=num_steps) return [opt], [scheduler] #DATALOADERS def train_dataloader(self): if self.hparams.dataset == "ImageNet100": train_dataset = ImageNet100(root=self.hparams.dataset_dir, split='train', transform=self.train_set_transform) train_dataloader = DataLoader(train_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\ pin_memory=True, shuffle=True) return train_dataloader def val_dataloader(self): if self.hparams.dataset == "ImageNet100": val_dataset = ImageNet100(root=self.hparams.dataset_dir, split='val', transform=self.val_set_transform) self.N_val = 5000 val_dataloader = DataLoader(val_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\ pin_memory=True, shuffle=False) return val_dataloader def test_dataloader(self): if self.hparams.dataset == "ImageNet100": test_dataset = ImageNet100(root=self.hparams.dataset_dir, split='val', transform=self.val_set_transform) test_dataloader = DataLoader(test_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\ pin_memory=True, shuffle=False) return test_dataloader #TRAINING def training_step(self, batch, batch_idx): """ Given a batch of images, train the model for one step. Calculates the crossentropy loss as well as top_1 and top_5 per batch Inputs: batch - the images to train on, shape [batch_size, num_channels, height, width] batch_idx - the index of the current batch Returns: loss - the crossentropy loss between the model's logits and the true classes of the inputs """ x, y = batch if batch_idx == 0 and self.current_epoch == 0: self.logger.experiment.add_image('Train_Sample', img_grid(x), self.current_epoch) logits = self.forward(x) pred_probs = logits.softmax(dim=-1) loss = self.criterion(logits, y) self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True, logger=True, sync_dist=True, sync_dist_op='sum') self.log("train_top_1", self.train_top_1(pred_probs, y), prog_bar=False, logger=False) self.log("train_top_5", self.train_top_5(pred_probs, y), prog_bar=False, logger=False) return loss def training_epoch_end(self, outputs): self.log("train_top_1", self.train_top_1.compute(), prog_bar=True, logger=True) self.log("train_top_5", self.train_top_5.compute(), prog_bar=True, logger=True) self.train_top_1.reset() self.train_top_5.reset() #VALIDATION def validation_step(self, batch, batch_idx): x, y = batch if batch_idx == 0 and self.current_epoch == 0: self.logger.experiment.add_image('Val_Sample', img_grid(x), self.current_epoch) logits = self.forward(x) pred_probs = logits.softmax(dim=-1) self.log("val_top_1", self.val_top_1(pred_probs, y), prog_bar=False, logger=False) self.log("val_top_5", self.val_top_5(pred_probs, y), prog_bar=False, logger=False) def validation_epoch_end(self, outputs): self.log("val_top_1", self.val_top_1.compute(), prog_bar=True, logger=True) self.log("val_top_5", self.val_top_5.compute(), prog_bar=True, logger=True) self.val_top_1.reset() self.val_top_5.reset() #TESTING def test_step(self, batch, batch_idx): x, y = batch if batch_idx == 0 and self.current_epoch == 0: self.logger.experiment.add_image('Test_Sample', img_grid(x), self.current_epoch) logits = self.forward(x) pred_probs = logits.softmax(dim=-1) self.log("test_top_1", self.test_top_1(pred_probs, y), prog_bar=False, logger=False) self.log("test_top_5", self.test_top_5(pred_probs, y), prog_bar=False, logger=False) def test_epoch_end(self, outputs): self.log("test_top_1", self.test_top_1.compute(), prog_bar=True, logger=True) self.log("test_top_5", self.test_top_5.compute(), prog_bar=True, logger=True) self.test_top_1.reset() self.test_top_5.reset()
class CIFAR10Classifier(pl.LightningModule): #pylint: disable=too-many-ancestors,too-many-instance-attributes """Cifar10 model class.""" def __init__(self, **kwargs): """Initializes the network, optimizer and scheduler.""" super(CIFAR10Classifier, self).__init__() #pylint: disable=super-with-arguments self.model_conv = models.resnet50(pretrained=True) for param in self.model_conv.parameters(): param.requires_grad = False num_ftrs = self.model_conv.fc.in_features num_classes = 10 self.model_conv.fc = nn.Linear(num_ftrs, num_classes) self.scheduler = None self.optimizer = None self.args = kwargs self.train_acc = Accuracy() self.val_acc = Accuracy() self.test_acc = Accuracy() self.preds = [] self.target = [] def forward(self, x_var): """Forward function.""" out = self.model_conv(x_var) return out def training_step(self, train_batch, batch_idx): """Training Step Args: train_batch : training batch batch_idx : batch id number Returns: train accuracy """ if batch_idx == 0: self.reference_image = (train_batch[0][0]).unsqueeze(0) #pylint: disable=attribute-defined-outside-init # self.reference_image.resize((1,1,28,28)) print("\n\nREFERENCE IMAGE!!!") print(self.reference_image.shape) x_var, y_var = train_batch output = self.forward(x_var) _, y_hat = torch.max(output, dim=1) loss = F.cross_entropy(output, y_var) self.log("train_loss", loss) self.train_acc(y_hat, y_var) self.log("train_acc", self.train_acc.compute()) return {"loss": loss} def test_step(self, test_batch, batch_idx): """Testing step Args: test_batch : test batch data batch_idx : tests batch id Returns: test accuracy """ x_var, y_var = test_batch output = self.forward(x_var) _, y_hat = torch.max(output, dim=1) loss = F.cross_entropy(output, y_var) accelerator = self.args.get("accelerator", None) if accelerator is not None: self.log("test_loss", loss, sync_dist=True) else: self.log("test_loss", loss) self.test_acc(y_hat, y_var) self.preds += y_hat.tolist() self.target += y_var.tolist() self.log("test_acc", self.test_acc.compute()) return {"test_acc": self.test_acc.compute()} def validation_step(self, val_batch, batch_idx): """Testing step. Args: val_batch : val batch data batch_idx : val batch id Returns: validation accuracy """ x_var, y_var = val_batch output = self.forward(x_var) _, y_hat = torch.max(output, dim=1) loss = F.cross_entropy(output, y_var) accelerator = self.args.get("accelerator", None) if accelerator is not None: self.log("val_loss", loss, sync_dist=True) else: self.log("val_loss", loss) self.val_acc(y_hat, y_var) self.log("val_acc", self.val_acc.compute()) return {"val_step_loss": loss, "val_loss": loss} def configure_optimizers(self): """Initializes the optimizer and learning rate scheduler. Returns: output - Initialized optimizer and scheduler """ self.optimizer = torch.optim.Adam(self.parameters(), lr=self.args.get("lr", 0.001)) self.scheduler = { "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, mode="min", factor=0.2, patience=3, min_lr=1e-6, verbose=True, ), "monitor": "val_loss", } return [self.optimizer], [self.scheduler] def makegrid(self, output, numrows): #pylint: disable=no-self-use """Makes grids. Args: output : Tensor output numrows : num of rows. Returns: c_array : gird array """ outer = torch.Tensor.cpu(output).detach() plt.figure(figsize=(20, 5)) b_array = np.array([]).reshape(0, outer.shape[2]) c_array = np.array([]).reshape(numrows * outer.shape[2], 0) i = 0 j = 0 while i < outer.shape[1]: img = outer[0][i] b_array = np.concatenate((img, b_array), axis=0) j += 1 if j == numrows: c_array = np.concatenate((c_array, b_array), axis=1) b_array = np.array([]).reshape(0, outer.shape[2]) j = 0 i += 1 return c_array def show_activations(self, x_var): """Showns activation Args: x_var: x variable """ # logging reference image self.logger.experiment.add_image("input", torch.Tensor.cpu(x_var[0][0]), self.current_epoch, dataformats="HW") # logging layer 1 activations out = self.model_conv.conv1(x_var) c_grid = self.makegrid(out, 4) self.logger.experiment.add_image("layer 1", c_grid, self.current_epoch, dataformats="HW") def training_epoch_end(self, outputs): """Training epoch end. Args: outputs: outputs of train end """ self.show_activations(self.reference_image) # Logging graph if self.current_epoch == 0: sample_img = torch.rand((1, 3, 64, 64)) self.logger.experiment.add_graph(CIFAR10Classifier(), sample_img)
class LightningMNISTClassifier(pl.LightningModule): def __init__(self, **kwargs): """mlflow.start_run() Initializes the network """ super(LightningMNISTClassifier, self).__init__() self.train_acc = Accuracy() self.val_acc = Accuracy() self.test_acc = Accuracy() # mnist images are (1, 28, 28) (channels, width, height) self.layer_1 = torch.nn.Linear(28 * 28, 128) self.layer_2 = torch.nn.Linear(128, 256) self.layer_3 = torch.nn.Linear(256, 10) self.args = kwargs @staticmethod def add_model_specific_args(parent_parser): parser = ArgumentParser(parents=[parent_parser], add_help=False) parser.add_argument( "--lr", type=float, default=0.001, metavar="LR", help="learning rate (default: 0.001)", ) return parser def forward(self, x): """ :param x: Input data :return: output - mnist digit label for the input image """ batch_size = x.size()[0] # (b, 1, 28, 28) -> (b, 1*28*28) x = x.view(batch_size, -1) # layer 1 (b, 1*28*28) -> (b, 128) x = self.layer_1(x) x = torch.relu(x) # layer 2 (b, 128) -> (b, 256) x = self.layer_2(x) x = torch.relu(x) # layer 3 (b, 256) -> (b, 10) x = self.layer_3(x) # probability distribution over labels x = torch.log_softmax(x, dim=1) return x def cross_entropy_loss(self, logits, labels): """ Initializes the loss function :return: output - Initialized cross entropy loss function """ return F.nll_loss(logits, labels) def training_step(self, train_batch, batch_idx): """ Training the data as batches and returns training loss on each batch :param train_batch: Batch data :param batch_idx: Batch indices :return: output - Training loss """ x, y = train_batch logits = self.forward(x) loss = self.cross_entropy_loss(logits, y) _, y_hat = torch.max(logits, dim=1) self.train_acc(y_hat, y) self.log("train_acc", self.train_acc.compute()) self.log("train_loss", loss) return {"loss": loss} def validation_step(self, val_batch, batch_idx): """ Performs validation of data in batches :param val_batch: Batch data :param batch_idx: Batch indices :return: output - valid step loss """ x, y = val_batch logits = self.forward(x) loss = self.cross_entropy_loss(logits, y) _, y_hat = torch.max(logits, dim=1) self.val_acc(y_hat, y) self.log("val_acc", self.val_acc.compute()) self.log("val_loss", loss, sync_dist=True) def test_step(self, test_batch, batch_idx): """ Performs test and computes the accuracy of the model :param test_batch: Batch data :param batch_idx: Batch indices :return: output - Testing accuracy """ x, y = test_batch output = self.forward(x) _, y_hat = torch.max(output, dim=1) self.test_acc(y_hat, y) self.log("test_acc", self.test_acc.compute()) def prepare_data(self): """ Prepares the data for training and prediction """ return {} def configure_optimizers(self): """ Initializes the optimizer and learning rate scheduler :return: output - Initialized optimizer and scheduler """ optimizer = torch.optim.Adam(self.parameters(), lr=self.args["lr"]) scheduler = { "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", factor=0.2, patience=2, min_lr=1e-6, verbose=True, ), "monitor": "val_loss", } return [optimizer], [scheduler]
class HPALit(pl.LightningModule): def __init__(self, params): super().__init__() self.hparams = params self.lr = self.hparams.lr self.save_hyperparameters() self.model = Net(name=self.hparams.model) self.criterion = torch.nn.BCEWithLogitsLoss() self.train_accuracy = Accuracy(subset_accuracy=True) self.val_accuracy = Accuracy(subset_accuracy=True) self.train_recall = Recall() self.val_recall = Recall() self.train_df = pd.read_csv( f'data_preprocessing/train_fold_{self.hparams.fold}.csv') self.valid_df = pd.read_csv( f'data_preprocessing/valid_fold_{self.hparams.fold}.csv') self.train_transforms = A.Compose([ A.Rotate(), A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5), A.Resize(width=self.hparams.img_size, height=self.hparams.img_size), A.Normalize(), ToTensorV2(), ]) self.valid_transforms = A.Compose([ A.Resize(width=self.hparams.img_size, height=self.hparams.img_size), A.Normalize(), ToTensorV2(), ]) self.train_dataset = CellDataset(data_dir=self.hparams.data_dir, csv_file=self.train_df, transform=self.train_transforms) self.val_dataset = CellDataset(data_dir=self.hparams.data_dir, csv_file=self.valid_df, transform=self.valid_transforms) def train_dataloader(self): return DataLoader( self.train_dataset, batch_size=self.hparams.batch_size, shuffle=True, num_workers=self.hparams.n_workers, pin_memory=True, ) def val_dataloader(self): return DataLoader( self.val_dataset, batch_size=self.hparams.batch_size, shuffle=False, num_workers=self.hparams.n_workers, pin_memory=True, ) def configure_optimizers(self): optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=2, eps=1e-6) lr_scheduler = { 'scheduler': scheduler, 'interval': 'epoch', 'monitor': 'valid_loss_epoch' } return [optimizer], [lr_scheduler] def training_step(self, batch, batch_idx): x, y = batch pred = self.model(x) train_loss = self.criterion(pred, y) self.train_accuracy(torch.sigmoid(pred), y.type(torch.int)) self.train_recall(torch.sigmoid(pred), y.type(torch.int)) return {'loss': train_loss} def training_epoch_end(self, outputs): train_loss_epoch = torch.stack([x['loss'] for x in outputs]).mean() self.log('train_loss_epoch', train_loss_epoch) self.log('train_acc_epoch', self.train_accuracy.compute()) self.log('train_recall_epoch', self.train_recall.compute()) self.train_accuracy.reset() self.train_recall.reset() def validation_step(self, batch, batch_idx): x, y = batch pred = self.model(x) val_loss = self.criterion(pred, y) self.val_accuracy(torch.sigmoid(pred), y.type(torch.int)) self.val_recall(torch.sigmoid(pred), y.type(torch.int)) return {'valid_loss': val_loss} def validation_epoch_end(self, outputs): val_loss_epoch = torch.stack([x['valid_loss'] for x in outputs]).mean() self.log('valid_loss_epoch', val_loss_epoch) self.log('valid_acc_epoch', self.val_accuracy.compute()) self.log('valid_recall_epoch', self.val_recall.compute()) self.val_accuracy.reset() self.val_recall.reset()
class TransferLearning(LightningModule): def __init__(self, args): super(TransferLearning, self).__init__() self.hparams = args self.world_size = self.hparams.num_nodes * self.hparams.gpus self.train_set_transform, self.val_set_transform = grab_transforms( self.hparams) #Grab the correct model - only want the embeddings from the final layer! if self.hparams.saved_model_type == 'contrastive': saved_model = NoisyCLIP.load_from_checkpoint( self.hparams.checkpoint_path) self.backbone = saved_model.noisy_visual_encoder elif self.hparams.saved_model_type == 'baseline': saved_model = Baseline.load_from_checkpoint( self.hparams.checkpoint_path) self.backbone = saved_model.encoder.feature_extractor for param in self.backbone.parameters(): param.requires_grad = False #Set up a classifier with the correct dimensions self.output = nn.Linear(self.hparams.emb_dim, self.hparams.num_classes) #Set up the criterion and stuff #(3) Set up our criterion - here we use reduction as "sum" so that we are able to average over all validation sets self.criterion = nn.CrossEntropyLoss(reduction="mean") self.train_top_1 = Accuracy(top_k=1) self.train_top_5 = Accuracy(top_k=5) self.val_top_1 = Accuracy(top_k=1) self.val_top_5 = Accuracy(top_k=5) self.test_top_1 = Accuracy(top_k=1) self.test_top_5 = Accuracy(top_k=5) #class INFECTED has label 0 if self.hparams.dataset == 'COVID': self.val_auc = AUROC(pos_label=0) self.test_auc = AUROC(pos_label=0) def forward(self, x): #Grab the noisy image embeddings self.backbone.eval() with torch.no_grad(): if self.hparams.encoder == "clip": noisy_embeddings = self.backbone(x.type(torch.float16)).float() elif self.hparams.encoder == "resnet": noisy_embeddings = self.backbone(x) return self.output(noisy_embeddings.flatten(1)) def configure_optimizers(self): if not hasattr(self.hparams, 'weight_decay'): self.hparams.weight_decay = 0 opt = torch.optim.Adam(self.output.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay) num_steps = self.hparams.max_epochs scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=num_steps) return [opt], [scheduler] def _grab_dataset(self, split): """ Given a split ("train" or "val" or "test") and a dataset, returns the proper dataset. Dataset needed is defined in this object's hparams Args: split: the split to use in the dataset Returns: dataset: the desired dataset with the correct split """ if self.hparams.dataset == "CIFAR10": if split == 'train': train = True transform = self.train_set_transform else: train = False transform = self.val_set_transform dataset = CIFAR10(root=self.hparams.dataset_dir, train=train, transform=transform, download=True) elif self.hparams.dataset == "CIFAR100": if split == 'train': train = True transform = self.train_set_transform else: train = False transform = self.val_set_transform dataset = CIFAR100(root=self.hparams.dataset_dir, train=train, transform=transform, download=True) elif self.hparams.dataset == 'STL10': if split == 'train': stlsplit = 'train' transform = self.train_set_transform else: stlsplit = 'test' transform = self.val_set_transform dataset = STL10(root=self.hparams.dataset_dir, split=stlsplit, transform=transform, download=True) elif self.hparams.dataset == 'COVID': if split == 'train': covidsplit = 'train' transform = self.train_set_transform else: covidsplit = 'test' transform = self.val_set_transform dataset = torchvision.datasets.ImageFolder( root=self.hparams.dataset_dir + covidsplit, transform=transform) elif self.hparams.dataset == 'ImageNet100B' or self.hparams.dataset == 'imagenet-100B': if split == 'train': transform = self.train_set_transform else: split = 'val' transform = self.val_set_transform dataset = ImageNet100(root=self.hparams.dataset_dir, split=split, transform=transform) elif self.hparams.dataset == 'COVID': if split == 'train': covidsplit = 'train' transform = self.train_set_transform else: covidsplit = 'test' transform = self.val_set_transform dataset = torchvision.datasets.ImageFolder( root=self.hparams.dataset_dir + covidsplit, transform=transform) elif self.hparams.dataset == 'ImageNet100B' or self.hparams.dataset == 'imagenet-100B': if split == 'train': transform = self.train_set_transform else: split = 'val' transform = self.val_set_transform dataset = ImageNet100(root=self.hparams.dataset_dir, split=split, transform=transform) return dataset def train_dataloader(self): train_dataset = self._grab_dataset(split='train') N_train = len(train_dataset) if self.hparams.use_subset: train_dataset = few_shot_dataset( train_dataset, int( np.ceil(N_train * self.hparams.subset_ratio / self.hparams.num_classes))) train_dataloader = DataLoader(train_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\ pin_memory=True, shuffle=True) return train_dataloader def val_dataloader(self): val_dataset = self._grab_dataset(split='val') N_val = len(val_dataset) #SET SHUFFLE TO TRUE SINCE AUROC FREAKS OUT IF IT GETS AN ALL-1 OR ALL-0 BATCH val_dataloader = DataLoader(val_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\ pin_memory=True, shuffle=True) return val_dataloader def test_dataloader(self): test_dataset = self._grab_dataset(split='test') N_test = len(test_dataset) #SET SHUFFLE TO TRUE SINCE AUROC FREAKS OUT IF IT GETS AN ALL-1 OR ALL-0 BATCH test_dataloader = DataLoader(test_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\ pin_memory=True, shuffle=True) return test_dataloader def training_step(self, batch, batch_idx): x, y = batch if batch_idx == 0 and self.current_epoch == 0: self.logger.experiment.add_image('Train_Sample', img_grid(x), self.current_epoch) logits = self.forward(x) loss = self.criterion(logits, y) self.log("train_loss", loss, prog_bar=False, on_step=True, \ on_epoch=True, logger=True, sync_dist=True, sync_dist_op='sum') return loss def validation_step(self, batch, batch_idx): x, y = batch if batch_idx == 0 and self.current_epoch == 0: self.logger.experiment.add_image('Val_Sample', img_grid(x), self.current_epoch) logits = self.forward(x) pred_probs = logits.softmax(dim=-1) #(N, num_classes) if self.hparams.dataset == 'COVID': positive_prob = pred_probs[:, 0].flatten() #class 0 is INFECTED label true_labels = y.flatten() self.val_auc.update(positive_prob, true_labels) self.log("val_auc", self.val_auc, prog_bar=False, logger=False) self.log("val_top_1", self.val_top_1(pred_probs, y), prog_bar=False, logger=False) if self.hparams.dataset != 'COVID': self.log("val_top_5", self.val_top_5(pred_probs, y), prog_bar=False, logger=False) def validation_epoch_end(self, outputs): self.log("val_top_1", self.val_top_1.compute(), prog_bar=True, logger=True) if self.hparams.dataset != 'COVID': self.log("val_top_5", self.val_top_5.compute(), prog_bar=True, logger=True) if self.hparams.dataset == 'COVID': self.log("val_auc", self.val_auc.compute(), prog_bar=True, logger=True) self.val_auc.reset() self.val_top_1.reset() self.val_top_5.reset() def test_step(self, batch, batch_idx): x, y = batch logits = self.forward(x) pred_probs = logits.softmax(dim=-1) if self.hparams.dataset == 'COVID': positive_prob = pred_probs[:, 0].flatten() #class 0 is INFECTED label true_labels = y.flatten() self.test_auc.update(positive_prob, true_labels) self.log("test_auc", self.test_auc, prog_bar=False, logger=False) self.log("test_top_1", self.test_top_1(pred_probs, y), prog_bar=False, logger=False) if self.hparams.dataset != 'COVID': self.log("test_top_5", self.test_top_5(pred_probs, y), prog_bar=False, logger=False) def test_epoch_end(self, outputs): self.log("test_top_1", self.test_top_1.compute(), prog_bar=True, logger=True) if self.hparams.dataset != 'COVID': self.log("test_top_5", self.test_top_5.compute(), prog_bar=True, logger=True) if self.hparams.dataset == 'COVID': self.log("test_auc", self.test_auc.compute(), prog_bar=True, logger=True) self.test_auc.reset() self.test_top_1.reset() self.test_top_5.reset()
class ImageClassifier(pl.LightningModule): """ Basic image classifier. """ def __init__(self, cfg: Config) -> None: super().__init__() # type: ignore self.logger: Union[LoggerCollection, WandbLogger, Any] self.wandb: Run self.cfg = cfg self.model = ConvNet(self.cfg) self.criterion = nn.CrossEntropyLoss() # Metrics self.train_acc = Accuracy() self.val_acc = Accuracy() # ----------------------------------------------------------------------------------------------- # Default PyTorch Lightning hooks # ----------------------------------------------------------------------------------------------- def on_fit_start(self) -> None: """ Hook before `trainer.fit()`. Attaches current wandb run to `self.wandb`. """ if isinstance(self.logger, LoggerCollection): for logger in self.logger: # type: ignore if isinstance(logger, WandbLogger): self.wandb = logger.experiment # type: ignore elif isinstance(self.logger, WandbLogger): self.wandb = self.logger.experiment # type: ignore def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None: """ Hook on checkpoint saving. Adds config and RNG states to the checkpoint file. """ checkpoint['cfg'] = self.cfg checkpoint['rng_torch'] = torch.default_generator.get_state() checkpoint['rng_numpy'] = np.random.get_state() checkpoint['rng_random'] = random.getstate() def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: """ Hook on checkpoint loading. Loads RNG states from the checkpoint file. """ torch.default_generator.set_state(checkpoint['rng_torch']) np.random.set_state(checkpoint['rng_numpy']) random.setstate(checkpoint['rng_random']) # ---------------------------------------------------------------------------------------------- # Optimizers # ---------------------------------------------------------------------------------------------- def configure_optimizers(self) -> Union[Optimizer, Tuple[List[Optimizer], List[_LRScheduler]]]: # type: ignore """ Define system optimization procedure. See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers. Returns ------- Union[Optimizer, Tuple[List[Optimizer], List[_LRScheduler]]] Single optimizer or a combination of optimizers with learning rate schedulers. """ optimizer: Optimizer = instantiate( self.cfg.optim.optimizer, params=self.parameters(), _convert_='all' ) if self.cfg.optim.scheduler is not None: scheduler: _LRScheduler = instantiate( # type: ignore self.cfg.optim.scheduler, optimizer=optimizer, _convert_='all' ) print(optimizer, scheduler) return [optimizer], [scheduler] else: print(optimizer) return optimizer # ---------------------------------------------------------------------------------------------- # Forward # ---------------------------------------------------------------------------------------------- def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore """ Forward pass of the whole system. In this simple case just calls the main model. Parameters ---------- x : torch.Tensor Input tensor. Returns ------- torch.Tensor Output tensor. """ return self.model(x) # ---------------------------------------------------------------------------------------------- # Loss # ---------------------------------------------------------------------------------------------- def calculate_loss(self, outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """ Compute loss value of a batch. In this simple case just forwards computation to default `self.criterion`. Parameters ---------- outputs : torch.Tensor Network outputs with shape (batch_size, n_classes). targets : torch.Tensor Targets (ground-truth labels) with shape (batch_size). Returns ------- torch.Tensor Loss value. """ return self.criterion(outputs, targets) # ---------------------------------------------------------------------------------------------- # Training # ---------------------------------------------------------------------------------------------- def training_step(self, batch: list[torch.Tensor], batch_idx: int) -> dict[str, torch.Tensor]: # type: ignore """ Train on a single batch with loss defined by `self.criterion`. Parameters ---------- batch : list[torch.Tensor] Training batch. batch_idx : int Batch index. Returns ------- dict[str, torch.Tensor] Metric values for a given batch. """ inputs, targets = batch outputs = self(inputs) # basically equivalent to self.forward(data) loss = self.calculate_loss(outputs, targets) self.train_acc(F.softmax(outputs, dim=1), targets) return { 'loss': loss, # no need to return 'train_acc' here since it is always available as `self.train_acc` } def training_epoch_end(self, outputs: list[Any]) -> None: """ Log training metrics. Parameters ---------- outputs : list[Any] List of dictionaries returned by `self.training_step` with batch metrics. """ step = self.current_epoch + 1 metrics = { 'epoch': float(step), 'train_acc': float(self.train_acc.compute().item()), } # Average additional metrics over all batches for key in outputs[0]: metrics[key] = float(self._reduce(outputs, key).item()) self.logger.log_metrics(metrics, step=step) def _reduce(self, outputs: list[Any], key: str): return torch.stack([out[key] for out in outputs]).mean().detach() # ---------------------------------------------------------------------------------------------- # Validation # ---------------------------------------------------------------------------------------------- def validation_step(self, batch: list[torch.Tensor], batch_idx: int) -> dict[str, Any]: # type: ignore """ Compute validation metrics. Parameters ---------- batch : list[torch.Tensor] Validation batch. batch_idx : int Batch index. Returns ------- dict[str, torch.Tensor] Metric values for a given batch. """ inputs, targets = batch outputs = self(inputs) # basically equivalent to self.forward(data) self.val_acc(F.softmax(outputs, dim=1), targets) return { # 'additional_metric': ... # no need to return 'val_acc' here since it is always available as `self.val_acc` } def validation_epoch_end(self, outputs: list[Any]) -> None: """ Log validation metrics. Parameters ---------- outputs : list[Any] List of dictionaries returned by `self.validation_step` with batch metrics. """ step = self.current_epoch + 1 if not self.trainer.running_sanity_check else self.current_epoch # type: ignore metrics = { 'epoch': float(step), 'val_acc': float(self.val_acc.compute().item()), } # Average additional metrics over all batches for key in outputs[0]: metrics[key] = float(self._reduce(outputs, key).item()) self.logger.log_metrics(metrics, step=step)
class IrisClassification(pl.LightningModule): def __init__(self, **kwargs): super(IrisClassification, self).__init__() self.train_acc = Accuracy() self.val_acc = Accuracy() self.test_acc = Accuracy() self.args = kwargs self.fc1 = nn.Linear(4, 10) self.fc2 = nn.Linear(10, 10) self.fc3 = nn.Linear(10, 3) self.cross_entropy_loss = nn.CrossEntropyLoss() def forward(self, x): x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = F.relu(self.fc3(x)) return x @staticmethod def add_model_specific_args(parent_parser): """ Add model specific arguments like learning rate :param parent_parser: Application specific parser :return: Returns the augmented arugument parser """ parser = ArgumentParser(parents=[parent_parser], add_help=False) parser.add_argument( "--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.001)", ) return parser def configure_optimizers(self): return torch.optim.Adam(self.parameters(), self.args.get("lr", 0.01)) def training_step(self, batch, batch_idx): x, y = batch logits = self.forward(x) _, y_hat = torch.max(logits, dim=1) loss = self.cross_entropy_loss(logits, y) self.train_acc(y_hat, y) self.log( "train_acc", self.train_acc.compute(), on_step=False, on_epoch=True, ) self.log("train_loss", loss) return {"loss": loss} def validation_step(self, batch, batch_idx): x, y = batch logits = self.forward(x) _, y_hat = torch.max(logits, dim=1) loss = F.cross_entropy(logits, y) self.val_acc(y_hat, y) self.log("val_acc", self.val_acc.compute()) self.log("val_loss", loss, sync_dist=True) def test_step(self, batch, batch_idx): x, y = batch logits = self.forward(x) _, y_hat = torch.max(logits, dim=1) self.test_acc(y_hat, y) self.log("test_acc", self.test_acc.compute())
class LitClassifier(pl.LightningModule): def __init__(self, hidden_dim=128, learning_rate=1e-3): super().__init__() self.save_hyperparameters() self.train_acc = Accuracy() self.val_acc = Accuracy(compute_on_step=False) self.test_acc = Accuracy(compute_on_step=False) self.example_input_array = torch.rand(10, 28 * 28) self.dims = (1, 28, 28) channels, width, height = self.dims self.model = nn.Sequential( nn.Flatten(), nn.Linear(channels * width * height, self.hparams.hidden_dim), nn.ReLU(), nn.Dropout(0.1), nn.Linear(self.hparams.hidden_dim, self.hparams.hidden_dim), nn.ReLU(), nn.Dropout(0.1), nn.Linear(self.hparams.hidden_dim, 10)) def forward(self, x): x = self.model(x) return x def training_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = F.cross_entropy(y_hat, y) acc = self.train_acc(y_hat, y) self.log("train_acc_step", acc) return {"loss": loss} def training_epoch_end(self, outputs): self.log("epoch_acc", self.train_acc.compute(), prog_bar=True) def validation_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = F.cross_entropy(y_hat, y) self.val_acc(y_hat, y) self.log('valid_loss', loss) def validation_epoch_end(self, outputs): self.log("epoch_val_acc", self.val_acc.compute(), prog_bar=True) def test_step(self, batch, batch_idx): x, y = batch y_hat = self(x) self.test_acc(y_hat, y) loss = F.cross_entropy(y_hat, y) self.log('test_loss', loss) def test_epoch_end(self, outputs): self.log("test_acc", self.test_acc.compute(), prog_bar=True) def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) @staticmethod def add_model_specific_args(parent_parser): parser = ArgumentParser(parents=[parent_parser], add_help=False) parser.add_argument('--hidden_dim', type=int, default=128) parser.add_argument('--learning_rate', type=float, default=0.001) return parser
class Rockfish(pl.LightningModule): def __init__(self, conf): super().__init__() self.save_hyperparameters(conf) self.ke = nn.Embedding(4, 2, max_norm=1.) self.conv1 = ConvBlock(3, 256, 13, stride=3, padding=6) self.conv2 = ConvBlock(256, 256, 7, stride=1, padding=3) self.conv3 = ConvBlock(256, 256, 3, stride=2, padding=1) self.pos_encoder = PositionalEncoding(256, self.hparams.dropout) encoder_layer = nn.TransformerEncoderLayer(256, self.hparams.nhead, self.hparams.dim_ff, self.hparams.dropout, activation='gelu') self.encoder = nn.TransformerEncoder(encoder_layer, self.hparams.nlayers) self.fc1 = nn.Linear(256, 1) self.train_acc = Accuracy() self.val_acc = Accuracy(compute_on_step=False) def forward(self, x, ref_k): ref_k = self.ke(ref_k).transpose(1, 2) x = torch.unsqueeze(x, 1) x = torch.cat((x, ref_k), 1) x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = x.permute(2, 0, 1) x = self.pos_encoder(x) x = self.encoder(x) x = x.permute(1, 2, 0) x = F.adaptive_avg_pool1d(x, 1) x = x.squeeze(-1) return self.fc1(x) @staticmethod def add_module_specific_args(parent_parser): parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) parser.add_argument('--dropout', type=float, default=0.1) parser.add_argument('--nhead', type=int, default=8) parser.add_argument('--dim_ff', type=int, default=1024) parser.add_argument('--nlayers', type=int, default=6) parser.add_argument('--wd', type=float, default=1e-4) parser.add_argument('--lr', type=float, default=1e-4) parser.add_argument('--step_size_up', type=int, default=None) return parser def configure_optimizers(self): optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.wd, eps=1e-6, betas=(0.9, 0.98)) if self.hparams.step_size_up is not None: step_size_up = self.hparams.step_size_up else: steps_per_epoch = self.train_ds_len // self.effective_batch_size step_size_up = 4 * steps_per_epoch scheduler = optim.lr_scheduler.CyclicLR(optimizer, self.hparams.lr / 10, self.hparams.lr, step_size_up=step_size_up, mode='triangular2', cycle_momentum=False) scheduler = {'scheduler': scheduler, 'interval': 'step'} return [optimizer], [scheduler] def training_step(self, batch, batch_idx): x, ref_k, y = batch out = self(x, ref_k).squeeze(1) loss = F.binary_cross_entropy_with_logits(out, y.float()) self.log('train_loss', loss, prog_bar=True) self.log('train_batch_acc', self.train_acc(torch.sigmoid(out), y)) return loss def training_epoch_end(self, outputs): self.log('train_epoch_acc', self.train_acc.compute()) def validation_step(self, batch, batch_idx): x, ref_k, y = batch out = self(x, ref_k).squeeze(1) loss = F.binary_cross_entropy_with_logits(out, y.float()) self.val_acc(torch.sigmoid(out), y) self.log('val_loss', loss) def validation_epoch_end(self, outputs): val_acc = self.val_acc.compute() self.log('val_acc', val_acc, prog_bar=True)
class FastTextLSTMModel(LightningModule): """ Run LSTM over tokens FastText embeddings and take final hidden state, add linear projection and dropout """ def __init__(self, ft_embedding_dim, hidden_dim=64): super().__init__() self.lstm_layer = nn.LSTM(ft_embedding_dim, hidden_dim, batch_first=True, bidirectional=True) self.dropout_layer = nn.Dropout(0.2) self.out_layer = nn.Linear(hidden_dim * 2, 1) self.loss = nn.BCEWithLogitsLoss() self.valid_accuracy = Accuracy() self.test_accuracy = Accuracy() def forward(self, embeddings, labels): """ Forward pass :param embeddings: (batch_size, max_tokens_in_text, ft_embedding_dim) text -> ["hello", ",", "world", ..] -> [9, 56, 72, ..] + padding or cutting to max sequence length :param labels: (batch_size, 1) :return: loss and logits """ batch_size = embeddings.size(0) output, (final_hidden_state, final_cell_state) = self.lstm_layer(embeddings) final_hidden_state = final_hidden_state.transpose(0, 1) final_hidden_state = final_hidden_state.reshape(batch_size, -1) text_hidden = self.dropout_layer(final_hidden_state) logits = self.out_layer.forward(text_hidden) loss = self.loss(logits, labels.type_as(logits)) return loss, logits def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) return [optimizer] def training_step(self, batch, _): inputs, labels = batch loss, logits = self(inputs, labels) return loss def validation_step(self, batch, _): inputs, labels = batch val_loss, logits = self(inputs, labels) if torch.max(labels) == 1: pred = torch.sigmoid(logits) else: pred = torch.softmax(logits, 1) self.valid_accuracy.update(pred, labels.long()) self.log("val_loss", val_loss, prog_bar=True) self.log("val_acc", self.valid_accuracy) def validation_epoch_end(self, outs): self.log("val_acc_epoch", self.valid_accuracy.compute(), prog_bar=True) def test_step(self, batch, _): inputs, labels = batch test_loss, logits = self(inputs, labels) if torch.max(labels) == 1: pred = torch.sigmoid(logits) else: pred = torch.softmax(logits, 1) self.test_accuracy.update(pred, labels.long()) self.log("test_loss", test_loss, prog_bar=True) self.log("test_acc", self.test_accuracy) def test_epoch_end(self, outs): self.log("test_acc_epoch", self.test_accuracy.compute(), prog_bar=True)