Esempio n. 1
0
class IrisClassification(pl.LightningModule):
    def __init__(self, **kwargs):
        super().__init__()

        self.train_acc = Accuracy()
        self.val_acc = Accuracy()
        self.test_acc = Accuracy()
        self.args = kwargs

        self.fc1 = nn.Linear(4, 10)
        self.fc2 = nn.Linear(10, 10)
        self.fc3 = nn.Linear(10, 3)
        self.cross_entropy_loss = nn.CrossEntropyLoss()

        self.lr = kwargs.get("lr", 0.01)
        self.momentum = kwargs.get("momentum", 0.9)
        self.weight_decay = kwargs.get("weight_decay", 0.1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        return x

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(),
                               lr=self.lr,
                               momentum=self.momentum,
                               weight_decay=self.weight_decay)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        self.train_acc(torch.argmax(logits, dim=1), y)
        self.log("train_acc",
                 self.train_acc.compute(),
                 on_step=False,
                 on_epoch=True)
        self.log("loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = F.cross_entropy(logits, y)
        self.val_acc(torch.argmax(logits, dim=1), y)
        self.log("val_acc", self.val_acc.compute())
        self.log("val_loss", loss, sync_dist=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = F.cross_entropy(logits, y)
        self.test_acc(torch.argmax(logits, dim=1), y)
        self.log("test_loss", loss)
        self.log("test_acc", self.test_acc.compute())
Esempio n. 2
0
def test_wrong_params(top_k, threshold):
    preds, target = _input_mcls_prob.preds, _input_mcls_prob.target

    with pytest.raises(ValueError):
        acc = Accuracy(threshold=threshold, top_k=top_k)
        acc(preds, target)
        acc.compute()

    with pytest.raises(ValueError):
        accuracy(preds, target, threshold=threshold, top_k=top_k)
Esempio n. 3
0
class SimpleModel(LightningModule):
    def __init__(self, vocab_size, embedding_dim=32):
        super().__init__()

        self.embeddings_layer = nn.Embedding(vocab_size, embedding_dim)
        self.loss = nn.BCEWithLogitsLoss()
        self.valid_accuracy = Accuracy()
        self.test_accuracy = Accuracy()

    def forward(self, inputs, labels):
        raise NotImplementedError("forward not implemented")

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return [optimizer]

    def training_step(self, batch, _):
        inputs, labels = batch
        loss, logits = self(inputs, labels)
        return loss

    def validation_step(self, batch, _):
        inputs, labels = batch
        val_loss, logits = self(inputs, labels)
        if torch.max(labels) == 1:
            pred = torch.sigmoid(logits)
        else:
            pred = torch.softmax(logits, 1)
        self.valid_accuracy.update(pred, labels.long())
        self.log("val_loss", val_loss, prog_bar=True)
        self.log("val_acc", self.valid_accuracy)

    def validation_epoch_end(self, outs):
        self.log("val_acc_epoch", self.valid_accuracy.compute(), prog_bar=True)

    def test_step(self, batch, _):
        inputs, labels = batch
        test_loss, logits = self(inputs, labels)
        if torch.max(labels) == 1:
            pred = torch.sigmoid(logits)
        else:
            pred = torch.softmax(logits, 1)
        self.test_accuracy.update(pred, labels.long())
        self.log("test_loss", test_loss, prog_bar=True)
        self.log("test_acc", self.test_accuracy)

    def test_epoch_end(self, outs):
        self.log("test_acc_epoch", self.test_accuracy.compute(), prog_bar=True)
def segmentation_model_attack(model,
                              model_type,
                              config,
                              num_classes=NUM_CLASSES):
    """Salt and pepper augmentation of segmentation images, return accuracy - difference between that and normal is
    a measure of resiliency"""

    if model_type == "pt":
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        test = perturbed_pt_gis_test_data()
        test_set = PT_GISDataset(test)
        testloader = DataLoader(test_set, batch_size=int(config['batch_size']))
        accuracy = Accuracy()
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        for sample in tqdm(testloader):
            cuda_in = sample[0].to(device)
            out = model(cuda_in)
            output = out.to('cpu').squeeze(1)
            accuracy(output, sample[1])
            # cuda_in = cuda_in.detach()
            # label = label.detach()
        return accuracy.compute().item()
    elif model_type == "tf":
        x_test, y_test = perturbed_tf_gis_test_data()
        test_acc = model.evaluate(x_test,
                                  y_test,
                                  batch_size=config['batch_size'])
        return test_acc
    else:
        print("Unknown model type, failure.")
        return None
Esempio n. 5
0
class ModelParallelClassificationModel(LightningModule):

    def __init__(self, lr: float = 0.01, num_blocks: int = 5):
        super().__init__()
        self.lr = lr
        self.num_blocks = num_blocks

        self.train_acc = Accuracy()
        self.valid_acc = Accuracy()
        self.test_acc = Accuracy()

    def make_block(self):
        return nn.Sequential(nn.Linear(32, 32, bias=False), nn.ReLU())

    def configure_sharded_model(self) -> None:
        self.model = nn.Sequential(*(self.make_block() for x in range(self.num_blocks)), nn.Linear(32, 3))

    def forward(self, x):
        x = self.model(x)
        # Ensure output is in float32 for softmax operation
        x = x.float()
        logits = F.softmax(x, dim=1)
        return logits

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = F.cross_entropy(logits, y)
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_acc', self.train_acc(logits, y), prog_bar=True, sync_dist=True)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        self.log('val_loss', F.cross_entropy(logits, y), prog_bar=False, sync_dist=True)
        self.log('val_acc', self.valid_acc(logits, y), prog_bar=True, sync_dist=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        self.log('test_loss', F.cross_entropy(logits, y), prog_bar=False, sync_dist=True)
        self.log('test_acc', self.test_acc(logits, y), prog_bar=True, sync_dist=True)

    def predict_step(self, batch, batch_idx, dataloader_idx=None):
        x, y = batch
        logits = self.forward(x)
        self.test_acc(logits, y)
        return self.test_acc.compute()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)

        lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
        return [optimizer], [{
            'scheduler': lr_scheduler,
            'interval': 'step',
        }]
Esempio n. 6
0
def str_accuracy(m: Accuracy, detail: bool = False):
    backup = m.correct, m.total
    metric = m.compute()
    m.correct, m.total = backup
    if math.isnan(metric) or math.isinf(metric):
        return 'N/A'
    elif not detail:
        return f'{metric * 100:.2f}%'
    else:
        return f'{metric * 100:.2f}%(= {m.correct}/{m.total})'
Esempio n. 7
0
def test_topk_accuracy(preds, target, exp_result, k, subset_accuracy):
    topk = Accuracy(top_k=k, subset_accuracy=subset_accuracy)

    for batch in range(preds.shape[0]):
        topk(preds[batch], target[batch])

    assert topk.compute() == exp_result

    # Test functional
    total_samples = target.shape[0] * target.shape[1]

    preds = preds.view(total_samples, 4, -1)
    target = target.view(total_samples, -1)

    assert accuracy(preds, target, top_k=k,
                    subset_accuracy=subset_accuracy) == exp_result
Esempio n. 8
0
class Densenet121Lightning(BaseParticipantModel, pl.LightningModule):
    def __init__(self,
                 num_classes,
                 *args,
                 weights=None,
                 pretrain=True,
                 **kwargs):
        model = torchvision.models.densenet121(pretrained=pretrain)
        model.classifier = Linear(in_features=1024,
                                  out_features=num_classes,
                                  bias=True)
        super().__init__(*args, model=model, **kwargs)
        self.model = model
        self.accuracy = Accuracy()
        self.train_accuracy = Accuracy()
        self.criterion = CrossEntropyLoss(weight=weights)

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        y = y.long()
        logits = self.model(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.log('train/acc/{}'.format(self.participant_name),
                 self.train_accuracy(preds, y))
        self.log('train/loss/{}'.format(self.participant_name), loss.item())
        return loss

    def test_step(self, test_batch, batch_idx):
        x, y = test_batch
        y = y.long()
        logits = self.model(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.accuracy.update(preds, y)
        GlobalConfusionMatrix().update(preds, y)
        return {'loss': loss}

    def test_epoch_end(self, outputs: List[Any]) -> None:
        loss_list = [o['loss'] for o in outputs]
        loss = torch.stack(loss_list)
        self.log(f'sample_num', self.accuracy.total.item())
        self.log(f'test/acc/{self.participant_name}', self.accuracy.compute())
        self.log(f'test/loss/{self.participant_name}', loss.mean().item())
Esempio n. 9
0
class CNNLightning(BaseParticipantModel, pl.LightningModule):

    def __init__(self, only_digits=False, input_channels=1, *args, **kwargs):
        model = CNN_OriginalFedAvg(only_digits=only_digits, input_channels=input_channels)
        super().__init__(*args, model=model, **kwargs)
        self.model = model
        # self.model.apply(init_weights)
        self.accuracy = Accuracy()
        self.train_accuracy = Accuracy()

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        y = y.long()
        logits = self.model(x)
        loss = F.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)

        self.log(f'train/acc/{self.participant_name}', self.train_accuracy(preds, y).item())
        self.log(f'train/loss/{self.participant_name}', loss.mean().item())
        return loss

    def test_step(self, test_batch, batch_idx):
        x, y = test_batch
        y = y.long()
        logits = self.model(x)
        loss = F.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.accuracy.update(preds, y)
        return {'loss': loss}

    def test_epoch_end(
            self, outputs: List[Any]
    ) -> None:
        loss_list = [o['loss'] for o in outputs]
        loss = torch.stack(loss_list)
        self.log(f'sample_num', self.accuracy.total.item())
        self.log(f'test/acc/{self.participant_name}', self.accuracy.compute())
        self.log(f'test/loss/{self.participant_name}', loss.mean().item())
Esempio n. 10
0
class LinearProbe(LightningModule):
    """
    A class to train and evaluate a linear probe on top of representations learned from a noisy clip student.
    """
    def __init__(self, args):
        super(LinearProbe, self).__init__()
        self.hparams = args

        #(1) Set up the dataset
        #Here, we use a 100-class subset of ImageNet
        if self.hparams.dataset not in ["ImageNet100", "CIFAR10", "CIFAR100"]:
            raise ValueError("Unsupported dataset selected.")
        else:
            if self.hparams.distortion == "None":
                self.train_set_transform = ImageNetBaseTransform(self.hparams)
                self.val_set_transform = ImageNetBaseTransformVal(self.hparams)
            elif self.hparams.distortion == "multi":
                self.train_set_transform = ImageNetDistortTrainMulti(self.hparams)
                self.val_set_transform = ImageNetDistortValMulti(self.hparams)

            else:
                #If we are using the ImageNet dataset, then set up the train and val sets to use the same mask if needed!
                self.train_set_transform = ImageNetDistortTrain(self.hparams)

                if self.hparams.fixed_mask:
                    self.val_set_transform = ImageNetDistortVal(self.hparams, fixed_distortion=self.train_set_transform.distortion)
                else:
                    self.val_set_transform = ImageNetDistortVal(self.hparams)

        #This should be initialised as a trained student CLIP network
        if self.hparams.encoder == "clip":
            saved_student = NoisyCLIP.load_from_checkpoint(self.hparams.checkpoint_path)
            self.backbone = saved_student.noisy_visual_encoder
        elif self.hparams.encoder == "clean":
            saved_student = clip.load('RN101', 'cpu', jit=False)[0]
            self.backbone = saved_student.visual


        self.backbone.eval()
        for param in self.backbone.parameters():
            param.requires_grad = False

        #This is the meat
        self.output = nn.Linear(self.hparams.emb_dim, self.hparams.num_classes)

        #Set up training and validation metrics
        self.criterion = nn.CrossEntropyLoss()

        self.val_top_1 = Accuracy(top_k=1)
        self.val_top_5 = Accuracy(top_k=5)
        self.test_top_1 = Accuracy(top_k=1)
        self.test_top_5 = Accuracy(top_k=5)


    def forward(self, x):
        """
        Given a set of images x with shape [N, c, h, w], get their embeddings and then logits.

        Returns: Logits with shape [N, n_classes]
        """

        #Grab the noisy image embeddings
        self.backbone.eval()
        with torch.no_grad():
            if self.hparams.encoder == "clip":
                noisy_embeddings = self.backbone(x.type(torch.float16)).float()
            elif self.hparams.encoder == "clean":
                noisy_embeddings = self.backbone(x.type(torch.float16)).float()

        return self.output(noisy_embeddings)

    def configure_optimizers(self):
        if not hasattr(self.hparams, 'weight_decay'):
            self.hparams.weight_decay = 0

        opt = torch.optim.Adam(self.output.parameters(), lr = self.hparams.lr, weight_decay = self.hparams.weight_decay)

        if self.hparams.dataset == "ImageNet100":
            num_steps = 126689//(self.hparams.batch_size * self.hparams.gpus) #divide N_train by number of distributed iters
            if self.hparams.use_subset:
                num_steps = num_steps * self.hparams.subset_ratio
        else:
            num_steps = 500

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=num_steps)

        return [opt], [scheduler]

    def train_dataloader(self):
        if self.hparams.dataset == "ImageNet100":
            train_dataset = ImageNet100(
                root=self.hparams.dataset_dir,
                split = 'train',
                transform = self.train_set_transform
            )
        N_train = len(train_dataset)
        if self.hparams.use_subset:
            train_dataset = few_shot_dataset(train_dataset, int(np.ceil(N_train*self.hparams.subset_ratio/self.hparams.num_classes)))

        train_dataloader = DataLoader(train_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\
                                        pin_memory=True, shuffle=True)

        return train_dataloader

    def val_dataloader(self):
        if self.hparams.dataset == "ImageNet100":
            val_dataset = ImageNet100(
                root=self.hparams.dataset_dir,
                split = 'val',
                transform = self.val_set_transform
            )

            self.N_val = 5000

        val_dataloader = DataLoader(val_dataset, batch_size=4*self.hparams.batch_size, num_workers=self.hparams.workers,\
                                        pin_memory=True, shuffle=False)

        return val_dataloader

    def training_step(self, batch, batch_idx):
        x, y = batch

        if batch_idx == 0 and self.current_epoch == 0:
            self.logger.experiment.add_image('Train_Sample', img_grid(x), self.current_epoch)

        logits = self.forward(x)

        loss = self.criterion(logits, y)

        self.log("train_loss", loss, prog_bar=True, on_step=True, \
                    on_epoch=True, logger=True, sync_dist=True, sync_dist_op='sum')

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch

        if batch_idx == 0 and self.current_epoch == 0:
            self.logger.experiment.add_image('Val_Sample', img_grid(x), self.current_epoch)

        logits = self.forward(x)
        pred_probs = logits.softmax(dim=-1)

        self.log("val_top_1", self.val_top_1(pred_probs, y), prog_bar=False, logger=False)
        self.log("val_top_5", self.val_top_5(pred_probs, y), prog_bar=False, logger=False)

    def validation_epoch_end(self, outputs):
        self.log("val_top_1", self.val_top_1.compute(), prog_bar=True, logger=True)
        self.log("val_top_5", self.val_top_5.compute(), prog_bar=True, logger=True)
        self.val_top_1.reset()
        self.val_top_5.reset()

    def test_step(self, batch, batch_idx):
        x, y = batch

        logits = self.forward(x)
        pred_probs = logits.softmax(dim=-1)

        self.log("test_top_1", self.test_top_1(pred_probs, y), prog_bar=False, logger=False)
        self.log("test_top_5", self.test_top_5(pred_probs, y), prog_bar=False, logger=False)

    def test_epoch_end(self, outputs):
        self.log("test_top_1", self.test_top_1.compute(), prog_bar=True, logger=True)
        self.log("test_top_5", self.test_top_5.compute(), prog_bar=True, logger=True)
        self.test_top_1.reset()
        self.test_top_5.reset()

    def predict(self, batch, batch_idx, hiddens):
        x, y = batch
        logits = self.forward(x)
        pred_probs = logits.softmax(dim=-1)
        preds = pred_probs.argmax(dim=-1)
        return preds
Esempio n. 11
0
class CIFAR10Classifier(pl.LightningModule):
    def __init__(self, **kwargs):
        """
        Initializes the network, optimizer and scheduler
        """
        super(CIFAR10Classifier, self).__init__()
        self.model_conv = models.resnet50(pretrained=True)
        for param in self.model_conv.parameters():
            param.requires_grad = False
        num_ftrs = self.model_conv.fc.in_features
        num_classes = 10
        self.model_conv.fc = nn.Linear(num_ftrs, num_classes)

        self.scheduler = None
        self.optimizer = None
        self.args = kwargs

        self.train_acc = Accuracy()
        self.val_acc = Accuracy()
        self.test_acc = Accuracy()

    def forward(self, x):
        out = self.model_conv(x)
        return out

    def training_step(self, train_batch, batch_idx):
        if batch_idx == 0:
            self.reference_image = (train_batch[0][0]).unsqueeze(0)
            #self.reference_image.resize((1,1,28,28))
            print("\n\nREFERENCE IMAGE!!!")
            print(self.reference_image.shape)
        x, y = train_batch
        y_hat = self.forward(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("train_loss", loss)
        self.train_acc(y_hat, y)
        self.log("train_acc", self.train_acc.compute())
        return {"loss": loss}

    def test_step(self, test_batch, batch_idx):

        x, y = test_batch
        y_hat = self.forward(x)
        loss = F.cross_entropy(y_hat, y)
        if self.args["accelerator"] is not None:
            self.log("test_loss", loss, sync_dist=True)
        else:
            self.log("test_loss", loss)
        self.test_acc(y_hat, y)
        self.log("test_acc", self.test_acc.compute())
        return {"test_acc": self.test_acc.compute()}

    def validation_step(self, val_batch, batch_idx):

        x, y = val_batch
        y_hat = self.forward(x)
        loss = F.cross_entropy(y_hat, y)
        if self.args["accelerator"] is not None:
            self.log("val_loss", loss, sync_dist=True)
        else:
            self.log("val_loss", loss)
        self.val_acc(y_hat, y)
        self.log("val_acc", self.val_acc.compute())
        return {"val_step_loss": loss, "val_loss": loss}

    def configure_optimizers(self):
        """
        Initializes the optimizer and learning rate scheduler

        :return: output - Initialized optimizer and scheduler
        """
        self.optimizer = torch.optim.Adam(self.parameters(),
                                          lr=self.args["lr"])
        self.scheduler = {
            "scheduler":
            torch.optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer,
                mode="min",
                factor=0.2,
                patience=3,
                min_lr=1e-6,
                verbose=True,
            ),
            "monitor":
            "val_loss",
        }
        return [self.optimizer], [self.scheduler]

    def makegrid(self, output, numrows):
        outer = torch.Tensor.cpu(output).detach()
        plt.figure(figsize=(20, 5))
        b = np.array([]).reshape(0, outer.shape[2])
        c = np.array([]).reshape(numrows * outer.shape[2], 0)
        i = 0
        j = 0
        while i < outer.shape[1]:
            img = outer[0][i]
            b = np.concatenate((img, b), axis=0)
            j += 1
            if j == numrows:
                c = np.concatenate((c, b), axis=1)
                b = np.array([]).reshape(0, outer.shape[2])
                j = 0

            i += 1
        return c

    def showActivations(self, x):

        # logging reference image
        self.logger.experiment.add_image("input",
                                         torch.Tensor.cpu(x[0][0]),
                                         self.current_epoch,
                                         dataformats="HW")

        # logging layer 1 activations
        out = self.model_conv.conv1(x)
        c = self.makegrid(out, 4)
        self.logger.experiment.add_image("layer 1",
                                         c,
                                         self.current_epoch,
                                         dataformats="HW")

    def training_epoch_end(self, outputs):
        self.showActivations(self.reference_image)

        # Logging graph
        if (self.current_epoch == 0):
            sampleImg = torch.rand((1, 3, 64, 64))
            self.logger.experiment.add_graph(CIFAR10Classifier(), sampleImg)
class NoisyCLIP(LightningModule):
    def __init__(self, args):
        """
        A class that trains OpenAI CLIP in a student-teacher fashion to classify distorted images.

        Given two identical pre-trained networks, Teacher - T() and Student S(), we freeze T() and train S().

        Given a batch of {x1, x2, ..., xN}, apply a given distortion to each one to obtain noisy images {y1, y2, ..., yN}.

        Feed original images to T() and obtain embeddings {T(x1), ..., T(xN)} and feed distorted images to S() and obtain embeddings {S(y1), ..., S(yN)}.

        Maximize the similarity between the pairs {(T(x1), S(y1)), ..., (T(xN), S(yN))} while minimizing the similarity between all non-matched pairs.
        """
        super(NoisyCLIP, self).__init__()
        self.hparams = args
        self.world_size = self.hparams.num_nodes * self.hparams.gpus

        #(1) Load the correct dataset class names
        if self.hparams.dataset == "ImageNet100" or self.hparams.dataset == "Imagenet-100":
            self.N_val = 5000 # Default ImageNet validation set, only 100 classes.

            #Retrieve the text labels for classes in order to do zero-shot classification.
            if self.hparams.mapping_and_text_file is None:
                raise ValueError('No file from which to read text labels was specified.')

            text_labels = pickle.load(open(self.hparams.mapping_and_text_file, 'rb'))
            self.text_list = ['A photo of '+label.strip().replace('_',' ') for label in text_labels]
        else:
            raise NotImplementedError('Handling of the dataset not implemented yet.')

        #Set up the dataset
        #Here, we use a 100-class subset of ImageNet
        if self.hparams.dataset != "ImageNet100" and self.hparams.dataset != "Imagenet-100":
            raise ValueError("Unsupported dataset selected.")
        elif not hasattr(self.hparams, 'increasing') or not self.hparams.increasing:
            if self.hparams.distortion == "None":
                self.train_set_transform = ImageNetBaseTrainContrastive(self.hparams)
                self.val_set_transform = ImageNetBaseTransformVal(self.hparams)
            elif self.hparams.distortion == "multi":
                self.train_set_transform = ImageNetDistortTrainMultiContrastive(self.hparams)
                self.val_set_transform = ImageNetDistortValMulti(self.hparams)
            else:
                #If we are using the ImageNet dataset, then set up the train and val sets to use the same mask if needed!
                self.train_set_transform = ImageNetDistortTrainContrastive(self.hparams)

                if self.hparams.fixed_mask:
                    self.val_set_transform = ImageNetDistortVal(self.hparams, fixed_distortion=self.train_set_transform.distortion)
                else:
                    self.val_set_transform = ImageNetDistortVal(self.hparams)

        #(2) set up the teacher CLIP network - freeze it and don't use gradients!
        self.logit_scale = self.hparams.logit_scale
        self.baseclip = clip.load(self.hparams.baseclip_type, self.hparams.device, jit=False)[0]
        self.baseclip.eval()
        self.baseclip.requires_grad_(False)

        #(3) set up the student CLIP network - unfreeze it and use gradients!
        self.noisy_visual_encoder = clip.load(self.hparams.baseclip_type, self.hparams.device, jit=False)[0].visual
        self.noisy_visual_encoder.train()

        #(4) set up the training and validation accuracy metrics.
        self.train_top_1 = Accuracy(top_k=1)
        self.train_top_5 = Accuracy(top_k=5)
        self.val_top_1 = Accuracy(top_k=1)
        self.val_top_5 = Accuracy(top_k=5)

    def criterion(self, input1, input2, reduction='mean'):
        """
        Args:
            input1: Embeddings of the clean/noisy images from the teacher/student. Size [N, embedding_dim].
            input2: Embeddings of the clean/noisy images from the teacher/student (the ones not used as input1). Size [N, embedding_dim].
            reduction: how to scale the final loss
        """
        bsz = input1.shape[0]

        # Use the simclr style InfoNCE
        if self.hparams.loss_type == 'simclr':
            # Create similarity matrix between embeddings.
            full_tensor = torch.cat([input1.unsqueeze(1),input2.unsqueeze(1)], dim=1).view(2*bsz, -1)
            #tensor1 = full_tensor.expand(2*bsz,2*bsz,-1)
            #tensor2 = full_tensor.permute(1,0,2).expand(2*bsz,2*bsz,-1)
            #sim_mat = torch.nn.CosineSimilarity(dim=-1)(tensor1,tensor2)
            full_tensor = full_tensor / full_tensor.norm(dim=-1, keepdim=True)
            sim_mat = full_tensor @ full_tensor.t()
            print(torch.sum(sim_mat < 0))
            # Calculate logits used for the contrastive loss.
            exp_sim_mat = torch.exp(sim_mat/self.hparams.loss_tau)
            mask = torch.ones_like(exp_sim_mat) - torch.eye(2*bsz).type_as(exp_sim_mat)
            logmat = -torch.log(exp_sim_mat)+torch.log(torch.sum(mask*exp_sim_mat, 1))

            #Grab the two off-diagonal similarities
            part1 = torch.sum(torch.diag(logmat, diagonal=1)[np.arange(0,2*bsz,2)])
            part2 = torch.sum(torch.diag(logmat, diagonal=-1)[np.arange(0,2*bsz,2)])

            #Take the mean of the two off-diagonals
            loss = (part1 + part2)/2

        #Use the CLIP-style InfoNCE
        elif self.hparams.loss_type == 'clip':
            # Create similarity matrix between embeddings.
            tensor1 = input1 / input1.norm(dim=-1, keepdim=True)
            tensor2 = input2 / input2.norm(dim=-1, keepdim=True)
            sim_mat = (1/self.hparams.loss_tau)*tensor1 @ tensor2.t()

            #Calculate the cross entropy between the similarities of the positive pairs, counted two ways
            part1 = F.cross_entropy(sim_mat, torch.LongTensor(np.arange(bsz)).to(self.device))
            part2 = F.cross_entropy(sim_mat.t(), torch.LongTensor(np.arange(bsz)).to(self.device))

            #Take the mean of the two off-diagonals
            loss = (part1+part2)/2

        #Take the simple MSE between the clean and noisy embeddings
        elif self.hparams.loss_type == 'mse':
            return F.mse_loss(input2, input1)

        elif self.hparams.loss_type.startswith('simclr_'):
            assert self.hparams.loss_type in ['simclr_ss', 'simclr_st', 'simclr_both']
            # Various schemes for the negative examples
            teacher_embeds = F.normalize(input1, dim=1)
            student_embeds = F.normalize(input2, dim=1)
            # First compute positive examples by taking <S(x_i), T(x_i)>/T for all i
            pos_term = (teacher_embeds * student_embeds).sum(dim=1) / self.hparams.loss_tau
            # Then generate the negative term by constructing various similarity matrices
            if self.hparams.loss_type == 'simclr_ss':
                cov = torch.mm(student_embeds, student_embeds.t())
                sim = torch.exp(cov / self.hparams.loss_tau) # shape is [bsz, bsz]
                neg_term = torch.log(sim.sum(dim=1) - sim.diag())
            elif self.hparams.loss_type == 'simclr_st':
                cov = torch.mm(student_embeds, teacher_embeds.t())
                sim = torch.exp(cov / self.hparams.loss_tau) # shape is [bsz, bsz]
                neg_term = torch.log(sim.sum(dim=1)) # Not removing the diagonal here!
            else:
                cat_embeds = torch.cat([student_embeds, teacher_embeds])
                cov = torch.mm(student_embeds, cat_embeds.t())
                sim = torch.exp(cov / self.hparams.loss_tau) # shape is [bsz, 2 * bsz]
                # and take row-wise sums w/o diagonals and
                neg_term = torch.log(sim.sum(dim=1) - sim.diag())
            # Final loss is
            loss = -1 * (pos_term - neg_term).sum() # (summed and then mean-reduced later)

        else:
            raise ValueError('Loss function not understood.')

        return loss/bsz if reduction == 'mean' else loss


    def configure_optimizers(self):
        optim = torch.optim.Adam(self.noisy_visual_encoder.parameters(), lr=self.hparams.lr)
        num_steps = 126689//(self.hparams.batch_size * self.hparams.gpus) #divide N_train by number of distributed iters
        sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=num_steps)
        return [optim], [sched]

    def encode_noisy_image(self, image):
        """
        Return S(yi) where S() is the student network and yi is distorted images.
        """
        return self.noisy_visual_encoder(image.type(torch.float16))

    def forward(self, image_features, text=None):
        """
        Given a set of noisy image embeddings, calculate the cosine similarity (scaled by temperature) of each image with each class text prompt.
        Calculates the similarity in two ways: logits per image (size = [N, n_classes]), logits per text (size = [n_classes, N]).
        This is mainly used for validation and classification.

        Args:
            image_features: the noisy image embeddings S(yi) where S() is the student and yi = Distort(xi). Shape [N, embedding_dim]
        """

        #load the pre-computed text features and load them on the correct device
        text_features = self.baseclip.encode_text(clip.tokenize(self.text_list).to(self.device))

        # normalized features
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        # cosine similarity as logits
        logits_per_image = self.logit_scale * image_features.type(torch.float16) @ text_features.type(torch.float16).t()
        logits_per_text = self.logit_scale * text_features.type(torch.float16) @ image_features.type(torch.float16).t()

        return logits_per_image, logits_per_text

    # Training methods - here we are concerned with contrastive loss (or MSE) between clean and noisy image embeddings.
    def training_step(self, train_batch, batch_idx):
        """
        Takes a batch of clean and noisy images and returns their respective embeddings.

        Returns:
            embed_clean: T(xi) where T() is the teacher and xi are clean images. Shape [N, embed_dim]
            embed_noisy: S(yi) where S() is the student and yi are noisy images. Shape [N, embed_dim]
        """
        image_clean, image_noisy, labels = train_batch
        embed_clean = self.baseclip.encode_image(image_clean)
        embed_noisy = self.encode_noisy_image(image_noisy)
        return {'embed_clean': embed_clean, 'embed_noisy': embed_noisy}

    def training_step_end(self, outputs):
        """
        Given all the clean and noisy image embeddings form across GPUs from training_step, gather them onto a single GPU and calculate overall loss.
        """
        embed_clean_full = outputs['embed_clean']
        embed_noisy_full = outputs['embed_noisy']
        loss = self.criterion(embed_clean_full, embed_noisy_full)
        self.log('train_loss', loss, prog_bar=False, logger=True, sync_dist=True, on_step=True, on_epoch=True)
        return loss

    # Validation methods - here we are concerned with similarity between noisy image embeddings and classification text embeddings.
    def validation_step(self, test_batch, batch_idx):
        """
        Grab the noisy image embeddings: S(yi), where S() is the student and yi = Distort(xi). Done on each GPU.
        Return these to be evaluated in validation step end.
        """
        images_noisy, labels = test_batch
        if batch_idx == 0 and self.current_epoch < 1:
            self.logger.experiment.add_image('Val_Sample', img_grid(images_noisy), self.current_epoch)
        image_features = self.encode_noisy_image(images_noisy)
        return {'image_features': image_features, 'labels': labels}

    def validation_step_end(self, outputs):
        """
        Gather the noisy image features and their labels from each GPU.
        Then calculate their similarities, convert to probabilities, and calculate accuracy on each GPU.
        """
        image_features_full = outputs['image_features']
        labels_full = outputs['labels']

        image_logits, _ = self.forward(image_features_full)
        image_logits = image_logits.float()
        image_probs = image_logits.softmax(dim=-1)
        self.log('val_top_1_step', self.val_top_1(image_probs, labels_full), prog_bar=False, logger=False)
        self.log('val_top_5_step', self.val_top_5(image_probs, labels_full), prog_bar=False, logger=False)

    def validation_epoch_end(self, outputs):
        """
        Gather the zero-shot validation accuracies from across GPUs and reduce.
        """
        self.log('val_top_1', self.val_top_1.compute(), prog_bar=True, logger=True)
        self.log('val_top_5', self.val_top_5.compute(), prog_bar=True, logger=True)
        self.val_top_1.reset()
        self.val_top_5.reset()

    # Default dataloaders - can be overwritten by datamodule.
    def train_dataloader(self):
        if hasattr(self.hparams, 'increasing') and self.hparams.increasing:
            datatf = ImageNetDistortTrainContrastive(self.hparams, epoch=self.current_epoch)
        else:
            datatf = self.train_set_transform

        train_dataset = ImageNet100(
            root=self.hparams.dataset_dir,
            split = 'train',
            transform = None
        )
        train_contrastive = ContrastiveUnsupervisedDataset(train_dataset, transform_contrastive=datatf, return_label=True)

        train_dataloader = DataLoader(train_contrastive, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\
                                        pin_memory=True, shuffle=True)

        return train_dataloader

    def val_dataloader(self):
        if hasattr(self.hparams, 'increasing') and self.hparams.increasing:
            datatf = ImageNetDistortVal(self.hparams, epoch=self.current_epoch)
        else:
            datatf = self.val_set_transform

        val_dataset = ImageNet100(
            root=self.hparams.dataset_dir,
            split = 'val',
            transform = datatf
        )
        self.N_val = len(val_dataset)

        val_dataloader = DataLoader(val_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\
                                        pin_memory=True, shuffle=False)

        return val_dataloader

    def test_dataloader(self):
        return self.val_dataloader()
Esempio n. 13
0
class BertNewsClassifier(pl.LightningModule):
    def __init__(self, **kwargs):
        """
        Initializes the network, optimizer and scheduler
        """
        super(BertNewsClassifier, self).__init__()

        self.train_acc = Accuracy()
        self.val_acc = Accuracy()
        self.test_acc = Accuracy()

        self.PRE_TRAINED_MODEL_NAME = "bert-base-uncased"
        self.bert_model = BertModel.from_pretrained(
            self.PRE_TRAINED_MODEL_NAME)
        for param in self.bert_model.parameters():
            param.requires_grad = False
        self.drop = nn.Dropout(p=0.2)
        # assigning labels
        self.class_names = ["world", "Sports", "Business", "Sci/Tech"]
        n_classes = len(self.class_names)

        self.fc1 = nn.Linear(self.bert_model.config.hidden_size, 512)
        self.out = nn.Linear(512, n_classes)

        self.args = kwargs

    def forward(self, input_ids, attention_mask):
        """
        :param input_ids: Input data
        :param attention_maks: Attention mask value

        :return: output - Type of news for the given news snippet
        """
        pooled_output = self.bert_model(
            input_ids=input_ids, attention_mask=attention_mask).pooler_output
        output = F.relu(self.fc1(pooled_output))
        output = self.drop(output)
        output = self.out(output)
        return output

    @staticmethod
    def add_model_specific_args(parent_parser):
        """
        Returns the review text and the targets of the specified item

        :param parent_parser: Application specific parser

        :return: Returns the augmented arugument parser
        """
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument(
            "--lr",
            type=float,
            default=0.001,
            metavar="LR",
            help="learning rate (default: 0.001)",
        )
        return parser

    def training_step(self, train_batch, batch_idx):
        """
        Training the data as batches and returns training loss on each batch

        :param train_batch Batch data
        :param batch_idx: Batch indices

        :return: output - Training loss
        """
        input_ids = train_batch["input_ids"]
        attention_mask = train_batch["attention_mask"]
        targets = train_batch["targets"]
        output = self.forward(input_ids, attention_mask)
        _, y_hat = torch.max(output, dim=1)
        loss = F.cross_entropy(output, targets)
        self.train_acc(y_hat, targets)
        self.log("train_acc", self.train_acc.compute().cpu())
        self.log("train_loss", loss.cpu())
        return {"loss": loss}

    def test_step(self, test_batch, batch_idx):
        """
        Performs test and computes the accuracy of the model

        :param test_batch: Batch data
        :param batch_idx: Batch indices

        :return: output - Testing accuracy
        """
        input_ids = test_batch["input_ids"]
        attention_mask = test_batch["attention_mask"]
        targets = test_batch["targets"]
        output = self.forward(input_ids, attention_mask)
        _, y_hat = torch.max(output, dim=1)
        self.test_acc(y_hat, targets)
        self.log("test_acc", self.test_acc.compute().cpu())

    def validation_step(self, val_batch, batch_idx):
        """
        Performs validation of data in batches

        :param val_batch: Batch data
        :param batch_idx: Batch indices

        :return: output - valid step loss
        """

        input_ids = val_batch["input_ids"]
        attention_mask = val_batch["attention_mask"]
        targets = val_batch["targets"]
        output = self.forward(input_ids, attention_mask)
        _, y_hat = torch.max(output, dim=1)
        loss = F.cross_entropy(output, targets)
        self.val_acc(y_hat, targets)
        self.log("val_acc", self.val_acc.compute().cpu())
        self.log("val_loss", loss, sync_dist=True)

    def configure_optimizers(self):
        """
        Initializes the optimizer and learning rate scheduler

        :return: output - Initialized optimizer and scheduler
        """
        optimizer = AdamW(self.parameters(), lr=self.args["lr"])
        scheduler = {
            "scheduler":
            torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                mode="min",
                factor=0.2,
                patience=2,
                min_lr=1e-6,
                verbose=True,
            ),
            "monitor":
            "val_loss",
        }
        return [optimizer], [scheduler]
