コード例 #1
0
def test_v1_3_0_deprecated_metrics():
    from pytorch_lightning.metrics.functional.classification import to_onehot
    with pytest.deprecated_call(match='will be removed in v1.3'):
        to_onehot(torch.tensor([1, 2, 3]))

    from pytorch_lightning.metrics.functional.classification import to_categorical
    with pytest.deprecated_call(match='will be removed in v1.3'):
        to_categorical(torch.tensor([[0.2, 0.5], [0.9, 0.1]]))

    from pytorch_lightning.metrics.functional.classification import get_num_classes
    with pytest.deprecated_call(match='will be removed in v1.3'):
        get_num_classes(pred=torch.tensor([0, 1]), target=torch.tensor([1, 1]))

    x_binary = torch.tensor([0, 1, 2, 3])
    y_binary = torch.tensor([0, 1, 2, 3])

    from pytorch_lightning.metrics.functional.classification import roc
    with pytest.deprecated_call(match='will be removed in v1.3'):
        roc(pred=x_binary, target=y_binary)

    from pytorch_lightning.metrics.functional.classification import _roc
    with pytest.deprecated_call(match='will be removed in v1.3'):
        _roc(pred=x_binary, target=y_binary)

    x_multy = torch.tensor([
        [0.85, 0.05, 0.05, 0.05],
        [0.05, 0.85, 0.05, 0.05],
        [0.05, 0.05, 0.85, 0.05],
        [0.05, 0.05, 0.05, 0.85],
    ])
    y_multy = torch.tensor([0, 1, 3, 2])

    from pytorch_lightning.metrics.functional.classification import multiclass_roc
    with pytest.deprecated_call(match='will be removed in v1.3'):
        multiclass_roc(pred=x_multy, target=y_multy)

    from pytorch_lightning.metrics.functional.classification import average_precision
    with pytest.deprecated_call(match='will be removed in v1.3'):
        average_precision(pred=x_binary, target=y_binary)

    from pytorch_lightning.metrics.functional.classification import precision_recall_curve
    with pytest.deprecated_call(match='will be removed in v1.3'):
        precision_recall_curve(pred=x_binary, target=y_binary)

    from pytorch_lightning.metrics.functional.classification import multiclass_precision_recall_curve
    with pytest.deprecated_call(match='will be removed in v1.3'):
        multiclass_precision_recall_curve(pred=x_multy, target=y_multy)

    from pytorch_lightning.metrics.functional.reduction import reduce
    with pytest.deprecated_call(match='will be removed in v1.3'):
        reduce(torch.tensor([0, 1, 1, 0]), 'sum')

    from pytorch_lightning.metrics.functional.reduction import class_reduce
    with pytest.deprecated_call(match='will be removed in v1.3'):
        class_reduce(
            torch.randint(1, 10, (50, )).float(),
            torch.randint(10, 20, (50, )).float(),
            torch.randint(1, 100, (50, )).float())
コード例 #2
0
    def forward(
            self, pred: torch.Tensor,
            target: torch.Tensor,
            sample_weight: Optional[Sequence] = None,
    ) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
        """
        Actual metric computation

        Args:
            pred: predicted probability for each label
            target: groundtruth labels
            sample_weight: Weights for each sample defining the sample's impact on the score

        Return:
            tuple: A tuple consisting of one tuple per class, holding false positive rate, true positive rate and thresholds

        """
        return multiclass_roc(pred=pred,
                              target=target,
                              sample_weight=sample_weight,
                              num_classes=self.num_classes)
