Ejemplo n.º 1
0
def test_v1_5_metric_classif_mix():
    ConfusionMatrix.__init__._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        ConfusionMatrix(num_classes=1)

    FBeta.__init__._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        FBeta(num_classes=1)

    F1.__init__._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        F1(num_classes=1)

    HammingDistance.__init__._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        HammingDistance()

    StatScores.__init__._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        StatScores()

    target = torch.tensor([1, 1, 0, 0])
    preds = torch.tensor([0, 1, 0, 0])
    confusion_matrix._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        assert torch.equal(
            confusion_matrix(preds, target, num_classes=2).float(),
            torch.tensor([[2.0, 0.0], [1.0, 1.0]]))

    target = torch.tensor([0, 1, 2, 0, 1, 2])
    preds = torch.tensor([0, 2, 1, 0, 0, 1])
    fbeta._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        assert torch.allclose(fbeta(preds, target, num_classes=3, beta=0.5),
                              torch.tensor(0.3333),
                              atol=1e-4)

    f1._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        assert torch.allclose(f1(preds, target, num_classes=3),
                              torch.tensor(0.3333),
                              atol=1e-4)

    target = torch.tensor([[0, 1], [1, 1]])
    preds = torch.tensor([[0, 1], [0, 1]])
    hamming_distance._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        assert hamming_distance(preds, target) == torch.tensor(0.25)

    preds = torch.tensor([1, 0, 2, 1])
    target = torch.tensor([1, 1, 2, 0])
    stat_scores._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        assert torch.equal(stat_scores(preds, target, reduce="micro"),
                           torch.tensor([2, 2, 6, 2, 4]))
Ejemplo n.º 2
0
    def __init__(self, lr: float, num_classes: int, *args, **kwargs):
        super().__init__()
        self.lr = lr
        self.num_classes = num_classes
        self.train_acc = pl.metrics.Accuracy()
        self.test_acc = pl.metrics.Accuracy()
        self.val_acc = pl.metrics.Accuracy()
        self.confusion = ConfusionMatrix(self.num_classes)

        self.init_layers(*args, **kwargs)
        self.save_hyperparameters()
Ejemplo n.º 3
0
def metrics(logits, targets):
    preds = torch.argmax(logits, dim=1)
    cm = ConfusionMatrix()(preds, targets)
    if len(cm.size()) == 0:
        idx = preds[0].item()
        n = cm.item()
        cm = torch.zeros((2, 2))
        cm[idx, idx] = n
    # cm_{i,j} is the number of observations in group i that were predicted in group j
    tp, tn, fn, fp = cm[1, 1], cm[0, 0], cm[0, 1], cm[1, 0]
    metrics = {'tp': tp, 'fp': fp, 'fn': fn, 'tn': tn}
    return metrics
Ejemplo n.º 4
0
    def print_pycm(self):
        cm = ConfusionMatrix(self.gts, self.preds)

        for cls_name in cm.classes:
            print('============' * 5)
            print('Class Name : [{}]'.format(cls_name)) # Class name 에 대한걸 positive라고 생각하고 tp, fn, fp, tn 구하기
            TP = cm.TP[cls_name]
            TN = cm.TN[cls_name]
            FP = cm.FP[cls_name]
            FN = cm.FN[cls_name]
            acc = cm.ACC[cls_name]
            pre = cm.PPV[cls_name]
            rec = cm.TPR[cls_name]
            spec = cm.TNR[cls_name]

            if acc is 'None':
                acc = 0.0
            if pre is 'None':
                pre = 0.0
            if rec is 'None':
                rec = 0.0
            if spec is 'None':
                spec = 0.0

            print('TP : {}, FN : {}, FP : {}, TN : {}'.format(TP, FN, FP, TN))
            print('Accuracy : {:.4f}, Precision : {:.4f}, Recall(Sensitivity) : {:.4f}, Specificity : {:.4f}'.
                format(acc, pre, rec, spec))
            print('============' * 5)
        cm.print_matrix()
        auc_list = list(cm.AUC.values())
        print('AUROC : ', auc_list)
        
        auroc_mean = 0
        for auc in auc_list:
            if auc is 'None':
                auroc_mean += 0
            else:
                auroc_mean += auc
        auroc_mean = auroc_mean / len(auc_list)
        print("AUROC mean: {:.4f}".format(auroc_mean))
        
        self.gts = []
        self.preds = []
