Example #1
0
    def test_unsupported_versions_from_json(self, unsupported_version):
        json_str = json.dumps({'metadata': {'version': unsupported_version}})
        with pytest.raises(UserConfigValidationException) as ucve:
            CounterfactualExplanations.from_json(json_str)

        assert "Incompatible version {} found in json input".format(unsupported_version) in str(ucve)

        json_str = json.dumps({'metadata': {'versio': unsupported_version}})
        with pytest.raises(UserConfigValidationException) as ucve:
            CounterfactualExplanations.from_json(json_str)

        assert "No version field in the json input" in str(ucve)
Example #2
0
 def test_no_counterfactuals_found_summary_importance(
         self, desired_class, sample_custom_query_10, total_CFs, version):
     counterfactual_explanations = self.exp.global_feature_importance(
         query_instances=sample_custom_query_10,
         desired_class=desired_class,
         total_CFs=total_CFs)
     counterfactual_explanations.cf_examples_list[0].final_cfs_df = None
     counterfactual_explanations.cf_examples_list[
         0].final_cfs_df_sparse = None
     counterfactual_explanations.cf_examples_list[9].final_cfs_df = None
     counterfactual_explanations.cf_examples_list[
         9].final_cfs_df_sparse = None
     self.verify_counterfactual_explanations(
         counterfactual_explanations,
         None,
         sample_custom_query_10.shape[0],
         version,
         local_importance_available=True,
         summary_importance_available=True)
     counterfactual_explanations_as_json = counterfactual_explanations.to_json(
     )
     recovered_counterfactual_explanations = CounterfactualExplanations.from_json(
         counterfactual_explanations_as_json)
     self.verify_counterfactual_explanations(
         recovered_counterfactual_explanations,
         None,
         sample_custom_query_10.shape[0],
         version,
         local_importance_available=True,
         summary_importance_available=True)
     assert counterfactual_explanations == recovered_counterfactual_explanations
Example #3
0
    def test_summary_importance_output(self, desired_class,
                                       sample_custom_query_10, total_CFs,
                                       version):
        counterfactual_explanations = self.exp.global_feature_importance(
            query_instances=sample_custom_query_10,
            desired_class=desired_class,
            total_CFs=total_CFs)

        self.verify_counterfactual_explanations(
            counterfactual_explanations,
            total_CFs,
            sample_custom_query_10.shape[0],
            version,
            local_importance_available=True,
            summary_importance_available=True)

        json_output = counterfactual_explanations.to_json()
        assert json_output is not None
        assert json.loads(json_output).get('metadata').get(
            'version') == version

        recovered_counterfactual_explanations = CounterfactualExplanations.from_json(
            json_output)
        self.verify_counterfactual_explanations(
            counterfactual_explanations,
            total_CFs,
            sample_custom_query_10.shape[0],
            version,
            local_importance_available=True,
            summary_importance_available=True)

        assert recovered_counterfactual_explanations == counterfactual_explanations
Example #4
0
    def test_counterfactual_explanations_output(self, desired_class,
                                                sample_custom_query_1,
                                                total_CFs, version):
        counterfactual_explanations = self.exp.generate_counterfactuals(
            query_instances=sample_custom_query_1,
            desired_class=desired_class,
            total_CFs=total_CFs)

        self.verify_counterfactual_explanations(counterfactual_explanations,
                                                total_CFs,
                                                sample_custom_query_1.shape[0],
                                                version)

        json_output = counterfactual_explanations.to_json()
        assert json_output is not None
        assert json.loads(json_output).get('metadata').get(
            'version') == version

        recovered_counterfactual_explanations = CounterfactualExplanations.from_json(
            json_output)
        self.verify_counterfactual_explanations(counterfactual_explanations,
                                                total_CFs,
                                                sample_custom_query_1.shape[0],
                                                version)

        assert recovered_counterfactual_explanations == counterfactual_explanations
Example #5
0
    def test_KD_tree_counterfactual_explanations_output(
            self, desired_range, sample_custom_query_2, total_CFs):
        counterfactual_explanations = self.exp_regr.generate_counterfactuals(
            query_instances=sample_custom_query_2,
            total_CFs=total_CFs,
            desired_range=desired_range)

        assert counterfactual_explanations is not None
        json_str = counterfactual_explanations.to_json()
        assert json_str is not None

        recovered_counterfactual_explanations = CounterfactualExplanations.from_json(
            json_str)
        assert recovered_counterfactual_explanations is not None
        assert counterfactual_explanations == recovered_counterfactual_explanations