Esempio n. 14
0
        'batch_size': 4,
        'learning_rate': .001,
        'epochs': 1,
        'adam_epsilon': 10**-9
    }
    res = tensorflow_unet.gis_tf_objective(test_config)
    x_test, y_test = perturbed_tf_gis_test_data()
    test_acc = res[1].evaluate(x_test,
                               y_test,
                               batch_size=test_config['batch_size'])
    print(res[0])
    print(test_acc)

    print("PyTorch Model evaluation...")
    acc, pt_model = pytorch_unet.gis_pt_objective(test_config)
    test = perturbed_pt_gis_test_data()
    test_set = PT_GISDataset(test)
    testloader = DataLoader(test_set,
                            batch_size=int(test_config['batch_size']))
    accuracy = Accuracy()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    for sample in tqdm(testloader):
        cuda_in = sample[0].to(device)
        out = pt_model(cuda_in)
        output = out.to('cpu').squeeze(1)
        accuracy(output, sample[1])
        # cuda_in = cuda_in.detach()
        # label = label.detach()
    print("AVERAGE ACCURACY:")
    print(accuracy.compute()[0])
Esempio n. 15
0
class BertNewsClassifier(pl.LightningModule):
    def __init__(self, **kwargs):
        """
        Initializes the network, optimizer and scheduler
        """
        super(BertNewsClassifier, self).__init__()
        self.PRE_TRAINED_MODEL_NAME = "bert-base-uncased"
        self.bert_model = BertModel.from_pretrained(
            self.PRE_TRAINED_MODEL_NAME)
        for param in self.bert_model.parameters():
            param.requires_grad = False
        self.drop = nn.Dropout(p=0.2)
        # assigning labels
        self.class_names = ["World", "Sports", "Business", "Sci/Tech"]
        n_classes = len(self.class_names)

        self.fc1 = nn.Linear(self.bert_model.config.hidden_size, 512)
        self.out = nn.Linear(512, n_classes)
        self.bert_model.embedding = self.bert_model.embeddings
        self.embedding = self.bert_model.embeddings

        self.scheduler = None
        self.optimizer = None
        self.args = kwargs

        self.train_acc = Accuracy()
        self.val_acc = Accuracy()
        self.test_acc = Accuracy()

    def forward(self, input_ids, attention_mask):
        """
        :param input_ids: Input data
        :param attention_maks: Attention mask value

        :return: output - Type of news for the given news snippet
        """
        output = self.bert_model(input_ids=input_ids,
                                 attention_mask=attention_mask)
        output = F.relu(self.fc1(output.pooler_output))
        output = self.drop(output)
        output = self.out(output)
        return output

    def training_step(self, train_batch, batch_idx):
        """
        Training the data as batches and returns training loss on each batch

        :param train_batch Batch data
        :param batch_idx: Batch indices

        :return: output - Training loss
        """
        input_ids = train_batch["input_ids"].to(self.device)
        attention_mask = train_batch["attention_mask"].to(self.device)
        targets = train_batch["targets"].to(self.device)
        output = self.forward(input_ids, attention_mask)
        _, y_hat = torch.max(output, dim=1)
        loss = F.cross_entropy(output, targets)
        self.train_acc(y_hat, targets)
        self.log("train_acc", self.train_acc.compute())
        self.log("train_loss", loss)
        return {"loss": loss, "acc": self.train_acc.compute()}

    def test_step(self, test_batch, batch_idx):
        """
        Performs test and computes the accuracy of the model

        :param test_batch: Batch data
        :param batch_idx: Batch indices

        :return: output - Testing accuracy
        """
        input_ids = test_batch["input_ids"].to(self.device)
        attention_mask = test_batch["attention_mask"].to(self.device)
        targets = test_batch["targets"].to(self.device)
        output = self.forward(input_ids, attention_mask)
        _, y_hat = torch.max(output, dim=1)
        test_acc = accuracy_score(y_hat.cpu(), targets.cpu())
        self.test_acc(y_hat, targets)
        self.log("test_acc", self.test_acc.compute())
        return {"test_acc": torch.tensor(test_acc)}

    def validation_step(self, val_batch, batch_idx):
        """
        Performs validation of data in batches

        :param val_batch: Batch data
        :param batch_idx: Batch indices

        :return: output - valid step loss
        """

        input_ids = val_batch["input_ids"].to(self.device)
        attention_mask = val_batch["attention_mask"].to(self.device)
        targets = val_batch["targets"].to(self.device)
        output = self.forward(input_ids, attention_mask)
        _, y_hat = torch.max(output, dim=1)
        loss = F.cross_entropy(output, targets)
        self.val_acc(y_hat, targets)
        self.log("val_acc", self.val_acc.compute())
        self.log("val_loss", loss, sync_dist=True)
        return {"val_step_loss": loss, "acc": self.val_acc.compute()}

    def configure_optimizers(self):
        """
        Initializes the optimizer and learning rate scheduler

        :return: output - Initialized optimizer and scheduler
        """
        self.optimizer = AdamW(self.parameters(), lr=self.args["lr"])
        self.scheduler = {
            "scheduler":
            torch.optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer,
                mode="min",
                factor=0.2,
                patience=2,
                min_lr=1e-6,
                verbose=True,
            ),
            "monitor":
            "val_loss",
        }
        return [self.optimizer], [self.scheduler]
