Ejemplo n.º 1
0
    def test_precision(self):
        y_true = np.array([1, 1, 0, 0, 0, 1, 1, 0, 0, 1])
        y_predict = np.array(
            [0.57, 0.70, 0.25, 0.30, 0.46, 0.62, 0.76, 0.46, 0.35, 0.56])
        dict_score = {"0.4": {0: 1, 1: 0.71}, "0.6": {0: 0.71, 1: 1}}

        eva = Evaluation("binary")
        split_thresholds = [0.4, 0.6]

        prec_values = eva.precision(y_true,
                                    y_predict,
                                    thresholds=split_thresholds)
        fix_prec_values = []
        for prec_value in prec_values:
            fix_prec_value = [round(pos, 2) for pos in prec_value]
            fix_prec_values.append(fix_prec_value)

        for i in range(len(split_thresholds)):
            score_0 = dict_score[str(split_thresholds[i])][0]
            score_1 = dict_score[str(split_thresholds[i])][1]

            pos_prec_value = fix_prec_values[i]
            self.assertEqual(len(pos_prec_value), 2)
            self.assertFloatEqual(score_0, pos_prec_value[0])
            self.assertFloatEqual(score_1, pos_prec_value[1])
Ejemplo n.º 2
0
    def test_precision(self):

        y_true = np.array([1, 1, 0, 0, 0, 1, 1, 0, 0, 1])
        y_predict = np.array(
            [0.57, 0.70, 0.25, 0.30, 0.46, 0.62, 0.76, 0.46, 0.35, 0.56])

        true_score = [[0.5, 1.0], [0.5, 1.0], [0.5, 1.0], [0.5,
                                                           1.0], [0.5, 1.0],
                      [0.5, 1.0], [0.5, 1.0], [0.5, 1.0], [0.5, 1.0],
                      [0.5, 1.0], [0.5, 1.0], [0.5, 1.0],
                      [0.5555555555555556, 1.0], [0.5555555555555556, 1.0],
                      [0.5555555555555556, 1.0], [0.5555555555555556, 1.0],
                      [0.5555555555555556, 1.0], [0.5555555555555556, 1.0],
                      [0.5555555555555556, 1.0], [0.5555555555555556, 1.0],
                      [0.5555555555555556, 1.0], [0.5555555555555556, 1.0],
                      [0.5555555555555556, 1.0], [0.625, 1.0], [0.625, 1.0],
                      [0.625, 1.0], [0.625, 1.0], [0.625, 1.0], [0.625, 1.0],
                      [0.625, 1.0], [0.625, 1.0], [0.625, 1.0], [0.625, 1.0],
                      [0.625, 1.0], [0.7142857142857143, 1.0],
                      [0.7142857142857143, 1.0], [0.7142857142857143, 1.0],
                      [0.7142857142857143, 1.0], [0.7142857142857143, 1.0],
                      [0.7142857142857143, 1.0], [0.7142857142857143, 1.0],
                      [0.7142857142857143, 1.0], [0.7142857142857143, 1.0],
                      [0.7142857142857143, 1.0], [0.7142857142857143, 1.0],
                      [0.8333333333333334, 1.0], [0.8333333333333334, 1.0],
                      [0.8333333333333334, 1.0], [0.8333333333333334, 1.0],
                      [0.8333333333333334, 1.0], [0.8333333333333334, 1.0],
                      [0.8333333333333334, 1.0], [0.8333333333333334, 1.0],
                      [0.8333333333333334, 1.0], [0.8333333333333334, 1.0],
                      [0.8333333333333334, 1.0], [1.0, 1.0], [1.0, 1.0],
                      [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0,
                                                           1.0], [1.0, 1.0],
                      [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0],
                      [1.0, 0.7142857142857143], [1.0, 0.7142857142857143],
                      [1.0, 0.7142857142857143], [1.0, 0.7142857142857143],
                      [1.0, 0.7142857142857143], [1.0, 0.7142857142857143],
                      [1.0, 0.7142857142857143], [1.0, 0.7142857142857143],
                      [1.0, 0.7142857142857143], [1.0, 0.7142857142857143],
                      [1.0, 0.7142857142857143], [1.0, 0.625], [1.0, 0.625],
                      [1.0, 0.625], [1.0, 0.625], [1.0, 0.625], [1.0, 0.625],
                      [1.0, 0.625], [1.0, 0.625], [1.0, 0.625], [1.0, 0.625],
                      [1.0, 0.625], [1.0, 0.5555555555555556],
                      [1.0, 0.5555555555555556], [1.0, 0.5555555555555556],
                      [1.0, 0.5555555555555556], [1.0, 0.5555555555555556],
                      [1.0, 0.5555555555555556], [1.0, 0.5555555555555556],
                      [1.0, 0.5555555555555556], [1.0, 0.5555555555555556],
                      [1.0, 0.5555555555555556], [1.0, 0.5555555555555556],
                      [0.0, 0.5]]

        eva = Evaluation()
        eva._init_model(EvaluateParam(eval_type=consts.BINARY, pos_label=1))

        rs = eva.precision(y_true, y_predict)

        self.assertListEqual(true_score, rs[0])
Ejemplo n.º 3
0
    def test_multi_precision(self):
        y_true = np.array([1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5])
        y_predict = np.array([1, 1, 2, 2, 3, 2, 1, 1, 1, 1, 3, 3, 3, 3, 2, 4, 4, 4, 4, 4, 6, 6, 6, 6, 6])
        gt_score = {2: 0.25, 3: 0.8, 5: 0, 6: 0, 7: -1}

        eva = Evaluation("multi")
        result_filter = [2, 3, 5, 6, 7]
        precision_scores = eva.precision(y_true, y_predict, result_filter=result_filter)
        for i in range(len(result_filter)):
            score = gt_score[result_filter[i]]
            self.assertFloatEqual(score, precision_scores[result_filter[i]])
Ejemplo n.º 4
0
    def test_multi_precision(self):
        y_true = np.array([
            1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5,
            5, 5, 5
        ])
        y_predict = np.array([
            1, 1, 2, 2, 3, 2, 1, 1, 1, 1, 3, 3, 3, 3, 2, 4, 4, 4, 4, 4, 6, 6,
            6, 6, 6
        ])
        gt_score = [0.33333333, 0.25, 0.8, 1., 0., 0.]

        eva = Evaluation()
        eva._init_model(EvaluateParam(eval_type=consts.MULTY, pos_label=1))
        precision_scores, _ = eva.precision(y_true, y_predict)
        for a, b in zip(precision_scores, gt_score):
            assert round(a, 2) == round(b, 2)
Ejemplo n.º 5
0
    def evaluate(self, labels, pred_prob, pred_labels,
                 evaluate_param: EvaluateParam):
        LOGGER.info("@ start host evaluate")
        eva = Evaluation()
        predict_res = None
        if evaluate_param.eval_type == consts.BINARY:
            eva.eval_type = consts.BINARY
            predict_res = pred_prob
        elif evaluate_param.eval_type == consts.MULTY:
            eva.eval_type = consts.MULTY
            predict_res = pred_labels
        else:
            LOGGER.warning(
                "unknown classification type, return None as evaluation results"
            )

        eva.pos_label = evaluate_param.pos_label
        precision_res, cuts, thresholds = eva.precision(
            labels=labels, pred_scores=predict_res)

        LOGGER.info("@ evaluation report:" + str(precision_res))
        return precision_res