コード例 #3
0
    def test_epoch_end(self, outputs):

        if self.incorrect_type != 'boundary':
            #####  Confusion Matrix  #####
            conf_mtx = confusion_matrix(
                torch.cat([b['preds'] for b in outputs]),
                torch.cat([b['labels'] for b in outputs]),
                normalize=False,
                num_classes=5)

            #####  Normalized Confusion Matrix  #####
            conf_mtx_normalized = confusion_matrix(
                torch.cat([b['preds'] for b in outputs]),
                torch.cat([b['labels'] for b in outputs]),
                normalize=True,
                num_classes=5)

            #####  Weighted Confusion Matrix  #####
            conf_mtx_weighted = conf_mtx.clone()
            for c, w in enumerate(self.weights):
                conf_mtx_weighted[c, :] *= w

            #####  ACCURACY  #####
            accuracy = torch.diag(conf_mtx).sum() / conf_mtx.sum()
            accuracy_weighted = torch.diag(
                conf_mtx_weighted).sum() / conf_mtx_weighted.sum()

            #####  AUC_SCORE  #####
            roc_results = multiclass_roc(
                torch.cat([b['logits'] for b in outputs]),
                torch.cat([b['labels'] for b in outputs]),
                num_classes=5)
            AUROC_str = ''
            AUROC_list = {}
            for cls, roc_cls in enumerate(roc_results):
                fpr, tpr, threshold = roc_cls
                self.logger.experiment.add_scalar(f'val_AUC[{cls}]',
                                                  auc(fpr, tpr),
                                                  self.current_epoch)
                AUROC_str += '\tAUC_SCORE[CLS %d]: \t%.4f\n' % (cls,
                                                                auc(fpr, tpr))
                AUROC_list['AUC_SCORE[CLS %d]' % cls] = auc(fpr, tpr)

            #####  F1  #####
            f1_score = f1(torch.cat([b['preds'] for b in outputs]),
                          torch.cat([b['labels'] for b in outputs]),
                          num_classes=5)

            #####  Average Precision  #####
            # TO DO

            #####  PRINT RESULTS  #####
            print('=' * 100)
            print(
                f'[MODEL NAME]: {self.model_name} \t [INCORRECT TYPE]: {self.incorrect_type}'
            )
            print('RESULTS:')
            print('\tAccuracy: \t\t%.4f' % accuracy)
            print('\tWeighted Accuracy: \t%.4f' % accuracy_weighted)
            print('\tF1 Score: \t\t%.4f' % f1_score)
            print(AUROC_str)

            self.metrics_result[self.incorrect_type][self.model_name] = {
                'Accuracy': round(float(accuracy), 4),
                'Weighted Accuracy': round(float(accuracy_weighted), 4),
                'F1_score': round(float(f1_score), 4)
            }
            for key, val in AUROC_list.items():
                self.metrics_result[self.incorrect_type][
                    self.model_name].update({key: round(float(val), 4)})
            print('Confusion Matrix')
            fig, ax = plt.subplots(figsize=(4, 4))
            sn.heatmap(conf_mtx.cpu(),
                       annot=True,
                       cbar=False,
                       annot_kws={"size": 15},
                       fmt='g',
                       cmap='mako')
            plt.show()
            fig, ax = plt.subplots(figsize=(4, 4))
            sn.heatmap(conf_mtx_normalized.cpu(),
                       annot=True,
                       cbar=False,
                       annot_kws={"size": 12},
                       fmt='.2f',
                       cmap='mako')
            plt.show()
            print('=' * 100)

        else:
            tol_correct = 0
            tol_samples = 0
            tol_drop = 0
            for batch in outputs:
                preds = batch['preds']
                labels = batch['labels']
                slope_id = batch['doc_ids']
                ##### Change lizhong's code ####
                for idx, slop_idx in enumerate(slope_id):
                    agree_by_user = bool(
                        slope_df[slope_df['slope_id'] == slop_idx.item()]
                        ['sentiment_correct'].values[0])
                    possible_classes = slope_df[
                        slope_df['slope_id'] ==
                        slop_idx.item()]['label_from_score'].values[0]

                    pred_class = preds[idx]
                    # difference between pred and true label
                    diff = torch.abs(pred_class - possible_classes)

                    # if correct label
                    if agree_by_user:  # True
                        if diff == 0:
                            # correct prediction
                            tol_correct += 1
                            tol_samples += 1
                        elif diff == 1:
                            # discard
                            tol_drop += 1
                        else:
                            # wrong prediction
                            tol_samples += 1
                    # if incorrect label
                    else:  # False
                        if diff == 0:
                            # wrong
                            tol_samples += 1
                        elif diff == 1:
                            # discard
                            tol_drop += 1
                        else:
                            # Correct
                            tol_correct += 1
                            tol_samples += 1

            boundary_accuracy = round(tol_correct / tol_samples, 4)
            self.metrics_result[self.incorrect_type][self.model_name] = {}
            self.metrics_result[self.incorrect_type][
                self.model_name]['boundary_acc'] = boundary_accuracy
            self.metrics_result[self.incorrect_type][
                self.model_name]['total_drop_sample'] = tol_drop
            print('=' * 100)
            print(
                f'[MODEL NAME]: {self.model_name} \t [INCORRECT TYPE]: {self.incorrect_type}'
            )
            print('\tBoundary Accuracy: \t\t%.4f' % boundary_accuracy)
            print('\tDrop Total Sample: \t\t%.4f' % tol_drop)