Пример #1
0
    def __init__(self, device, num_classes, topK=3):
        if not pl:
            return

        self.device = device
        self.topK = topK

        # https://github.com/PyTorchLightning/metrics/blob/master/torchmetrics/classification/f_beta.py#L221
        #mdmc_average = "samplewise"
        mdmc_average = "global"

        val_metrics = {
            'hamming_dist':
            HammingDistance() if HammingDistance is not None else None,
            'iou':
            IoU(num_classes=num_classes),
            'auroc':
            AUROC(num_classes=num_classes),
            'f1':
            F1(num_classes=num_classes,
               multilabel=True,
               mdmc_average=mdmc_average),
            'avg_precision':
            AveragePrecision(num_classes=num_classes),
            #'acc': Accuracy(num_classes=num_classes, mdmc_average = mdmc_average)
        }

        for k in range(1, topK + 1):
            val_metrics["top%d" % k] = Accuracy(top_k=k)
            val_metrics["top%d_f1" % k] = F1(top_k=k)

        self.val_metrics = torch.nn.ModuleDict(val_metrics).to(self.device)

        self.class_names = list(range(num_classes))
        self.label_binarizer = MultiLabelBinarizer(classes=self.class_names)
Пример #2
0
def test_v1_5_metric_auc_auroc():
    AUC.__init__._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        AUC()

    ROC.__init__._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        ROC()

    AUROC.__init__._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        AUROC()

    x = torch.tensor([0, 1, 2, 3])
    y = torch.tensor([0, 1, 2, 2])
    auc._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        assert auc(x, y) == torch.tensor(4.)

    preds = torch.tensor([0, 1, 2, 3])
    target = torch.tensor([0, 1, 1, 1])
    roc._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        fpr, tpr, thrs = roc(preds, target, pos_label=1)
    assert torch.equal(fpr, torch.tensor([0., 0., 0., 0., 1.]))
    assert torch.allclose(tpr, torch.tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000]), atol=1e-4)
    assert torch.equal(thrs, torch.tensor([4, 3, 2, 1, 0]))

    preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34])
    target = torch.tensor([0, 0, 1, 1, 1])
    auroc._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        assert auroc(preds, target) == torch.tensor(0.5)
Пример #3
0
    def __init__(self, prefix, loss_type: str, threshold=0.5, top_k=[1, 5, 10], n_classes: int = None,
                 multilabel: bool = None, metrics=["precision", "recall", "top_k", "accuracy"]):
        super().__init__()

        self.loss_type = loss_type.upper()
        self.threshold = threshold
        self.n_classes = n_classes
        self.multilabel = multilabel
        self.top_ks = top_k
        self.prefix = prefix

        self.metrics = {}
        for metric in metrics:
            if "precision" == metric:
                self.metrics[metric] = Precision(average=True, is_multilabel=multilabel)
            elif "recall" == metric:
                self.metrics[metric] = Recall(average=True, is_multilabel=multilabel)

            elif "top_k" in metric:
                if n_classes:
                    top_k = [k for k in top_k if k < n_classes]

                if multilabel:
                    self.metrics[metric] = TopKMultilabelAccuracy(k_s=top_k)
                else:
                    self.metrics[metric] = TopKCategoricalAccuracy(k=max(int(np.log(n_classes)), 1),
                                                                   output_transform=None)
            elif "macro_f1" in metric:
                self.metrics[metric] = F1(num_classes=n_classes, average="macro", multilabel=multilabel)
            elif "micro_f1" in metric:
                self.metrics[metric] = F1(num_classes=n_classes, average="micro", multilabel=multilabel)
            elif "mse" == metric:
                self.metrics[metric] = MeanSquaredError()
            elif "auroc" == metric:
                self.metrics[metric] = AUROC(num_classes=n_classes)
            elif "avg_precision" in metric:
                self.metrics[metric] = AveragePrecision(num_classes=n_classes, )


            elif "accuracy" in metric:
                self.metrics[metric] = Accuracy(top_k=int(metric.split("@")[-1]) if "@" in metric else None)

            elif "ogbn" in metric:
                self.metrics[metric] = OGBNodeClfMetrics(NodeEvaluator(metric))
            elif "ogbg" in metric:
                self.metrics[metric] = OGBNodeClfMetrics(GraphEvaluator(metric))
            elif "ogbl" in metric:
                self.metrics[metric] = OGBLinkPredMetrics(LinkEvaluator(metric))
            else:
                print(f"WARNING: metric {metric} doesn't exist")

            # Needed to add the PytorchGeometric methods as Modules, so they'll be on the correct CUDA device during training
            if isinstance(self.metrics[metric], torchmetrics.metric.Metric):
                setattr(self, metric, self.metrics[metric])

        self.reset_metrics()
    def __init__(self, args):
        super(TransferLearning, self).__init__()

        self.hparams = args
        self.world_size = self.hparams.num_nodes * self.hparams.gpus

        self.train_set_transform, self.val_set_transform = grab_transforms(
            self.hparams)

        #Grab the correct model - only want the embeddings from the final layer!
        if self.hparams.saved_model_type == 'contrastive':
            saved_model = NoisyCLIP.load_from_checkpoint(
                self.hparams.checkpoint_path)
            self.backbone = saved_model.noisy_visual_encoder
        elif self.hparams.saved_model_type == 'baseline':
            saved_model = Baseline.load_from_checkpoint(
                self.hparams.checkpoint_path)
            self.backbone = saved_model.encoder.feature_extractor

        for param in self.backbone.parameters():
            param.requires_grad = False

        #Set up a classifier with the correct dimensions
        self.output = nn.Linear(self.hparams.emb_dim, self.hparams.num_classes)

        #Set up the criterion and stuff
        #(3) Set up our criterion - here we use reduction as "sum" so that we are able to average over all validation sets
        self.criterion = nn.CrossEntropyLoss(reduction="mean")

        self.train_top_1 = Accuracy(top_k=1)
        self.train_top_5 = Accuracy(top_k=5)

        self.val_top_1 = Accuracy(top_k=1)
        self.val_top_5 = Accuracy(top_k=5)

        self.test_top_1 = Accuracy(top_k=1)
        self.test_top_5 = Accuracy(top_k=5)

        #class INFECTED has label 0
        if self.hparams.dataset == 'COVID':
            self.val_auc = AUROC(pos_label=0)

            self.test_auc = AUROC(pos_label=0)
