예제 #1
0
    def test_group_names_do_not_match_groups(self):
        target = GroupMetricSet()

        target.model_type = GroupMetricSet.BINARY_CLASSIFICATION
        target.y_true = [0, 1, 0, 0]
        target.y_pred = [1, 1, 1, 0]
        target.groups = [0, 1, 1, 0]

        # Some wholly synthetic metrics
        firstMetric = GroupMetricResult()
        firstMetric.overall = 0.2
        firstMetric.by_group[0] = 0.3
        firstMetric.by_group[1] = 0.4
        secondMetric = GroupMetricResult()
        secondMetric.overall = 0.6
        secondMetric.by_group[0] = 0.7
        secondMetric.by_group[1] = 0.8
        metric_dict = {
            GroupMetricSet.GROUP_ACCURACY_SCORE: firstMetric,
            GroupMetricSet.GROUP_MISS_RATE: secondMetric
        }

        target.metrics = metric_dict

        target.group_names = ['First']
        target.group_title = "Some string"
        with pytest.raises(ValueError) as exception_context:
            target.check_consistency()
        expected = "Count of group_names not the same as the number of unique groups"
        assert exception_context.value.args[0] == expected
예제 #2
0
    def test_length_mismatch_groups(self):
        target = GroupMetricSet()
        target.y_true = [0, 1, 0, 1]
        target.y_pred = [0, 1, 1, 0]
        target.groups = [0, 1, 1]

        with pytest.raises(ValueError) as exception_context:
            target.check_consistency()
        assert exception_context.value.args[
            0] == "Lengths of y_true, y_pred and groups must match"
예제 #3
0
    def test_metric_has_bad_groups(self):
        target = GroupMetricSet()
        target.y_true = [0, 1, 1, 1, 0]
        target.y_pred = [1, 1, 1, 0, 0]
        target.groups = [0, 1, 0, 1, 1]
        bad_metric = GroupMetricResult()
        bad_metric.by_group[0] = 0.1
        metric_dict = {'bad_metric': bad_metric}
        target.metrics = metric_dict

        with pytest.raises(ValueError) as exception_context:
            target.check_consistency()
        expected = "The groups for metric bad_metric do not match the groups property"
        assert exception_context.value.args[0] == expected