Пример #1
0
def test_v1_4_0_deprecated_metrics():
    from pytorch_lightning.metrics.functional.classification import stat_scores_multiple_classes
    with pytest.deprecated_call(match='will be removed in v1.4'):
        stat_scores_multiple_classes(pred=torch.tensor([0, 1]),
                                     target=torch.tensor([0, 1]))

    from pytorch_lightning.metrics.functional.classification import iou
    with pytest.deprecated_call(match='will be removed in v1.4'):
        iou(torch.randint(0, 2, (10, 3, 3)), torch.randint(0, 2, (10, 3, 3)))

    from pytorch_lightning.metrics.functional.classification import recall
    with pytest.deprecated_call(match='will be removed in v1.4'):
        recall(torch.randint(0, 2, (10, 3, 3)),
               torch.randint(0, 2, (10, 3, 3)))

    from pytorch_lightning.metrics.functional.classification import precision
    with pytest.deprecated_call(match='will be removed in v1.4'):
        precision(torch.randint(0, 2, (10, 3, 3)),
                  torch.randint(0, 2, (10, 3, 3)))

    from pytorch_lightning.metrics.functional.classification import precision_recall
    with pytest.deprecated_call(match='will be removed in v1.4'):
        precision_recall(torch.randint(0, 2, (10, 3, 3)),
                         torch.randint(0, 2, (10, 3, 3)))

    # Testing deprecation of class_reduction arg in the *new* precision
    from pytorch_lightning.metrics.functional import precision
    with pytest.deprecated_call(match='will be removed in v1.4'):
        precision(torch.randint(0, 2, (10, )),
                  torch.randint(0, 2, (10, )),
                  class_reduction='micro')

    # Testing deprecation of class_reduction arg in the *new* recall
    from pytorch_lightning.metrics.functional import recall
    with pytest.deprecated_call(match='will be removed in v1.4'):
        recall(torch.randint(0, 2, (10, )),
               torch.randint(0, 2, (10, )),
               class_reduction='micro')

    from pytorch_lightning.metrics.functional.classification import auc
    with pytest.deprecated_call(match='will be removed in v1.4'):
        auc(torch.rand(10, ).sort().values, torch.rand(10, ))

    from pytorch_lightning.metrics.functional.classification import auroc
    with pytest.deprecated_call(match='will be removed in v1.4'):
        auroc(torch.rand(10, ), torch.randint(0, 2, (10, )))

    from pytorch_lightning.metrics.functional.classification import multiclass_auroc
    with pytest.deprecated_call(match='will be removed in v1.4'):
        multiclass_auroc(torch.rand(20, 5).softmax(dim=-1),
                         torch.randint(0, 5, (20, )),
                         num_classes=5)

    from pytorch_lightning.metrics.functional.classification import auc_decorator
    with pytest.deprecated_call(match='will be removed in v1.4'):
        auc_decorator()

    from pytorch_lightning.metrics.functional.classification import multiclass_auc_decorator
    with pytest.deprecated_call(match='will be removed in v1.4'):
        multiclass_auc_decorator()
Пример #2
0
    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Actual metric computation

        Args:
            pred: predicted labels
            target: ground truth labels

        Return:
            A Tensor with the classification score.
        """
        return recall(pred=pred,
                      target=target,
                      num_classes=self.num_classes,
                      reduction=self.reduction)
Пример #3
0
    def test_step(self, batch: tuple, batch_nb: int, *args, **kwargs) -> dict:
        """ 
        Runs one training step. This usually consists in the forward function followed
            by the loss function.
        
        :param batch: The output of your dataloader. 
        :param batch_nb: Integer displaying which batch this is

        Returns:
            - dictionary containing the loss and the metrics to be added to the lightning logger.
        """
        inputs, targets = batch
        model_out = self.forward(**inputs)
        loss_val = self.loss(model_out, targets)

        # in DP mode (default) make sure if result is scalar, there's another dim in the beginning
        if self.trainer.use_dp or self.trainer.use_ddp2:
            loss_val = loss_val.unsqueeze(0)
            
        self.log('test_loss',loss_val)


        y_hat=model_out['logits']
        labels_hat = torch.argmax(y_hat, dim=1)
        y=targets['labels']


        f1 = metrics.f1_score(labels_hat,y,    class_reduction='weighted')
        prec =metrics.precision(labels_hat,y,  class_reduction='weighted')
        recall = metrics.recall(labels_hat,y,  class_reduction='weighted')
        acc = metrics.accuracy(labels_hat,y,   class_reduction='weighted')
        # auroc = metrics.multiclass_auroc(labels_hat, y)

        self.log('test_batch_prec',prec)
        self.log('test_batch_f1',f1)
        self.log('test_batch_recall',recall)
        self.log('test_batch_weighted_acc', acc)
        # self.log('test_batch_auc_roc', auroc)

        from pytorch_lightning.metrics.functional import confusion_matrix
        # TODO CHANGE THIS
        # return (labels_hat, y)
        cm = confusion_matrix(preds = labels_hat,target=y,normalize=None, num_classes=50)
        # cm = confusion_matrix(preds = labels_hat,target=y,normalize=False, num_classes=len(y.unique()))
        self.test_conf_matrices.append(cm)
Пример #4
0
    def validation_step(self, batch: tuple, batch_nb: int, *args, **kwargs) -> dict:
        """ Similar to the training step but with the model in eval mode.

        Returns:
            - dictionary passed to the validation_end function.
        """
        inputs, targets = batch
        model_out = self.forward(**inputs)
        loss_val = self.loss(model_out, targets)

        y = targets["labels"]
        y_hat = model_out["logits"]

        # acc
        labels_hat = torch.argmax(y_hat, dim=1)
        val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
        val_acc = torch.tensor(val_acc)

        if self.on_gpu:
            val_acc = val_acc.cuda(loss_val.device.index)

        # in DP mode (default) make sure if result is scalar, there's another dim in the beginning
        if self.trainer.use_dp or self.trainer.use_ddp2:
            loss_val = loss_val.unsqueeze(0)
            val_acc = val_acc.unsqueeze(0)


        self.log('val_loss',loss_val)

        f1 = metrics.f1_score(labels_hat, y,class_reduction='weighted')
        prec =metrics.precision(labels_hat, y,class_reduction='weighted')
        recall = metrics.recall(labels_hat, y,class_reduction='weighted')
        acc = metrics.accuracy(labels_hat, y,class_reduction='weighted')
        # auroc = metrics.multiclass_auroc(y_hat,y)

        self.log('val_prec',prec)
        self.log('val_f1',f1)
        self.log('val_recall',recall)
        self.log('val_acc_weighted', acc)
def test_precision_recall(pred, target, expected_prec, expected_rec):
    prec = precision(pred, target, class_reduction='none')
    rec = recall(pred, target, class_reduction='none')

    assert torch.allclose(torch.tensor(expected_prec).to(prec), prec)
    assert torch.allclose(torch.tensor(expected_rec).to(rec), rec)