Пример #5
0
 def __init__(self):
     super().__init__()
     self.metrics_singleclass = nn.ModuleDict({
         'loss':
         BCEWithLogitsLoss(),
         'acc':
         SoftMaxWrapper(Accuracy(), multiclass=False),
         'auc':
         AUROC(),
     })
     self.metrics_multiclass = nn.ModuleDict({
         'loss':
         CrossEntropyLoss(),
         'acc':
         SoftMaxWrapper(Accuracy(), multiclass=True),
     })
Пример #6
0
    def evaluate_metrics(self, batch_labels,
                         batch_predictions) -> Dict[str, torch.Tensor]:

        if self.hparams.mode == 'classification':
            # transformers convention is to output classification as two neurons.
            # In order to convert this to a class label we take the argmax.
            probs = nn.Softmax(dim=1)(batch_predictions)
            preds = torch.argmax(probs, dim=1).squeeze()
            probs_of_positive_class = probs[:, 1]
            batch_labels = batch_labels.squeeze()
        else:
            preds = batch_predictions

        if self.hparams.mode == 'classification':
            metrics = {
                'AUROC':
                lambda: AUROC()(probs_of_positive_class, batch_labels),
                'AveragePrecision':
                lambda: AveragePrecision()
                (probs_of_positive_class, batch_labels),
                'Accuracy':
                lambda: Accuracy()(preds, batch_labels),
            }
        else:
            metrics = {
                'MAE': lambda: MAE()(preds, batch_labels),
                'RMSE': lambda: RMSE()(preds, batch_labels),
                'MSE': lambda: MSE()(preds, batch_labels),
                # sklearn metrics work the other way round metric_fn(y_true, y_pred)
                'R2': lambda: r2_score(batch_labels.cpu(), preds.cpu()),
            }

        out = {}
        for name, callable_metric in metrics.items():
            try:
                out[name] = callable_metric().item()
            except Exception as e:
                logger.info(f'unable to calculate {name} metric')
                logger.info(e)
                out[name] = np.nan

        return out