Esempio n. 16
0
class BertNewsClassifier(pl.LightningModule):  #pylint: disable=too-many-ancestors,too-many-instance-attributes
    """Bert Model Class."""

    def __init__(self, **kwargs):
        """Initializes the network, optimizer and scheduler."""
        super(BertNewsClassifier, self).__init__()  #pylint: disable=super-with-arguments
        self.pre_trained_model_name = "bert-base-uncased"  #pylint: disable=invalid-name
        self.bert_model = BertModel.from_pretrained(self.pre_trained_model_name)
        for param in self.bert_model.parameters():
            param.requires_grad = False
        self.drop = nn.Dropout(p=0.2)
        # assigning labels
        self.class_names = ["World", "Sports", "Business", "Sci/Tech"]
        n_classes = len(self.class_names)

        self.fc1 = nn.Linear(self.bert_model.config.hidden_size, 512)
        self.out = nn.Linear(512, n_classes)
        # self.bert_model.embedding = self.bert_model.embeddings
        # self.embedding = self.bert_model.embeddings

        self.scheduler = None
        self.optimizer = None
        self.args = kwargs

        self.train_acc = Accuracy()
        self.val_acc = Accuracy()
        self.test_acc = Accuracy()

        self.preds = []
        self.target = []

    def compute_bert_outputs(  #pylint: disable=no-self-use
        self, model_bert, embedding_input, attention_mask=None, head_mask=None
    ):
        """Computes Bert Outputs.

        Args:
            model_bert : the bert model
            embedding_input : input for bert embeddings.
            attention_mask : attention  mask
            head_mask : head mask
        Returns:
            output : the bert output
        """
        if attention_mask is None:
            attention_mask = torch.ones(  #pylint: disable=no-member
                embedding_input.shape[0], embedding_input.shape[1]
            ).to(embedding_input)

        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        extended_attention_mask = extended_attention_mask.to(
            dtype=next(model_bert.parameters()).dtype
        )  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        if head_mask is not None:
            if head_mask.dim() == 1:
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(
                    -1
                ).unsqueeze(-1)
                head_mask = head_mask.expand(
                    model_bert.config.num_hidden_layers, -1, -1, -1, -1
                )
            elif head_mask.dim() == 2:
                head_mask = (
                    head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
                )  # We can specify head_mask for each layer
            head_mask = head_mask.to(
                dtype=next(model_bert.parameters()).dtype
            )  # switch to fload if need + fp16 compatibility
        else:
            head_mask = [None] * model_bert.config.num_hidden_layers

        encoder_outputs = model_bert.encoder(
            embedding_input, extended_attention_mask, head_mask=head_mask
        )
        sequence_output = encoder_outputs[0]
        pooled_output = model_bert.pooler(sequence_output)
        outputs = (
            sequence_output,
            pooled_output,
        ) + encoder_outputs[1:]
        return outputs

    def forward(self, input_ids, attention_mask=None):
        """ Forward function.
        Args:
            input_ids: Input data
            attention_maks: Attention mask value

        Returns:
             output - Type of news for the given news snippet
        """
        embedding_input = self.bert_model.embeddings(input_ids)
        outputs = self.compute_bert_outputs(
            self.bert_model, embedding_input, attention_mask
        )
        pooled_output = outputs[1]
        output = torch.tanh(self.fc1(pooled_output))
        output = self.drop(output)
        output = self.out(output)
        return output

    def training_step(self, train_batch, batch_idx):
        """Training the data as batches and returns training loss on each
        batch.

        Args:
            train_batch Batch data
            batch_idx: Batch indices

        Returns:
            output - Training loss
        """
        input_ids = train_batch["input_ids"].to(self.device)
        attention_mask = train_batch["attention_mask"].to(self.device)
        targets = train_batch["targets"].to(self.device)
        output = self.forward(input_ids, attention_mask)
        _, y_hat = torch.max(output, dim=1)  #pylint: disable=no-member
        loss = F.cross_entropy(output, targets)
        self.train_acc(y_hat, targets)
        self.log("train_acc", self.train_acc.compute())
        self.log("train_loss", loss)
        return {"loss": loss, "acc": self.train_acc.compute()}

    def test_step(self, test_batch, batch_idx):
        """Performs test and computes the accuracy of the model.

        Args:
             test_batch: Batch data
             batch_idx: Batch indices

        Returns:
             output - Testing accuracy
        """
        input_ids = test_batch["input_ids"].to(self.device)
        attention_mask = test_batch["attention_mask"].to(self.device)
        targets = test_batch["targets"].to(self.device)
        output = self.forward(input_ids, attention_mask)
        _, y_hat = torch.max(output, dim=1)  #pylint: disable=no-member
        test_acc = accuracy_score(y_hat.cpu(), targets.cpu())
        self.test_acc(y_hat, targets)
        self.preds += y_hat.tolist()
        self.target += targets.tolist()
        self.log("test_acc", self.test_acc.compute())
        return {"test_acc": torch.tensor(test_acc)}  #pylint: disable=no-member

    def validation_step(self, val_batch, batch_idx):
        """Performs validation of data in batches.

        Args:
             val_batch: Batch data
             batch_idx: Batch indices

        Returns:
             output - valid step loss
        """

        input_ids = val_batch["input_ids"].to(self.device)
        attention_mask = val_batch["attention_mask"].to(self.device)
        targets = val_batch["targets"].to(self.device)
        output = self.forward(input_ids, attention_mask)
        _, y_hat = torch.max(output, dim=1)  #pylint: disable=no-member
        loss = F.cross_entropy(output, targets)
        self.val_acc(y_hat, targets)
        self.log("val_acc", self.val_acc.compute())
        self.log("val_loss", loss, sync_dist=True)
        return {"val_step_loss": loss, "acc": self.val_acc.compute()}

    def configure_optimizers(self):
        """Initializes the optimizer and learning rate scheduler.

        Returns:
             output - Initialized optimizer and scheduler
        """
        self.optimizer = AdamW(self.parameters(), lr=self.args.get("lr", 0.001))
        self.scheduler = {
            "scheduler":
                torch.optim.lr_scheduler.ReduceLROnPlateau(
                    self.optimizer,
                    mode="min",
                    factor=0.2,
                    patience=2,
                    min_lr=1e-6,
                    verbose=True,
                ),
            "monitor":
                "val_loss",
        }
        return [self.optimizer], [self.scheduler]
