def test_calibration_display_non_binary(pyplot, iris_data, constructor_name):
    X, y = iris_data
    clf = DecisionTreeClassifier()
    clf.fit(X, y)
    y_prob = clf.predict_proba(X)

    if constructor_name == "from_estimator":
        msg = "to be a binary classifier, but got"
        with pytest.raises(ValueError, match=msg):
            CalibrationDisplay.from_estimator(clf, X, y)
    else:
        msg = "y should be a 1d array, got an array of shape"
        with pytest.raises(ValueError, match=msg):
            CalibrationDisplay.from_predictions(y, y_prob)
Пример #2
0
def multi_class_calibration(y_true,
                            y_prob,
                            n_bins=5,
                            strategy='uniform',
                            names=None,
                            ref_line=True,
                            ax=None,
                            **kwargs):
    """
    Displays a multi-class Calibration Curve (one line per class in one-v-rest setup)

    :param y_true: True Labels: note that labels must be sequential starting from 0
    :param y_prob: Predicted Probabilities
    :param n_bins: Number of bins to use (see CalibrationDisplay.from_predictions)
    :param strategy: Strategy for bins (see CalibrationDisplay.from_predictions)
    :param names:  Class names to use
    :param ref_line: Whether to plot reference line (see CalibrationDisplay.from_predictions)
    :param ax: Axes to draw on (see CalibrationDisplay.from_predictions)
    :param kwargs: Keyword arguments passed on to plot
    :return: Dict of Calibration Displays (by name)
    """
    # Iterate over classes
    names = utils.default(names, np.arange(y_prob.shape[1]))
    displays = {}
    for cls, name in zip(range(y_prob.shape[1]), names):
        _y_true = (y_true == cls).astype(
            int)  # Get positive class for this label
        _y_prob = y_prob[:, cls]  # Get probability assigned to this class
        displays[name] = CalibrationDisplay.from_predictions(_y_true,
                                                             _y_prob,
                                                             n_bins=n_bins,
                                                             strategy=strategy,
                                                             name=name,
                                                             ref_line=ref_line,
                                                             ax=ax,
                                                             **kwargs)
    return displays