class TransferLearning(LightningModule):
    def __init__(self, args):
        super(TransferLearning, self).__init__()

        self.hparams = args
        self.world_size = self.hparams.num_nodes * self.hparams.gpus

        self.train_set_transform, self.val_set_transform = grab_transforms(
            self.hparams)

        #Grab the correct model - only want the embeddings from the final layer!
        if self.hparams.saved_model_type == 'contrastive':
            saved_model = NoisyCLIP.load_from_checkpoint(
                self.hparams.checkpoint_path)
            self.backbone = saved_model.noisy_visual_encoder
        elif self.hparams.saved_model_type == 'baseline':
            saved_model = Baseline.load_from_checkpoint(
                self.hparams.checkpoint_path)
            self.backbone = saved_model.encoder.feature_extractor

        for param in self.backbone.parameters():
            param.requires_grad = False

        #Set up a classifier with the correct dimensions
        self.output = nn.Linear(self.hparams.emb_dim, self.hparams.num_classes)

        #Set up the criterion and stuff
        #(3) Set up our criterion - here we use reduction as "sum" so that we are able to average over all validation sets
        self.criterion = nn.CrossEntropyLoss(reduction="mean")

        self.train_top_1 = Accuracy(top_k=1)
        self.train_top_5 = Accuracy(top_k=5)

        self.val_top_1 = Accuracy(top_k=1)
        self.val_top_5 = Accuracy(top_k=5)

        self.test_top_1 = Accuracy(top_k=1)
        self.test_top_5 = Accuracy(top_k=5)

        #class INFECTED has label 0
        if self.hparams.dataset == 'COVID':
            self.val_auc = AUROC(pos_label=0)

            self.test_auc = AUROC(pos_label=0)

    def forward(self, x):
        #Grab the noisy image embeddings
        self.backbone.eval()
        with torch.no_grad():
            if self.hparams.encoder == "clip":
                noisy_embeddings = self.backbone(x.type(torch.float16)).float()
            elif self.hparams.encoder == "resnet":
                noisy_embeddings = self.backbone(x)

        return self.output(noisy_embeddings.flatten(1))

    def configure_optimizers(self):
        if not hasattr(self.hparams, 'weight_decay'):
            self.hparams.weight_decay = 0

        opt = torch.optim.Adam(self.output.parameters(),
                               lr=self.hparams.lr,
                               weight_decay=self.hparams.weight_decay)

        num_steps = self.hparams.max_epochs

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt,
                                                               T_max=num_steps)

        return [opt], [scheduler]

    def _grab_dataset(self, split):
        """
        Given a split ("train" or "val" or "test") and a dataset, returns the proper dataset.
        Dataset needed is defined in this object's hparams

        Args:
            split: the split to use in the dataset
        Returns:
            dataset: the desired dataset with the correct split
        """
        if self.hparams.dataset == "CIFAR10":
            if split == 'train':
                train = True
                transform = self.train_set_transform
            else:
                train = False
                transform = self.val_set_transform

            dataset = CIFAR10(root=self.hparams.dataset_dir,
                              train=train,
                              transform=transform,
                              download=True)

        elif self.hparams.dataset == "CIFAR100":
            if split == 'train':
                train = True
                transform = self.train_set_transform
            else:
                train = False
                transform = self.val_set_transform

            dataset = CIFAR100(root=self.hparams.dataset_dir,
                               train=train,
                               transform=transform,
                               download=True)

        elif self.hparams.dataset == 'STL10':
            if split == 'train':
                stlsplit = 'train'
                transform = self.train_set_transform
            else:
                stlsplit = 'test'
                transform = self.val_set_transform

            dataset = STL10(root=self.hparams.dataset_dir,
                            split=stlsplit,
                            transform=transform,
                            download=True)

        elif self.hparams.dataset == 'COVID':
            if split == 'train':
                covidsplit = 'train'
                transform = self.train_set_transform
            else:
                covidsplit = 'test'
                transform = self.val_set_transform

            dataset = torchvision.datasets.ImageFolder(
                root=self.hparams.dataset_dir + covidsplit,
                transform=transform)

        elif self.hparams.dataset == 'ImageNet100B' or self.hparams.dataset == 'imagenet-100B':
            if split == 'train':
                transform = self.train_set_transform
            else:
                split = 'val'
                transform = self.val_set_transform

            dataset = ImageNet100(root=self.hparams.dataset_dir,
                                  split=split,
                                  transform=transform)

        elif self.hparams.dataset == 'COVID':
            if split == 'train':
                covidsplit = 'train'
                transform = self.train_set_transform
            else:
                covidsplit = 'test'
                transform = self.val_set_transform

            dataset = torchvision.datasets.ImageFolder(
                root=self.hparams.dataset_dir + covidsplit,
                transform=transform)

        elif self.hparams.dataset == 'ImageNet100B' or self.hparams.dataset == 'imagenet-100B':
            if split == 'train':
                transform = self.train_set_transform
            else:
                split = 'val'
                transform = self.val_set_transform

            dataset = ImageNet100(root=self.hparams.dataset_dir,
                                  split=split,
                                  transform=transform)

        return dataset

    def train_dataloader(self):
        train_dataset = self._grab_dataset(split='train')

        N_train = len(train_dataset)
        if self.hparams.use_subset:
            train_dataset = few_shot_dataset(
                train_dataset,
                int(
                    np.ceil(N_train * self.hparams.subset_ratio /
                            self.hparams.num_classes)))

        train_dataloader = DataLoader(train_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\
                                        pin_memory=True, shuffle=True)

        return train_dataloader

    def val_dataloader(self):
        val_dataset = self._grab_dataset(split='val')

        N_val = len(val_dataset)

        #SET SHUFFLE TO TRUE SINCE AUROC FREAKS OUT IF IT GETS AN ALL-1 OR ALL-0 BATCH
        val_dataloader = DataLoader(val_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\
                                        pin_memory=True, shuffle=True)

        return val_dataloader

    def test_dataloader(self):
        test_dataset = self._grab_dataset(split='test')

        N_test = len(test_dataset)

        #SET SHUFFLE TO TRUE SINCE AUROC FREAKS OUT IF IT GETS AN ALL-1 OR ALL-0 BATCH
        test_dataloader = DataLoader(test_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers,\
                                        pin_memory=True, shuffle=True)

        return test_dataloader

    def training_step(self, batch, batch_idx):
        x, y = batch

        if batch_idx == 0 and self.current_epoch == 0:
            self.logger.experiment.add_image('Train_Sample', img_grid(x),
                                             self.current_epoch)

        logits = self.forward(x)

        loss = self.criterion(logits, y)

        self.log("train_loss", loss, prog_bar=False, on_step=True, \
                    on_epoch=True, logger=True, sync_dist=True, sync_dist_op='sum')

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch

        if batch_idx == 0 and self.current_epoch == 0:
            self.logger.experiment.add_image('Val_Sample', img_grid(x),
                                             self.current_epoch)

        logits = self.forward(x)
        pred_probs = logits.softmax(dim=-1)  #(N, num_classes)

        if self.hparams.dataset == 'COVID':
            positive_prob = pred_probs[:,
                                       0].flatten()  #class 0 is INFECTED label
            true_labels = y.flatten()

            self.val_auc.update(positive_prob, true_labels)

            self.log("val_auc", self.val_auc, prog_bar=False, logger=False)

        self.log("val_top_1",
                 self.val_top_1(pred_probs, y),
                 prog_bar=False,
                 logger=False)

        if self.hparams.dataset != 'COVID':
            self.log("val_top_5",
                     self.val_top_5(pred_probs, y),
                     prog_bar=False,
                     logger=False)

    def validation_epoch_end(self, outputs):
        self.log("val_top_1",
                 self.val_top_1.compute(),
                 prog_bar=True,
                 logger=True)

        if self.hparams.dataset != 'COVID':
            self.log("val_top_5",
                     self.val_top_5.compute(),
                     prog_bar=True,
                     logger=True)

        if self.hparams.dataset == 'COVID':
            self.log("val_auc",
                     self.val_auc.compute(),
                     prog_bar=True,
                     logger=True)

            self.val_auc.reset()

        self.val_top_1.reset()
        self.val_top_5.reset()

    def test_step(self, batch, batch_idx):
        x, y = batch

        logits = self.forward(x)
        pred_probs = logits.softmax(dim=-1)

        if self.hparams.dataset == 'COVID':
            positive_prob = pred_probs[:,
                                       0].flatten()  #class 0 is INFECTED label
            true_labels = y.flatten()

            self.test_auc.update(positive_prob, true_labels)

            self.log("test_auc", self.test_auc, prog_bar=False, logger=False)

        self.log("test_top_1",
                 self.test_top_1(pred_probs, y),
                 prog_bar=False,
                 logger=False)
        if self.hparams.dataset != 'COVID':
            self.log("test_top_5",
                     self.test_top_5(pred_probs, y),
                     prog_bar=False,
                     logger=False)

    def test_epoch_end(self, outputs):
        self.log("test_top_1",
                 self.test_top_1.compute(),
                 prog_bar=True,
                 logger=True)
        if self.hparams.dataset != 'COVID':
            self.log("test_top_5",
                     self.test_top_5.compute(),
                     prog_bar=True,
                     logger=True)

        if self.hparams.dataset == 'COVID':
            self.log("test_auc",
                     self.test_auc.compute(),
                     prog_bar=True,
                     logger=True)

            self.test_auc.reset()

        self.test_top_1.reset()
        self.test_top_5.reset()