Esempio n. 17
0
class Baseline(LightningModule):
    """
    Class for training a classification model on distorted ImageNet inan end-to-end fashion.
    """
    def __init__(self, args):
        super(Baseline, self).__init__()

        self.hparams = args
        self.world_size = self.hparams.num_nodes * self.hparams.gpus

        self.lr = self.hparams.lr  #initalise this specially as a tuneable parameter

        #(1) Set up the dataset
        #Here, we use a 100-class subset of ImageNet
        if self.hparams.dataset != "ImageNet100":
            raise ValueError("Unsupported dataset selected.")
        else:
            if self.hparams.distortion == "None":
                self.train_set_transform = ImageNetBaseTransform(self.hparams)
                self.val_set_transform = ImageNetBaseTransformVal(self.hparams)
            elif self.hparams.distortion == "multi":
                self.train_set_transform = ImageNetDistortTrainMulti(
                    self.hparams)
                self.val_set_transform = ImageNetDistortValMulti(self.hparams)
            else:
                #If we are using the ImageNet dataset, then set up the train and val sets to use the same mask if needed!
                self.train_set_transform = ImageNetDistortTrain(self.hparams)

                if self.hparams.fixed_mask:
                    self.val_set_transform = ImageNetDistortVal(
                        self.hparams,
                        fixed_distortion=self.train_set_transform.distortion)
                else:
                    self.val_set_transform = ImageNetDistortVal(self.hparams)

        #(2) Grab the correct baseline pre-trained model
        if self.hparams.encoder == 'resnet':
            self.encoder = RESNET_finetune(self.hparams)
        elif self.hparams.encoder == 'clip':
            self.encoder = CLIP_finetune(self.hparams)
        else:
            raise ValueError("Please select a valid encoder model.")

        #(3) Set up our criterion - here we use reduction as "sum" so that we are able to average over all validation sets
        self.criterion = nn.CrossEntropyLoss(reduction="sum")

        self.train_top_1 = Accuracy(top_k=1)
        self.train_top_5 = Accuracy(top_k=5)

        self.val_top_1 = Accuracy(top_k=1)
        self.val_top_5 = Accuracy(top_k=5)

        self.test_top_1 = Accuracy(top_k=1)
        self.test_top_5 = Accuracy(top_k=5)

    def forward(self, x):
        return self.encoder(x)

    def configure_optimizers(self):
        opt = torch.optim.Adam(self.encoder.parameters(), lr=self.lr)

        if self.hparams.dataset == "ImageNet100":
            num_steps = 126689 // (
                self.hparams.batch_size * self.hparams.gpus
            )  #divide N_train by number of distributed iters
        else:
            num_steps = 500

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt,
                                                               T_max=num_steps)

        return [opt], [scheduler]

    #DATALOADERS
    def train_dataloader(self):
        if self.hparams.dataset == "ImageNet100":
            train_dataset = ImageNet100(root=self.hparams.dataset_dir,
                                        split='train',
                                        transform=self.train_set_transform)

        train_dataloader = DataLoader(train_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\
                                        pin_memory=True, shuffle=True)

        return train_dataloader

    def val_dataloader(self):
        if self.hparams.dataset == "ImageNet100":
            val_dataset = ImageNet100(root=self.hparams.dataset_dir,
                                      split='val',
                                      transform=self.val_set_transform)

            self.N_val = 5000

        val_dataloader = DataLoader(val_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\
                                        pin_memory=True, shuffle=False)

        return val_dataloader

    def test_dataloader(self):
        if self.hparams.dataset == "ImageNet100":
            test_dataset = ImageNet100(root=self.hparams.dataset_dir,
                                       split='val',
                                       transform=self.val_set_transform)

        test_dataloader = DataLoader(test_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\
                                        pin_memory=True, shuffle=False)

        return test_dataloader

    #TRAINING
    def training_step(self, batch, batch_idx):
        """
        Given a batch of images, train the model for one step.
        Calculates the crossentropy loss as well as top_1 and top_5 per batch

        Inputs:
            batch - the images to train on, shape [batch_size, num_channels, height, width]
            batch_idx - the index of the current batch
        
        Returns:
            loss - the crossentropy loss between the model's logits and the true classes of the inputs
        """
        x, y = batch

        if batch_idx == 0 and self.current_epoch == 0:
            self.logger.experiment.add_image('Train_Sample', img_grid(x),
                                             self.current_epoch)

        logits = self.forward(x)
        pred_probs = logits.softmax(dim=-1)

        loss = self.criterion(logits, y)

        self.log("train_loss",
                 loss,
                 prog_bar=True,
                 on_step=True,
                 on_epoch=True,
                 logger=True,
                 sync_dist=True,
                 sync_dist_op='sum')
        self.log("train_top_1",
                 self.train_top_1(pred_probs, y),
                 prog_bar=False,
                 logger=False)
        self.log("train_top_5",
                 self.train_top_5(pred_probs, y),
                 prog_bar=False,
                 logger=False)

        return loss

    def training_epoch_end(self, outputs):
        self.log("train_top_1",
                 self.train_top_1.compute(),
                 prog_bar=True,
                 logger=True)
        self.log("train_top_5",
                 self.train_top_5.compute(),
                 prog_bar=True,
                 logger=True)

        self.train_top_1.reset()
        self.train_top_5.reset()

    #VALIDATION
    def validation_step(self, batch, batch_idx):
        x, y = batch

        if batch_idx == 0 and self.current_epoch == 0:
            self.logger.experiment.add_image('Val_Sample', img_grid(x),
                                             self.current_epoch)

        logits = self.forward(x)
        pred_probs = logits.softmax(dim=-1)

        self.log("val_top_1",
                 self.val_top_1(pred_probs, y),
                 prog_bar=False,
                 logger=False)
        self.log("val_top_5",
                 self.val_top_5(pred_probs, y),
                 prog_bar=False,
                 logger=False)

    def validation_epoch_end(self, outputs):
        self.log("val_top_1",
                 self.val_top_1.compute(),
                 prog_bar=True,
                 logger=True)
        self.log("val_top_5",
                 self.val_top_5.compute(),
                 prog_bar=True,
                 logger=True)

        self.val_top_1.reset()
        self.val_top_5.reset()

    #TESTING
    def test_step(self, batch, batch_idx):
        x, y = batch

        if batch_idx == 0 and self.current_epoch == 0:
            self.logger.experiment.add_image('Test_Sample', img_grid(x),
                                             self.current_epoch)

        logits = self.forward(x)
        pred_probs = logits.softmax(dim=-1)

        self.log("test_top_1",
                 self.test_top_1(pred_probs, y),
                 prog_bar=False,
                 logger=False)
        self.log("test_top_5",
                 self.test_top_5(pred_probs, y),
                 prog_bar=False,
                 logger=False)

    def test_epoch_end(self, outputs):
        self.log("test_top_1",
                 self.test_top_1.compute(),
                 prog_bar=True,
                 logger=True)
        self.log("test_top_5",
                 self.test_top_5.compute(),
                 prog_bar=True,
                 logger=True)

        self.test_top_1.reset()
        self.test_top_5.reset()
