Esempio n. 1
0
    def test_compute_report_multiple_df(
        global_clusters,
        global_y,
        global_predicted_proba,
        global_summary_df,
        global_aggregate_summary_df,
        global_summary_dfs,
        global_ys,
        global_predicted_probas,
        global_aggregate_summary_dfs_eval_set,
    ):
        inspector = InspectorShap(model=MockModel(), algotype="kmeans")
        inspector.hasmultiple_dfs = True
        inspector.normalize_proba = False

        inspector.clusters = global_clusters
        inspector.y = global_y
        inspector.predicted_proba = global_predicted_proba
        inspector.ys = global_ys
        inspector.predicted_probas = global_predicted_probas
        target_summary_df = global_summary_df
        target_summary_dfs = global_summary_dfs
        aggregated_summary_df = global_aggregate_summary_df
        aggregated_summary_dfs = global_aggregate_summary_dfs_eval_set

        with patch.object(InspectorShap,
                          "create_summary_df") as mock_create_summary_df:
            with patch.object(
                    InspectorShap,
                    "aggregate_summary_df") as mock_aggregate_summary_df:
                # Set returns for each call of methods
                mock_create_summary_df.side_effect = [
                    target_summary_df, target_summary_dfs[0],
                    target_summary_dfs[1]
                ]
                mock_aggregate_summary_df.side_effect = [
                    aggregated_summary_df,
                    aggregated_summary_dfs[0],
                    aggregated_summary_dfs[1],
                ]
                inspector._compute_report()

        assert inspector.agg_summary_df.equals(aggregated_summary_df)
        assert inspector.summary_df.equals(target_summary_df)
        for index, item in inspector.agg_summary_dfs:
            assert item.equals(aggregated_summary_dfs[index])
        for index, item in inspector.summary_dfs:
            assert item.equals(target_summary_dfs[index])
    def test_compute_report_single_df(global_clusters, global_y,
                                      global_predicted_proba,
                                      global_summary_df,
                                      global_aggregate_summary_df):
        inspector = InspectorShap(model=MockModel(), algotype='kmeans')
        inspector.hasmultiple_dfs = False
        inspector.normalize_proba = target_normalize = False

        inspector.clusters = input_clust = global_clusters
        inspector.y = input_y = global_y
        inspector.predicted_proba = input_predicted_proba = global_predicted_proba
        target_summary_df = global_summary_df
        aggregated_summary = global_aggregate_summary_df

        with patch.object(InspectorShap,
                          'create_summary_df') as mock_create_summary_df:
            with patch.object(
                    InspectorShap,
                    'aggregate_summary_df') as mock_aggregate_summary_df:
                mock_create_summary_df.return_value = target_summary_df
                mock_aggregate_summary_df.return_value = aggregated_summary

                inspector._compute_report()

                #check if the methods were called with correct arguments
                pd.testing.assert_frame_equal(
                    mock_aggregate_summary_df.call_args[0][0],
                    target_summary_df)
                pd.testing.assert_series_equal(
                    mock_create_summary_df.call_args[0][0], input_clust)
                pd.testing.assert_series_equal(
                    mock_create_summary_df.call_args[0][1], input_y)
                pd.testing.assert_series_equal(
                    mock_create_summary_df.call_args[0][2],
                    input_predicted_proba)
                assert mock_create_summary_df.call_args[1][
                    'normalize'] == target_normalize

        # Check if the function correctly stored variables
        pd.testing.assert_frame_equal(inspector.agg_summary_df,
                                      aggregated_summary)
        pd.testing.assert_frame_equal(inspector.summary_df, target_summary_df)