def test_stat_scores(pred, target, expected_tp, expected_fp, expected_tn,
                     expected_fn):
    tp, fp, tn, fn = stat_scores(pred, target, class_index=4)

    assert tp.item() == expected_tp
    assert fp.item() == expected_fp
    assert tn.item() == expected_tn
    assert fn.item() == expected_fn
Пример #2
0
    def speed_test(self, data):

        error = 100
        speed_limit = 13.89
        speed_thershold = 5.55556

        pred = []
        y = []
        tpr = []
        fpr = []
        while error >= 0:
            pred = []
            y = []
            for record in data:
                d, e = record
                pred.append(1 if e > error else 0)
                y.append(1 if d - speed_limit > speed_thershold else 0)
            # print('Accuracy',
            #       self.accuracy(tensor(pred, device=device("cuda", 0)), tensor(y, device=device("cuda", 0))).item())
            # print('Precision',
            #       self.Precision(tensor(pred, device=device("cuda", 0)), tensor(y, device=device("cuda", 0))).item())
            # print('Recall',
            #       self.Recall(tensor(pred, device=device("cuda", 0)), tensor(y, device=device("cuda", 0))).item())
            # fpr, tpr, thresholds = self.ROC(tensor(pred, device=device("cuda", 0)),
            #                                 tensor(y, device=device("cuda", 0)))
            tps, fps, tns, fns, sups = stat_scores(
                tensor(pred, device=device("cuda", 0)),
                tensor(y, device=device("cuda", 0)),
                class_index=1)
            print(error, tps.item(), fps.item(), tns.item(), fns.item(),
                  tps.item() / (tps.item() + fns.item()),
                  fps.item() / (fps.item() + tns.item()))
            tpr.append(tps.item() / (tps.item() + fns.item()))
            fpr.append(fps.item() / (fps.item() + tns.item()))
            error -= 0.25
        plt.plot(fpr, tpr)
        plt.grid(True)
        plt.show()