Esempio n. 18
0
class CIFAR10Classifier(pl.LightningModule):  #pylint: disable=too-many-ancestors,too-many-instance-attributes
    """Cifar10 model class."""
    def __init__(self, **kwargs):
        """Initializes the network, optimizer and scheduler."""
        super(CIFAR10Classifier, self).__init__()  #pylint: disable=super-with-arguments
        self.model_conv = models.resnet50(pretrained=True)
        for param in self.model_conv.parameters():
            param.requires_grad = False
        num_ftrs = self.model_conv.fc.in_features
        num_classes = 10
        self.model_conv.fc = nn.Linear(num_ftrs, num_classes)

        self.scheduler = None
        self.optimizer = None
        self.args = kwargs

        self.train_acc = Accuracy()
        self.val_acc = Accuracy()
        self.test_acc = Accuracy()

        self.preds = []
        self.target = []

    def forward(self, x_var):
        """Forward function."""
        out = self.model_conv(x_var)
        return out

    def training_step(self, train_batch, batch_idx):
        """Training Step
        Args:
             train_batch : training batch
             batch_idx : batch id number
        Returns:
            train accuracy
        """
        if batch_idx == 0:
            self.reference_image = (train_batch[0][0]).unsqueeze(0)  #pylint: disable=attribute-defined-outside-init
            # self.reference_image.resize((1,1,28,28))
            print("\n\nREFERENCE IMAGE!!!")
            print(self.reference_image.shape)
        x_var, y_var = train_batch
        output = self.forward(x_var)
        _, y_hat = torch.max(output, dim=1)
        loss = F.cross_entropy(output, y_var)
        self.log("train_loss", loss)
        self.train_acc(y_hat, y_var)
        self.log("train_acc", self.train_acc.compute())
        return {"loss": loss}

    def test_step(self, test_batch, batch_idx):
        """Testing step
        Args:
             test_batch : test batch data
             batch_idx : tests batch id
        Returns:
             test accuracy
        """

        x_var, y_var = test_batch
        output = self.forward(x_var)
        _, y_hat = torch.max(output, dim=1)
        loss = F.cross_entropy(output, y_var)
        accelerator = self.args.get("accelerator", None)
        if accelerator is not None:
            self.log("test_loss", loss, sync_dist=True)
        else:
            self.log("test_loss", loss)
        self.test_acc(y_hat, y_var)
        self.preds += y_hat.tolist()
        self.target += y_var.tolist()

        self.log("test_acc", self.test_acc.compute())
        return {"test_acc": self.test_acc.compute()}

    def validation_step(self, val_batch, batch_idx):
        """Testing step.

        Args:
             val_batch : val batch data
             batch_idx : val batch id
        Returns:
             validation accuracy
        """

        x_var, y_var = val_batch
        output = self.forward(x_var)
        _, y_hat = torch.max(output, dim=1)
        loss = F.cross_entropy(output, y_var)
        accelerator = self.args.get("accelerator", None)
        if accelerator is not None:
            self.log("val_loss", loss, sync_dist=True)
        else:
            self.log("val_loss", loss)
        self.val_acc(y_hat, y_var)
        self.log("val_acc", self.val_acc.compute())
        return {"val_step_loss": loss, "val_loss": loss}

    def configure_optimizers(self):
        """Initializes the optimizer and learning rate scheduler.

        Returns:
             output - Initialized optimizer and scheduler
        """
        self.optimizer = torch.optim.Adam(self.parameters(),
                                          lr=self.args.get("lr", 0.001))
        self.scheduler = {
            "scheduler":
            torch.optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer,
                mode="min",
                factor=0.2,
                patience=3,
                min_lr=1e-6,
                verbose=True,
            ),
            "monitor":
            "val_loss",
        }
        return [self.optimizer], [self.scheduler]

    def makegrid(self, output, numrows):  #pylint: disable=no-self-use
        """Makes grids.

        Args:
             output : Tensor output
             numrows : num of rows.
        Returns:
             c_array : gird array
        """
        outer = torch.Tensor.cpu(output).detach()
        plt.figure(figsize=(20, 5))
        b_array = np.array([]).reshape(0, outer.shape[2])
        c_array = np.array([]).reshape(numrows * outer.shape[2], 0)
        i = 0
        j = 0
        while i < outer.shape[1]:
            img = outer[0][i]
            b_array = np.concatenate((img, b_array), axis=0)
            j += 1
            if j == numrows:
                c_array = np.concatenate((c_array, b_array), axis=1)
                b_array = np.array([]).reshape(0, outer.shape[2])
                j = 0

            i += 1
        return c_array

    def show_activations(self, x_var):
        """Showns activation
        Args:
             x_var: x variable
        """

        # logging reference image
        self.logger.experiment.add_image("input",
                                         torch.Tensor.cpu(x_var[0][0]),
                                         self.current_epoch,
                                         dataformats="HW")

        # logging layer 1 activations
        out = self.model_conv.conv1(x_var)
        c_grid = self.makegrid(out, 4)
        self.logger.experiment.add_image("layer 1",
                                         c_grid,
                                         self.current_epoch,
                                         dataformats="HW")

    def training_epoch_end(self, outputs):
        """Training epoch end.

        Args:
             outputs: outputs of train end
        """
        self.show_activations(self.reference_image)

        # Logging graph
        if self.current_epoch == 0:
            sample_img = torch.rand((1, 3, 64, 64))
            self.logger.experiment.add_graph(CIFAR10Classifier(), sample_img)
