Example #1
0
    def test_get_raw_explanation_no_datasets_mixin(self, boston,
                                                   mimic_explainer):
        model = create_sklearn_random_forest_regressor(
            boston[DatasetConstants.X_TRAIN], boston[DatasetConstants.Y_TRAIN])

        explainer = mimic_explainer(model, boston[DatasetConstants.X_TRAIN],
                                    LGBMExplainableModel)
        global_explanation = explainer.explain_global(
            boston[DatasetConstants.X_TEST])
        assert global_explanation.method == LIGHTGBM_METHOD

        kwargs = {ExplainParams.METHOD: global_explanation.method}
        kwargs[ExplainParams.FEATURES] = global_explanation.features
        kwargs[ExplainParams.MODEL_TASK] = ExplainType.REGRESSION
        kwargs[
            ExplainParams.
            LOCAL_IMPORTANCE_VALUES] = global_explanation._local_importance_values
        kwargs[ExplainParams.EXPECTED_VALUES] = 0
        kwargs[ExplainParams.CLASSIFICATION] = False
        kwargs[ExplainParams.IS_ENG] = True
        synthetic_explanation = _create_local_explanation(**kwargs)

        num_engineered_feats = boston[DatasetConstants.X_TRAIN].shape[1]
        feature_map = np.eye(5, num_engineered_feats)
        feature_names = [str(i) for i in range(feature_map.shape[0])]
        raw_names = feature_names[:feature_map.shape[0]]
        assert not _DatasetsMixin._does_quack(synthetic_explanation)
        global_raw_explanation = synthetic_explanation.get_raw_explanation(
            [feature_map], raw_feature_names=raw_names)
        self.validate_local_explanation_regression(synthetic_explanation,
                                                   global_raw_explanation,
                                                   feature_map,
                                                   has_eng_eval_data=False,
                                                   has_raw_eval_data=False,
                                                   has_dataset_data=False)
    def test_does_quack_datasets_negative(self):
        NoDatasets = type('InvalidDatasets', (BaseValid,), {})
        assert not _DatasetsMixin._does_quack(NoDatasets())

        class NoTrainData(object):
            @property
            def eval_data(self):
                return [[.2, .4, .01], [.3, .2, 0]]

            @property
            def eval_y_predicted(self):
                return None

            @property
            def eval_y_predicted_proba(self):
                return None
        NoTrainExp = type('InvalidDatasets', (NoTrainData, BaseValid), {})
        assert not _DatasetsMixin._does_quack(NoTrainExp())

        class NoTestData(object):
            @property
            def init_data(self):
                return 'a_dataset_id'

            @property
            def eval_y_predicted(self):
                return None

            @property
            def eval_y_predicted_proba(self):
                return None
        NoTestExp = type('InvalidDatasets', (NoTestData, BaseValid), {})
        assert not _DatasetsMixin._does_quack(NoTestExp())

        class NoEvalYPredicted(object):
            @property
            def init_data(self):
                return 'a_dataset_id'

            @property
            def eval_data(self):
                return[[.2, .4, .01], [.3, .2, 0]]

            @property
            def eval_y_predicted_proba(self):
                return None
        NoEvalYPExp = type('InvalidDatasets', (NoEvalYPredicted, BaseValid), {})
        assert not _DatasetsMixin._does_quack(NoEvalYPExp())

        class NoEvalYPredictedProba(object):
            @property
            def init_data(self):
                return 'a_dataset_id'

            @property
            def eval_data(self):
                return[[.2, .4, .01], [.3, .2, 0]]

            @property
            def eval_y_predicted(self):
                return None
        NoEvalYPPExp = type('InvalidDatasets', (NoEvalYPredictedProba, BaseValid), {})
        assert not _DatasetsMixin._does_quack(NoEvalYPPExp())
 def test_does_quack_datasets_mixin(self):
     ValidDatasets = type('ValidDatasets', (_DatasetsValid,), {})
     assert _DatasetsMixin._does_quack(ValidDatasets())