def test_output_format_checked(): input_features, y_true = pd.DataFrame(data=[range(15)]), pd.Series( range(15)) with pytest.raises( ValueError, match= "Parameter output_format must be either text or dict. Received bar" ): explain_predictions(None, input_features, output_format="bar") with pytest.raises( ValueError, match= "Parameter output_format must be either text or dict. Received xml" ): explain_prediction(None, input_features=input_features, training_data=None, output_format="xml") input_features, y_true = pd.DataFrame(data=range(15)), pd.Series(range(15)) with pytest.raises( ValueError, match= "Parameter output_format must be either text or dict. Received foo" ): explain_predictions_best_worst(None, input_features, y_true=y_true, output_format="foo")
def test_output_format_checked(): input_features, y_true = pd.DataFrame(data=[range(15)]), pd.Series( range(15)) with pytest.raises( ValueError, match= "Parameter output_format must be either text, dict, or dataframe. Received bar" ): explain_predictions(pipeline=MagicMock(), input_features=input_features, y=None, indices_to_explain=0, output_format="bar") with pytest.raises( ValueError, match= "Parameter output_format must be either text, dict, or dataframe. Received xml" ): explain_prediction(pipeline=MagicMock(), input_features=input_features, y=None, index_to_explain=0, output_format="xml") input_features, y_true = pd.DataFrame(data=range(15)), pd.Series(range(15)) with pytest.raises( ValueError, match= "Parameter output_format must be either text, dict, or dataframe. Received foo" ): explain_predictions_best_worst(pipeline=MagicMock(), input_features=input_features, y_true=y_true, output_format="foo")
def test_explain_prediction_value_error(test_features): with pytest.raises( ValueError, match= "features must be stored in a dataframe or datatable with exactly one row." ): explain_prediction(None, input_features=test_features, training_data=None)
def test_explain_prediction_errors(): with pytest.raises(ValueError, match="Explained indices should be between"): explain_prediction(MagicMock(), pd.DataFrame({"a": [0, 1, 2, 3, 4]}), y=None, index_to_explain=5) with pytest.raises(ValueError, match="Explained indices should be between"): explain_prediction(MagicMock(), pd.DataFrame({"a": [0, 1, 2, 3, 4]}), y=None, index_to_explain=-1)
def test_explain_prediction(mock_normalize_shap_values, mock_compute_shap_values, problem_type, output_format, shap_values, normalized_shap_values, answer, input_type): mock_compute_shap_values.return_value = shap_values mock_normalize_shap_values.return_value = normalized_shap_values pipeline = MagicMock() pipeline.problem_type = problem_type pipeline.classes_ = ["class_0", "class_1", "class_2"] # By the time we call transform, we are looking at only one row of the input data. pipeline.compute_estimator_features.return_value = pd.DataFrame({ "a": [10], "b": [20], "c": [30], "d": [40] }) features = pd.DataFrame({"a": [1], "b": [2]}) training_data = pd.DataFrame() if input_type == "ww": features = ww.DataTable(features) training_data = ww.DataTable(training_data) table = explain_prediction(pipeline, features, output_format=output_format, top_k=2, training_data=training_data) if isinstance(table, str): compare_two_tables(table.splitlines(), answer) else: assert table == answer
def test_explain_prediction(mock_normalize_shap_values, mock_compute_shap_values, problem_type, output_format, shap_values, normalized_shap_values, answer, input_type): mock_compute_shap_values.return_value = shap_values mock_normalize_shap_values.return_value = normalized_shap_values pipeline = MagicMock() pipeline.problem_type = problem_type pipeline.classes_ = ["class_0", "class_1", "class_2"] # By the time we call transform, we are looking at only one row of the input data. pipeline.compute_estimator_features.return_value = ww.DataTable( pd.DataFrame({ "a": [10], "b": [20], "c": [30], "d": [40] })) features = pd.DataFrame({"a": [1], "b": [2]}) if input_type == "ww": features = ww.DataTable(features) with warnings.catch_warnings(record=True) as warn: warnings.simplefilter("always") table = explain_prediction(pipeline, features, y=None, output_format=output_format, index_to_explain=0, top_k_features=2) assert str(warn[0].message).startswith( "The explain_prediction function will be deleted in the next release" ) if isinstance(table, str): compare_two_tables(table.splitlines(), answer) elif isinstance(table, pd.DataFrame): pd.testing.assert_frame_equal(table, answer) else: assert table == answer