Esempio n. 19
0
class LightningMNISTClassifier(pl.LightningModule):
    def __init__(self, **kwargs):
        """mlflow.start_run()
        Initializes the network
        """
        super(LightningMNISTClassifier, self).__init__()

        self.train_acc = Accuracy()
        self.val_acc = Accuracy()
        self.test_acc = Accuracy()

        # mnist images are (1, 28, 28) (channels, width, height)
        self.layer_1 = torch.nn.Linear(28 * 28, 128)
        self.layer_2 = torch.nn.Linear(128, 256)
        self.layer_3 = torch.nn.Linear(256, 10)
        self.args = kwargs

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument(
            "--lr",
            type=float,
            default=0.001,
            metavar="LR",
            help="learning rate (default: 0.001)",
        )
        return parser

    def forward(self, x):
        """
        :param x: Input data

        :return: output - mnist digit label for the input image
        """
        batch_size = x.size()[0]

        # (b, 1, 28, 28) -> (b, 1*28*28)
        x = x.view(batch_size, -1)

        # layer 1 (b, 1*28*28) -> (b, 128)
        x = self.layer_1(x)
        x = torch.relu(x)

        # layer 2 (b, 128) -> (b, 256)
        x = self.layer_2(x)
        x = torch.relu(x)

        # layer 3 (b, 256) -> (b, 10)
        x = self.layer_3(x)

        # probability distribution over labels
        x = torch.log_softmax(x, dim=1)

        return x

    def cross_entropy_loss(self, logits, labels):
        """
        Initializes the loss function

        :return: output - Initialized cross entropy loss function
        """
        return F.nll_loss(logits, labels)

    def training_step(self, train_batch, batch_idx):
        """
        Training the data as batches and returns training loss on each batch
        :param train_batch: Batch data
        :param batch_idx: Batch indices

        :return: output - Training loss
        """
        x, y = train_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        _, y_hat = torch.max(logits, dim=1)
        self.train_acc(y_hat, y)
        self.log("train_acc", self.train_acc.compute())
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, val_batch, batch_idx):
        """
        Performs validation of data in batches

        :param val_batch: Batch data
        :param batch_idx: Batch indices

        :return: output - valid step loss
        """
        x, y = val_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        _, y_hat = torch.max(logits, dim=1)
        self.val_acc(y_hat, y)
        self.log("val_acc", self.val_acc.compute())
        self.log("val_loss", loss, sync_dist=True)

    def test_step(self, test_batch, batch_idx):
        """
        Performs test and computes the accuracy of the model

        :param test_batch: Batch data
        :param batch_idx: Batch indices

        :return: output - Testing accuracy
        """
        x, y = test_batch
        output = self.forward(x)
        _, y_hat = torch.max(output, dim=1)

        self.test_acc(y_hat, y)
        self.log("test_acc", self.test_acc.compute())

    def prepare_data(self):
        """
        Prepares the data for training and prediction
        """
        return {}

    def configure_optimizers(self):
        """
        Initializes the optimizer and learning rate scheduler

        :return: output - Initialized optimizer and scheduler
        """
        optimizer = torch.optim.Adam(self.parameters(), lr=self.args["lr"])
        scheduler = {
            "scheduler":
            torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                mode="min",
                factor=0.2,
                patience=2,
                min_lr=1e-6,
                verbose=True,
            ),
            "monitor":
            "val_loss",
        }
        return [optimizer], [scheduler]
Esempio n. 20
0
class HPALit(pl.LightningModule):
    def __init__(self, params):
        super().__init__()

        self.hparams = params
        self.lr = self.hparams.lr
        self.save_hyperparameters()
        self.model = Net(name=self.hparams.model)
        self.criterion = torch.nn.BCEWithLogitsLoss()

        self.train_accuracy = Accuracy(subset_accuracy=True)
        self.val_accuracy = Accuracy(subset_accuracy=True)
        self.train_recall = Recall()
        self.val_recall = Recall()

        self.train_df = pd.read_csv(
            f'data_preprocessing/train_fold_{self.hparams.fold}.csv')
        self.valid_df = pd.read_csv(
            f'data_preprocessing/valid_fold_{self.hparams.fold}.csv')

        self.train_transforms = A.Compose([
            A.Rotate(),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.Resize(width=self.hparams.img_size,
                     height=self.hparams.img_size),
            A.Normalize(),
            ToTensorV2(),
        ])

        self.valid_transforms = A.Compose([
            A.Resize(width=self.hparams.img_size,
                     height=self.hparams.img_size),
            A.Normalize(),
            ToTensorV2(),
        ])

        self.train_dataset = CellDataset(data_dir=self.hparams.data_dir,
                                         csv_file=self.train_df,
                                         transform=self.train_transforms)
        self.val_dataset = CellDataset(data_dir=self.hparams.data_dir,
                                       csv_file=self.valid_df,
                                       transform=self.valid_transforms)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.hparams.batch_size,
            shuffle=True,
            num_workers=self.hparams.n_workers,
            pin_memory=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.hparams.batch_size,
            shuffle=False,
            num_workers=self.hparams.n_workers,
            pin_memory=True,
        )

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                               'min',
                                                               factor=0.5,
                                                               patience=2,
                                                               eps=1e-6)
        lr_scheduler = {
            'scheduler': scheduler,
            'interval': 'epoch',
            'monitor': 'valid_loss_epoch'
        }

        return [optimizer], [lr_scheduler]

    def training_step(self, batch, batch_idx):
        x, y = batch
        pred = self.model(x)

        train_loss = self.criterion(pred, y)
        self.train_accuracy(torch.sigmoid(pred), y.type(torch.int))
        self.train_recall(torch.sigmoid(pred), y.type(torch.int))

        return {'loss': train_loss}

    def training_epoch_end(self, outputs):
        train_loss_epoch = torch.stack([x['loss'] for x in outputs]).mean()
        self.log('train_loss_epoch', train_loss_epoch)
        self.log('train_acc_epoch', self.train_accuracy.compute())
        self.log('train_recall_epoch', self.train_recall.compute())
        self.train_accuracy.reset()
        self.train_recall.reset()

    def validation_step(self, batch, batch_idx):
        x, y = batch
        pred = self.model(x)

        val_loss = self.criterion(pred, y)
        self.val_accuracy(torch.sigmoid(pred), y.type(torch.int))
        self.val_recall(torch.sigmoid(pred), y.type(torch.int))

        return {'valid_loss': val_loss}

    def validation_epoch_end(self, outputs):
        val_loss_epoch = torch.stack([x['valid_loss'] for x in outputs]).mean()
        self.log('valid_loss_epoch', val_loss_epoch)
        self.log('valid_acc_epoch', self.val_accuracy.compute())
        self.log('valid_recall_epoch', self.val_recall.compute())
        self.val_accuracy.reset()
        self.val_recall.reset()