Example #6
0
 def test_no_counterfactuals_found(self, desired_class,
                                   sample_custom_query_1, total_CFs,
                                   version):
     counterfactual_explanations = self.exp.generate_counterfactuals(
         query_instances=sample_custom_query_1, desired_class=desired_class,
         total_CFs=total_CFs)
     counterfactual_explanations.cf_examples_list[0].final_cfs_df = None
     counterfactual_explanations.cf_examples_list[0].final_cfs_df_sparse = None
     self.verify_counterfactual_explanations(counterfactual_explanations, None,
                                             sample_custom_query_1.shape[0], version)
     counterfactual_explanations_as_json = counterfactual_explanations.to_json()
     recovered_counterfactual_explanations = CounterfactualExplanations.from_json(
         counterfactual_explanations_as_json)
     self.verify_counterfactual_explanations(recovered_counterfactual_explanations, None,
                                             sample_custom_query_1.shape[0], version)
     assert counterfactual_explanations == recovered_counterfactual_explanations
Example #7
0
    def test_serialization_deserialization_counterfactual_explanations_class(
            self):

        counterfactual_explanations = CounterfactualExplanations(
            cf_examples_list=[],
            local_importance=None,
            summary_importance=None)
        assert counterfactual_explanations.cf_examples_list is not None
        assert len(counterfactual_explanations.cf_examples_list) == 0
        assert counterfactual_explanations.summary_importance is None
        assert counterfactual_explanations.local_importance is None
        assert counterfactual_explanations.metadata is not None
        assert counterfactual_explanations.metadata['version'] is not None
        assert counterfactual_explanations.metadata['version'] == '1.0'

        counterfactual_explanations_as_json = counterfactual_explanations.to_json(
        )
        recovered_counterfactual_explanations = CounterfactualExplanations.from_json(
            counterfactual_explanations_as_json)
        assert counterfactual_explanations == recovered_counterfactual_explanations
Example #8
0
    def test_empty_counterfactual_explanations_object(self, version):

        counterfactual_explanations = CounterfactualExplanations(
            cf_examples_list=[],
            local_importance=None,
            summary_importance=None,
            version=version)
        self.verify_counterfactual_explanations(counterfactual_explanations, None,
                                                0, version)

        counterfactual_explanations_as_json = counterfactual_explanations.to_json()
        assert counterfactual_explanations_as_json is not None

        recovered_counterfactual_explanations = CounterfactualExplanations.from_json(
            counterfactual_explanations_as_json)

        self.verify_counterfactual_explanations(recovered_counterfactual_explanations, None,
                                                0, version)

        assert counterfactual_explanations == recovered_counterfactual_explanations
Example #9
0
    def load_result(self, data_directory_path):
        metadata_file_path = (data_directory_path /
                              (_CommonSchemaConstants.METADATA + '.json'))

        if metadata_file_path.exists():
            with open(metadata_file_path, 'r') as result_file:
                metadata = json.load(result_file)

            if metadata['version'] == _SchemaVersions.V1:
                cf_schema_keys = _V1SchemaConstants.ALL
            else:
                cf_schema_keys = _V2SchemaConstants.ALL

            counterfactual_examples_dict = {}
            for counterfactual_examples_key in cf_schema_keys:
                result_path = (data_directory_path /
                               (counterfactual_examples_key + '.json'))
                with open(result_path, 'r') as result_file:
                    counterfactual_examples_dict[
                        counterfactual_examples_key] = json.load(result_file)

            counterfactuals_json_str = json.dumps(counterfactual_examples_dict)
            self.counterfactual_obj = \
                CounterfactualExplanations.from_json(counterfactuals_json_str)
        else:
            self.counterfactual_obj = None

        result_path = (data_directory_path /
                       (CounterfactualConfig.HAS_COMPUTATION_FAILED + '.json'))
        with open(result_path, 'r') as result_file:
            self.has_computation_failed = json.load(result_file)

        result_path = (data_directory_path /
                       (CounterfactualConfig.FAILURE_REASON + '.json'))
        with open(result_path, 'r') as result_file:
            self.failure_reason = json.load(result_file)

        result_path = (data_directory_path /
                       (CounterfactualConfig.IS_COMPUTED + '.json'))
        with open(result_path, 'r') as result_file:
            self.is_computed = json.load(result_file)