예제 #1
0
def test_output_format_checked():
    input_features, y_true = pd.DataFrame(data=[range(15)]), pd.Series(
        range(15))
    with pytest.raises(
            ValueError,
            match=
            "Parameter output_format must be either text or dict. Received bar"
    ):
        explain_predictions(None, input_features, output_format="bar")
    with pytest.raises(
            ValueError,
            match=
            "Parameter output_format must be either text or dict. Received xml"
    ):
        explain_prediction(None,
                           input_features=input_features,
                           training_data=None,
                           output_format="xml")

    input_features, y_true = pd.DataFrame(data=range(15)), pd.Series(range(15))
    with pytest.raises(
            ValueError,
            match=
            "Parameter output_format must be either text or dict. Received foo"
    ):
        explain_predictions_best_worst(None,
                                       input_features,
                                       y_true=y_true,
                                       output_format="foo")
예제 #2
0
def test_output_format_checked():
    input_features, y_true = pd.DataFrame(data=[range(15)]), pd.Series(
        range(15))
    with pytest.raises(
            ValueError,
            match=
            "Parameter output_format must be either text, dict, or dataframe. Received bar"
    ):
        explain_predictions(pipeline=MagicMock(),
                            input_features=input_features,
                            y=None,
                            indices_to_explain=0,
                            output_format="bar")
    with pytest.raises(
            ValueError,
            match=
            "Parameter output_format must be either text, dict, or dataframe. Received xml"
    ):
        explain_prediction(pipeline=MagicMock(),
                           input_features=input_features,
                           y=None,
                           index_to_explain=0,
                           output_format="xml")

    input_features, y_true = pd.DataFrame(data=range(15)), pd.Series(range(15))
    with pytest.raises(
            ValueError,
            match=
            "Parameter output_format must be either text, dict, or dataframe. Received foo"
    ):
        explain_predictions_best_worst(pipeline=MagicMock(),
                                       input_features=input_features,
                                       y_true=y_true,
                                       output_format="foo")
예제 #3
0
def test_explain_prediction_value_error(test_features):
    with pytest.raises(
            ValueError,
            match=
            "features must be stored in a dataframe or datatable with exactly one row."
    ):
        explain_prediction(None,
                           input_features=test_features,
                           training_data=None)
예제 #4
0
def test_explain_prediction_errors():
    with pytest.raises(ValueError,
                       match="Explained indices should be between"):
        explain_prediction(MagicMock(),
                           pd.DataFrame({"a": [0, 1, 2, 3, 4]}),
                           y=None,
                           index_to_explain=5)

    with pytest.raises(ValueError,
                       match="Explained indices should be between"):
        explain_prediction(MagicMock(),
                           pd.DataFrame({"a": [0, 1, 2, 3, 4]}),
                           y=None,
                           index_to_explain=-1)
예제 #5
0
def test_explain_prediction(mock_normalize_shap_values,
                            mock_compute_shap_values, problem_type,
                            output_format, shap_values, normalized_shap_values,
                            answer, input_type):
    mock_compute_shap_values.return_value = shap_values
    mock_normalize_shap_values.return_value = normalized_shap_values
    pipeline = MagicMock()
    pipeline.problem_type = problem_type
    pipeline.classes_ = ["class_0", "class_1", "class_2"]

    # By the time we call transform, we are looking at only one row of the input data.
    pipeline.compute_estimator_features.return_value = pd.DataFrame({
        "a": [10],
        "b": [20],
        "c": [30],
        "d": [40]
    })
    features = pd.DataFrame({"a": [1], "b": [2]})
    training_data = pd.DataFrame()
    if input_type == "ww":
        features = ww.DataTable(features)
        training_data = ww.DataTable(training_data)
    table = explain_prediction(pipeline,
                               features,
                               output_format=output_format,
                               top_k=2,
                               training_data=training_data)

    if isinstance(table, str):
        compare_two_tables(table.splitlines(), answer)
    else:
        assert table == answer
예제 #6
0
def test_explain_prediction(mock_normalize_shap_values,
                            mock_compute_shap_values, problem_type,
                            output_format, shap_values, normalized_shap_values,
                            answer, input_type):
    mock_compute_shap_values.return_value = shap_values
    mock_normalize_shap_values.return_value = normalized_shap_values
    pipeline = MagicMock()
    pipeline.problem_type = problem_type
    pipeline.classes_ = ["class_0", "class_1", "class_2"]

    # By the time we call transform, we are looking at only one row of the input data.
    pipeline.compute_estimator_features.return_value = ww.DataTable(
        pd.DataFrame({
            "a": [10],
            "b": [20],
            "c": [30],
            "d": [40]
        }))
    features = pd.DataFrame({"a": [1], "b": [2]})
    if input_type == "ww":
        features = ww.DataTable(features)

    with warnings.catch_warnings(record=True) as warn:
        warnings.simplefilter("always")
        table = explain_prediction(pipeline,
                                   features,
                                   y=None,
                                   output_format=output_format,
                                   index_to_explain=0,
                                   top_k_features=2)
        assert str(warn[0].message).startswith(
            "The explain_prediction function will be deleted in the next release"
        )
    if isinstance(table, str):
        compare_two_tables(table.splitlines(), answer)
    elif isinstance(table, pd.DataFrame):
        pd.testing.assert_frame_equal(table, answer)
    else:
        assert table == answer