class TransferLearning(LightningModule):
    def __init__(self, args):
        super(TransferLearning, self).__init__()

        self.hparams = args
        self.world_size = self.hparams.num_nodes * self.hparams.gpus

        self.train_set_transform, self.val_set_transform = grab_transforms(
            self.hparams)

        #Grab the correct model - only want the embeddings from the final layer!
        if self.hparams.saved_model_type == 'contrastive':
            saved_model = NoisyCLIP.load_from_checkpoint(
                self.hparams.checkpoint_path)
            self.backbone = saved_model.noisy_visual_encoder
        elif self.hparams.saved_model_type == 'baseline':
            saved_model = Baseline.load_from_checkpoint(
                self.hparams.checkpoint_path)
            self.backbone = saved_model.encoder.feature_extractor

        for param in self.backbone.parameters():
            param.requires_grad = False

        #Set up a classifier with the correct dimensions
        self.output = nn.Linear(self.hparams.emb_dim, self.hparams.num_classes)

        #Set up the criterion and stuff
        #(3) Set up our criterion - here we use reduction as "sum" so that we are able to average over all validation sets
        self.criterion = nn.CrossEntropyLoss(reduction="mean")

        self.train_top_1 = Accuracy(top_k=1)
        self.train_top_5 = Accuracy(top_k=5)

        self.val_top_1 = Accuracy(top_k=1)
        self.val_top_5 = Accuracy(top_k=5)

        self.test_top_1 = Accuracy(top_k=1)
        self.test_top_5 = Accuracy(top_k=5)

        #class INFECTED has label 0
        if self.hparams.dataset == 'COVID':
            self.val_auc = AUROC(pos_label=0)

            self.test_auc = AUROC(pos_label=0)

    def forward(self, x):
        #Grab the noisy image embeddings
        self.backbone.eval()
        with torch.no_grad():
            if self.hparams.encoder == "clip":
                noisy_embeddings = self.backbone(x.type(torch.float16)).float()
            elif self.hparams.encoder == "resnet":
                noisy_embeddings = self.backbone(x)

        return self.output(noisy_embeddings.flatten(1))

    def configure_optimizers(self):
        if not hasattr(self.hparams, 'weight_decay'):
            self.hparams.weight_decay = 0

        opt = torch.optim.Adam(self.output.parameters(),
                               lr=self.hparams.lr,
                               weight_decay=self.hparams.weight_decay)

        num_steps = self.hparams.max_epochs

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt,
                                                               T_max=num_steps)

        return [opt], [scheduler]

    def _grab_dataset(self, split):
        """
        Given a split ("train" or "val" or "test") and a dataset, returns the proper dataset.
        Dataset needed is defined in this object's hparams

        Args:
            split: the split to use in the dataset
        Returns:
            dataset: the desired dataset with the correct split
        """
        if self.hparams.dataset == "CIFAR10":
            if split == 'train':
                train = True
                transform = self.train_set_transform
            else:
                train = False
                transform = self.val_set_transform

            dataset = CIFAR10(root=self.hparams.dataset_dir,
                              train=train,
                              transform=transform,
                              download=True)

        elif self.hparams.dataset == "CIFAR100":
            if split == 'train':
                train = True
                transform = self.train_set_transform
            else:
                train = False
                transform = self.val_set_transform

            dataset = CIFAR100(root=self.hparams.dataset_dir,
                               train=train,
                               transform=transform,
                               download=True)

        elif self.hparams.dataset == 'STL10':
            if split == 'train':
                stlsplit = 'train'
                transform = self.train_set_transform
            else:
                stlsplit = 'test'
                transform = self.val_set_transform

            dataset = STL10(root=self.hparams.dataset_dir,
                            split=stlsplit,
                            transform=transform,
                            download=True)

        elif self.hparams.dataset == 'COVID':
            if split == 'train':
                covidsplit = 'train'
                transform = self.train_set_transform
            else:
                covidsplit = 'test'
                transform = self.val_set_transform

            dataset = torchvision.datasets.ImageFolder(
                root=self.hparams.dataset_dir + covidsplit,
                transform=transform)

        elif self.hparams.dataset == 'ImageNet100B' or self.hparams.dataset == 'imagenet-100B':
            if split == 'train':
                transform = self.train_set_transform
            else:
                split = 'val'
                transform = self.val_set_transform

            dataset = ImageNet100(root=self.hparams.dataset_dir,
                                  split=split,
                                  transform=transform)

        elif self.hparams.dataset == 'COVID':
            if split == 'train':
                covidsplit = 'train'
                transform = self.train_set_transform
            else:
                covidsplit = 'test'
                transform = self.val_set_transform

            dataset = torchvision.datasets.ImageFolder(
                root=self.hparams.dataset_dir + covidsplit,
                transform=transform)

        elif self.hparams.dataset == 'ImageNet100B' or self.hparams.dataset == 'imagenet-100B':
            if split == 'train':
                transform = self.train_set_transform
            else:
                split = 'val'
                transform = self.val_set_transform

            dataset = ImageNet100(root=self.hparams.dataset_dir,
                                  split=split,
                                  transform=transform)

        return dataset

    def train_dataloader(self):
        train_dataset = self._grab_dataset(split='train')

        N_train = len(train_dataset)
        if self.hparams.use_subset:
            train_dataset = few_shot_dataset(
                train_dataset,
                int(
                    np.ceil(N_train * self.hparams.subset_ratio /
                            self.hparams.num_classes)))

        train_dataloader = DataLoader(train_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\
                                        pin_memory=True, shuffle=True)

        return train_dataloader

    def val_dataloader(self):
        val_dataset = self._grab_dataset(split='val')

        N_val = len(val_dataset)

        #SET SHUFFLE TO TRUE SINCE AUROC FREAKS OUT IF IT GETS AN ALL-1 OR ALL-0 BATCH
        val_dataloader = DataLoader(val_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\
                                        pin_memory=True, shuffle=True)

        return val_dataloader

    def test_dataloader(self):
        test_dataset = self._grab_dataset(split='test')

        N_test = len(test_dataset)

        #SET SHUFFLE TO TRUE SINCE AUROC FREAKS OUT IF IT GETS AN ALL-1 OR ALL-0 BATCH
        test_dataloader = DataLoader(test_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\
                                        pin_memory=True, shuffle=True)

        return test_dataloader

    def training_step(self, batch, batch_idx):
        x, y = batch

        if batch_idx == 0 and self.current_epoch == 0:
            self.logger.experiment.add_image('Train_Sample', img_grid(x),
                                             self.current_epoch)

        logits = self.forward(x)

        loss = self.criterion(logits, y)

        self.log("train_loss", loss, prog_bar=False, on_step=True, \
                    on_epoch=True, logger=True, sync_dist=True, sync_dist_op='sum')

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch

        if batch_idx == 0 and self.current_epoch == 0:
            self.logger.experiment.add_image('Val_Sample', img_grid(x),
                                             self.current_epoch)

        logits = self.forward(x)
        pred_probs = logits.softmax(dim=-1)  #(N, num_classes)

        if self.hparams.dataset == 'COVID':
            positive_prob = pred_probs[:,
                                       0].flatten()  #class 0 is INFECTED label
            true_labels = y.flatten()

            self.val_auc.update(positive_prob, true_labels)

            self.log("val_auc", self.val_auc, prog_bar=False, logger=False)

        self.log("val_top_1",
                 self.val_top_1(pred_probs, y),
                 prog_bar=False,
                 logger=False)

        if self.hparams.dataset != 'COVID':
            self.log("val_top_5",
                     self.val_top_5(pred_probs, y),
                     prog_bar=False,
                     logger=False)

    def validation_epoch_end(self, outputs):
        self.log("val_top_1",
                 self.val_top_1.compute(),
                 prog_bar=True,
                 logger=True)

        if self.hparams.dataset != 'COVID':
            self.log("val_top_5",
                     self.val_top_5.compute(),
                     prog_bar=True,
                     logger=True)

        if self.hparams.dataset == 'COVID':
            self.log("val_auc",
                     self.val_auc.compute(),
                     prog_bar=True,
                     logger=True)

            self.val_auc.reset()

        self.val_top_1.reset()
        self.val_top_5.reset()

    def test_step(self, batch, batch_idx):
        x, y = batch

        logits = self.forward(x)
        pred_probs = logits.softmax(dim=-1)

        if self.hparams.dataset == 'COVID':
            positive_prob = pred_probs[:,
                                       0].flatten()  #class 0 is INFECTED label
            true_labels = y.flatten()

            self.test_auc.update(positive_prob, true_labels)

            self.log("test_auc", self.test_auc, prog_bar=False, logger=False)

        self.log("test_top_1",
                 self.test_top_1(pred_probs, y),
                 prog_bar=False,
                 logger=False)
        if self.hparams.dataset != 'COVID':
            self.log("test_top_5",
                     self.test_top_5(pred_probs, y),
                     prog_bar=False,
                     logger=False)

    def test_epoch_end(self, outputs):
        self.log("test_top_1",
                 self.test_top_1.compute(),
                 prog_bar=True,
                 logger=True)
        if self.hparams.dataset != 'COVID':
            self.log("test_top_5",
                     self.test_top_5.compute(),
                     prog_bar=True,
                     logger=True)

        if self.hparams.dataset == 'COVID':
            self.log("test_auc",
                     self.test_auc.compute(),
                     prog_bar=True,
                     logger=True)

            self.test_auc.reset()

        self.test_top_1.reset()
        self.test_top_5.reset()