Ejemplo n.º 5
0
def accuracy_test(classifiers: T.Dict[str, T.Callable[[int],
                                                      FewshotClassifier]],
                  collections: T.Dict[str, T.Dict[str, U.data.Dataset]],
                  as_bit=True,
                  as_half=False,
                  as_cuda=True,
                  n_support=10):

    cast_tensor = (lambda t: t.cuda()) if as_cuda else (lambda t: t)
    cast_integer = (lambda t: torch.tensor(t, device="cuda")
                    if as_cuda else (lambda t: t))

    for classifier_name, classifier_constructor in classifiers.items():
        classifier = classifier_constructor(416 if as_bit else 52)

        if as_half:
            classifier.half()
        if as_cuda:
            classifier.cuda()

        for collection_name, datasets in collections.items():

            confusion_matrix = ConfusionMatrix(num_classes=len(datasets))

            key_list = list(datasets.keys())
            value_list = [datasets[k] for k in key_list]

            queries_ds = U.data.dmap(sum(value_list[1:], value_list[0]),
                                     transform=cast_tensor)
            labels_ds = U.data.dmap(
                [x for i, ds in enumerate(value_list) for x in [i] * len(ds)],
                transform=cast_integer)

            support_ds_list = []
            for dataset in datasets:
                support_ds = U.data.dconst()
Ejemplo n.º 6
0
 def on_validation_epoch_start(self):
     self.confmat_metric = ConfusionMatrix(num_classes=6)
     if self.num_class_linear_flag is not None:
         self.confmat_linear = ConfusionMatrix(num_classes=6)
Ejemplo n.º 7
0
 def __init__(self, config, trial=None):
     if hasattr(config.net_config, "use_detector_number"):
         self.use_detector_number = config.net_config.use_detector_number
         if self.use_detector_number:
             if not hasattr(config.net_config, "num_detectors"):
                 raise IOError(
                     "net config must contain 'num_detectors' property if 'use_detector_number' set to true"
                 )
             config.system_config.n_samples = config.system_config.n_samples + 3
             if config.net_config.num_detectors == 308:
                 self.detector_num_factor_x = 1. / 13
                 self.detector_num_factor_y = 1. / 10
             else:
                 raise IOError("num detectors " +
                               str(config.net_config.num_detector) +
                               " not supported")
     else:
         self.use_detector_number = False
     super(LitWaveform, self).__init__(config, trial)
     if config.net_config.net_class.endswith("RecurrentWaveformNet"):
         self.squeeze_index = 2
     else:
         self.squeeze_index = 1
     self.test_has_phys = False
     if hasattr(self.config.dataset_config, "test_dataset_params"):
         if self.config.dataset_config.test_dataset_params.label_name == "phys" and not hasattr(
                 self.config.dataset_config.test_dataset_params,
                 "label_index"):
             self.test_has_phys = True
     if hasattr(self.config.dataset_config, "calgroup"):
         calgroup = self.config.dataset_config.calgroup
     else:
         calgroup = None
     if hasattr(self.config.dataset_config.dataset_params, "label_index"):
         self.target_index = self.config.dataset_config.dataset_params.label_index
     else:
         self.target_index = None
     self.use_accuracy = False
     if config.net_config.criterion_class == "L1Loss":
         metric_name = "mean absolute error"
     elif config.net_config.criterion_class == "MSELoss":
         metric_name = "mean squared error"
     elif config.net_config.criterion_class.startswith(
             "BCE") or config.net_config.criterion_class.startswith(
                 "CrossEntropy"):
         self.use_accuracy = True
         metric_name = "Accuracy"
     else:
         metric_name = "?"
     eval_params = {}
     if hasattr(config, "evaluation_config"):
         eval_params = DictionaryUtility.to_dict(config.evaluation_config)
     self.evaluator = TensorEvaluator(self.logger,
                                      calgroup=calgroup,
                                      target_has_phys=self.test_has_phys,
                                      target_index=self.target_index,
                                      metric_name=metric_name,
                                      **eval_params)
     self.loss_no_reduce = self.criterion_class(
         *config.net_config.criterion_params, reduction="none")
     if self.use_accuracy:
         self.accuracy = Accuracy()
         self.confusion = ConfusionMatrix(2)
         self.softmax = Softmax(dim=1)
Ejemplo n.º 8
0
 def __init__(self):
     self.confusion_matrix = ConfusionMatrix(7)
     self.has_data = False
     self.is_logging = False