Esempio n. 22
0
class ImageClassifier(pl.LightningModule):
    """
    Basic image classifier.
    """

    def __init__(self, cfg: Config) -> None:
        super().__init__()  # type: ignore

        self.logger: Union[LoggerCollection, WandbLogger, Any]
        self.wandb: Run

        self.cfg = cfg

        self.model = ConvNet(self.cfg)
        self.criterion = nn.CrossEntropyLoss()

        # Metrics
        self.train_acc = Accuracy()
        self.val_acc = Accuracy()

    # -----------------------------------------------------------------------------------------------
    # Default PyTorch Lightning hooks
    # -----------------------------------------------------------------------------------------------
    def on_fit_start(self) -> None:
        """
        Hook before `trainer.fit()`.

        Attaches current wandb run to `self.wandb`.
        """
        if isinstance(self.logger, LoggerCollection):
            for logger in self.logger:  # type: ignore
                if isinstance(logger, WandbLogger):
                    self.wandb = logger.experiment  # type: ignore
        elif isinstance(self.logger, WandbLogger):
            self.wandb = self.logger.experiment  # type: ignore

    def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
        """
        Hook on checkpoint saving.

        Adds config and RNG states to the checkpoint file.
        """
        checkpoint['cfg'] = self.cfg
        checkpoint['rng_torch'] = torch.default_generator.get_state()
        checkpoint['rng_numpy'] = np.random.get_state()
        checkpoint['rng_random'] = random.getstate()

    def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
        """
        Hook on checkpoint loading.

        Loads RNG states from the checkpoint file.
        """
        torch.default_generator.set_state(checkpoint['rng_torch'])
        np.random.set_state(checkpoint['rng_numpy'])
        random.setstate(checkpoint['rng_random'])

    # ----------------------------------------------------------------------------------------------
    # Optimizers
    # ----------------------------------------------------------------------------------------------
    def configure_optimizers(self) -> Union[Optimizer, Tuple[List[Optimizer], List[_LRScheduler]]]:  # type: ignore
        """
        Define system optimization procedure.

        See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers.

        Returns
        -------
        Union[Optimizer, Tuple[List[Optimizer], List[_LRScheduler]]]
            Single optimizer or a combination of optimizers with learning rate schedulers.
        """
        optimizer: Optimizer = instantiate(
            self.cfg.optim.optimizer,
            params=self.parameters(),
            _convert_='all'
        )

        if self.cfg.optim.scheduler is not None:
            scheduler: _LRScheduler = instantiate(  # type: ignore
                self.cfg.optim.scheduler,
                optimizer=optimizer,
                _convert_='all'
            )
            print(optimizer, scheduler)
            return [optimizer], [scheduler]
        else:
            print(optimizer)
            return optimizer

    # ----------------------------------------------------------------------------------------------
    # Forward
    # ----------------------------------------------------------------------------------------------
    def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore
        """
        Forward pass of the whole system.

        In this simple case just calls the main model.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor.

        Returns
        -------
        torch.Tensor
            Output tensor.
        """
        return self.model(x)

    # ----------------------------------------------------------------------------------------------
    # Loss
    # ----------------------------------------------------------------------------------------------
    def calculate_loss(self, outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Compute loss value of a batch.

        In this simple case just forwards computation to default `self.criterion`.

        Parameters
        ----------
        outputs : torch.Tensor
            Network outputs with shape (batch_size, n_classes).
        targets : torch.Tensor
            Targets (ground-truth labels) with shape (batch_size).

        Returns
        -------
        torch.Tensor
            Loss value.
        """
        return self.criterion(outputs, targets)

    # ----------------------------------------------------------------------------------------------
    # Training
    # ----------------------------------------------------------------------------------------------
    def training_step(self, batch: list[torch.Tensor], batch_idx: int) -> dict[str, torch.Tensor]:  # type: ignore
        """
        Train on a single batch with loss defined by `self.criterion`.

        Parameters
        ----------
        batch : list[torch.Tensor]
            Training batch.
        batch_idx : int
            Batch index.

        Returns
        -------
        dict[str, torch.Tensor]
            Metric values for a given batch.
        """
        inputs, targets = batch
        outputs = self(inputs)  # basically equivalent to self.forward(data)
        loss = self.calculate_loss(outputs, targets)

        self.train_acc(F.softmax(outputs, dim=1), targets)

        return {
            'loss': loss,
            # no need to return 'train_acc' here since it is always available as `self.train_acc`
        }

    def training_epoch_end(self, outputs: list[Any]) -> None:
        """
        Log training metrics.

        Parameters
        ----------
        outputs : list[Any]
            List of dictionaries returned by `self.training_step` with batch metrics.
        """
        step = self.current_epoch + 1

        metrics = {
            'epoch': float(step),
            'train_acc': float(self.train_acc.compute().item()),
        }

        # Average additional metrics over all batches
        for key in outputs[0]:
            metrics[key] = float(self._reduce(outputs, key).item())

        self.logger.log_metrics(metrics, step=step)

    def _reduce(self, outputs: list[Any], key: str):
        return torch.stack([out[key] for out in outputs]).mean().detach()

    # ----------------------------------------------------------------------------------------------
    # Validation
    # ----------------------------------------------------------------------------------------------
    def validation_step(self, batch: list[torch.Tensor], batch_idx: int) -> dict[str, Any]:  # type: ignore
        """
        Compute validation metrics.

        Parameters
        ----------
        batch : list[torch.Tensor]
            Validation batch.
        batch_idx : int
            Batch index.

        Returns
        -------
        dict[str, torch.Tensor]
            Metric values for a given batch.
        """

        inputs, targets = batch
        outputs = self(inputs)  # basically equivalent to self.forward(data)

        self.val_acc(F.softmax(outputs, dim=1), targets)

        return {
            # 'additional_metric': ...
            # no need to return 'val_acc' here since it is always available as `self.val_acc`
        }

    def validation_epoch_end(self, outputs: list[Any]) -> None:
        """
        Log validation metrics.

        Parameters
        ----------
        outputs : list[Any]
            List of dictionaries returned by `self.validation_step` with batch metrics.
        """
        step = self.current_epoch + 1 if not self.trainer.running_sanity_check else self.current_epoch  # type: ignore

        metrics = {
            'epoch': float(step),
            'val_acc': float(self.val_acc.compute().item()),
        }

        # Average additional metrics over all batches
        for key in outputs[0]:
            metrics[key] = float(self._reduce(outputs, key).item())

        self.logger.log_metrics(metrics, step=step)
class IrisClassification(pl.LightningModule):
    def __init__(self, **kwargs):
        super(IrisClassification, self).__init__()

        self.train_acc = Accuracy()
        self.val_acc = Accuracy()
        self.test_acc = Accuracy()
        self.args = kwargs

        self.fc1 = nn.Linear(4, 10)
        self.fc2 = nn.Linear(10, 10)
        self.fc3 = nn.Linear(10, 3)
        self.cross_entropy_loss = nn.CrossEntropyLoss()

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        return x

    @staticmethod
    def add_model_specific_args(parent_parser):
        """
        Add model specific arguments like learning rate

        :param parent_parser: Application specific parser

        :return: Returns the augmented arugument parser
        """
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument(
            "--lr",
            type=float,
            default=0.01,
            metavar="LR",
            help="learning rate (default: 0.001)",
        )
        return parser

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), self.args.get("lr", 0.01))

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        _, y_hat = torch.max(logits, dim=1)
        loss = self.cross_entropy_loss(logits, y)
        self.train_acc(y_hat, y)
        self.log(
            "train_acc",
            self.train_acc.compute(),
            on_step=False,
            on_epoch=True,
        )
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        _, y_hat = torch.max(logits, dim=1)
        loss = F.cross_entropy(logits, y)
        self.val_acc(y_hat, y)
        self.log("val_acc", self.val_acc.compute())
        self.log("val_loss", loss, sync_dist=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        _, y_hat = torch.max(logits, dim=1)
        self.test_acc(y_hat, y)
        self.log("test_acc", self.test_acc.compute())
class LitClassifier(pl.LightningModule):
    def __init__(self, hidden_dim=128, learning_rate=1e-3):
        super().__init__()
        self.save_hyperparameters()

        self.train_acc = Accuracy()
        self.val_acc = Accuracy(compute_on_step=False)
        self.test_acc = Accuracy(compute_on_step=False)
        self.example_input_array = torch.rand(10, 28 * 28)
        self.dims = (1, 28, 28)
        channels, width, height = self.dims

        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels * width * height, self.hparams.hidden_dim),
            nn.ReLU(), nn.Dropout(0.1),
            nn.Linear(self.hparams.hidden_dim, self.hparams.hidden_dim),
            nn.ReLU(), nn.Dropout(0.1), nn.Linear(self.hparams.hidden_dim, 10))

    def forward(self, x):
        x = self.model(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        acc = self.train_acc(y_hat, y)
        self.log("train_acc_step", acc)
        return {"loss": loss}

    def training_epoch_end(self, outputs):
        self.log("epoch_acc", self.train_acc.compute(), prog_bar=True)

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.val_acc(y_hat, y)
        self.log('valid_loss', loss)

    def validation_epoch_end(self, outputs):
        self.log("epoch_val_acc", self.val_acc.compute(), prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        self.test_acc(y_hat, y)
        loss = F.cross_entropy(y_hat, y)
        self.log('test_loss', loss)

    def test_epoch_end(self, outputs):
        self.log("test_acc", self.test_acc.compute(), prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(),
                                lr=self.hparams.learning_rate)

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--hidden_dim', type=int, default=128)
        parser.add_argument('--learning_rate', type=float, default=0.001)
        return parser
Esempio n. 25
0
class Rockfish(pl.LightningModule):
    def __init__(self, conf):
        super().__init__()
        self.save_hyperparameters(conf)

        self.ke = nn.Embedding(4, 2, max_norm=1.)

        self.conv1 = ConvBlock(3, 256, 13, stride=3, padding=6)
        self.conv2 = ConvBlock(256, 256, 7, stride=1, padding=3)
        self.conv3 = ConvBlock(256, 256, 3, stride=2, padding=1)

        self.pos_encoder = PositionalEncoding(256, self.hparams.dropout)

        encoder_layer = nn.TransformerEncoderLayer(256,
                                                   self.hparams.nhead,
                                                   self.hparams.dim_ff,
                                                   self.hparams.dropout,
                                                   activation='gelu')
        self.encoder = nn.TransformerEncoder(encoder_layer,
                                             self.hparams.nlayers)

        self.fc1 = nn.Linear(256, 1)

        self.train_acc = Accuracy()
        self.val_acc = Accuracy(compute_on_step=False)

    def forward(self, x, ref_k):
        ref_k = self.ke(ref_k).transpose(1, 2)

        x = torch.unsqueeze(x, 1)
        x = torch.cat((x, ref_k), 1)

        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)

        x = x.permute(2, 0, 1)
        x = self.pos_encoder(x)
        x = self.encoder(x)

        x = x.permute(1, 2, 0)
        x = F.adaptive_avg_pool1d(x, 1)
        x = x.squeeze(-1)

        return self.fc1(x)

    @staticmethod
    def add_module_specific_args(parent_parser):
        parser = argparse.ArgumentParser(parents=[parent_parser],
                                         add_help=False)

        parser.add_argument('--dropout', type=float, default=0.1)
        parser.add_argument('--nhead', type=int, default=8)
        parser.add_argument('--dim_ff', type=int, default=1024)
        parser.add_argument('--nlayers', type=int, default=6)

        parser.add_argument('--wd', type=float, default=1e-4)
        parser.add_argument('--lr', type=float, default=1e-4)
        parser.add_argument('--step_size_up', type=int, default=None)

        return parser

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(),
                                lr=self.hparams.lr,
                                weight_decay=self.hparams.wd,
                                eps=1e-6,
                                betas=(0.9, 0.98))

        if self.hparams.step_size_up is not None:
            step_size_up = self.hparams.step_size_up
        else:
            steps_per_epoch = self.train_ds_len // self.effective_batch_size
            step_size_up = 4 * steps_per_epoch

        scheduler = optim.lr_scheduler.CyclicLR(optimizer,
                                                self.hparams.lr / 10,
                                                self.hparams.lr,
                                                step_size_up=step_size_up,
                                                mode='triangular2',
                                                cycle_momentum=False)

        scheduler = {'scheduler': scheduler, 'interval': 'step'}
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        x, ref_k, y = batch

        out = self(x, ref_k).squeeze(1)
        loss = F.binary_cross_entropy_with_logits(out, y.float())

        self.log('train_loss', loss, prog_bar=True)
        self.log('train_batch_acc', self.train_acc(torch.sigmoid(out), y))

        return loss

    def training_epoch_end(self, outputs):
        self.log('train_epoch_acc', self.train_acc.compute())

    def validation_step(self, batch, batch_idx):
        x, ref_k, y = batch

        out = self(x, ref_k).squeeze(1)
        loss = F.binary_cross_entropy_with_logits(out, y.float())

        self.val_acc(torch.sigmoid(out), y)
        self.log('val_loss', loss)

    def validation_epoch_end(self, outputs):
        val_acc = self.val_acc.compute()
        self.log('val_acc', val_acc, prog_bar=True)
Esempio n. 26
0
class FastTextLSTMModel(LightningModule):
    """
    Run LSTM over tokens FastText embeddings and take final hidden state, add linear projection and dropout
    """
    def __init__(self, ft_embedding_dim, hidden_dim=64):
        super().__init__()

        self.lstm_layer = nn.LSTM(ft_embedding_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.dropout_layer = nn.Dropout(0.2)
        self.out_layer = nn.Linear(hidden_dim * 2, 1)

        self.loss = nn.BCEWithLogitsLoss()
        self.valid_accuracy = Accuracy()
        self.test_accuracy = Accuracy()

    def forward(self, embeddings, labels):
        """
        Forward pass
        :param embeddings: (batch_size, max_tokens_in_text, ft_embedding_dim)
        text -> ["hello", ",", "world",  ..] -> [9, 56, 72, ..] + padding or cutting to max sequence length
        :param labels: (batch_size, 1)
        :return: loss and logits
        """
        batch_size = embeddings.size(0)
        output, (final_hidden_state, final_cell_state) = self.lstm_layer(embeddings)
        final_hidden_state = final_hidden_state.transpose(0, 1)
        final_hidden_state = final_hidden_state.reshape(batch_size, -1)
        text_hidden = self.dropout_layer(final_hidden_state)
        logits = self.out_layer.forward(text_hidden)
        loss = self.loss(logits, labels.type_as(logits))
        return loss, logits

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return [optimizer]

    def training_step(self, batch, _):
        inputs, labels = batch
        loss, logits = self(inputs, labels)
        return loss

    def validation_step(self, batch, _):
        inputs, labels = batch
        val_loss, logits = self(inputs, labels)
        if torch.max(labels) == 1:
            pred = torch.sigmoid(logits)
        else:
            pred = torch.softmax(logits, 1)
        self.valid_accuracy.update(pred, labels.long())
        self.log("val_loss", val_loss, prog_bar=True)
        self.log("val_acc", self.valid_accuracy)

    def validation_epoch_end(self, outs):
        self.log("val_acc_epoch", self.valid_accuracy.compute(), prog_bar=True)

    def test_step(self, batch, _):
        inputs, labels = batch
        test_loss, logits = self(inputs, labels)
        if torch.max(labels) == 1:
            pred = torch.sigmoid(logits)
        else:
            pred = torch.softmax(logits, 1)
        self.test_accuracy.update(pred, labels.long())
        self.log("test_loss", test_loss, prog_bar=True)
        self.log("test_acc", self.test_accuracy)

    def test_epoch_end(self, outs):
        self.log("test_acc_epoch", self.test_accuracy.compute(), prog_bar=True)