示例#1
0
def validate_state_directory(path, manager_type):
    all_dirs = os.listdir(path)
    assert manager_type in all_dirs
    all_component_paths = os.listdir(path / manager_type)
    for component_path in all_component_paths:
        # Test if the component directory has UUID structure
        UUID(component_path, version=4)

        dm = DirectoryManager(path / manager_type, component_path)

        config_path = dm.get_config_directory()
        data_path = dm.get_data_directory()
        generators_path = dm.get_generators_directory()

        if manager_type == ManagerNames.EXPLAINER:
            assert not config_path.exists()
            assert data_path.exists()
            assert not generators_path.exists()
        elif manager_type == ManagerNames.COUNTERFACTUAL:
            assert config_path.exists()
            assert data_path.exists()
            assert generators_path.exists()
        elif manager_type == ManagerNames.ERROR_ANALYSIS:
            assert config_path.exists()
            assert data_path.exists()
            assert not generators_path.exists()
        elif manager_type == ManagerNames.CAUSAL:
            assert not config_path.exists()
            assert data_path.exists()
            assert not generators_path.exists()
    def test_counterfactual_manager_save_load(self, tmpdir):
        X_train, X_test, y_train, y_test, feature_names, _ = \
            create_iris_data()

        model = create_lightgbm_classifier(X_train, y_train)
        X_train['target'] = y_train
        X_test['target'] = y_test

        rai_insights = RAIInsights(
            model=model,
            train=X_train,
            test=X_test.iloc[0:10],
            target_column='target',
            task_type='classification')

        rai_insights.counterfactual.add(
            total_CFs=10, desired_class=2,
            features_to_vary=[feature_names[0]],
            permitted_range={feature_names[0]: [2.0, 5.0]})
        rai_insights.counterfactual.add(
            total_CFs=10, desired_class=1,
            features_to_vary=[feature_names[0]],
            permitted_range={feature_names[0]: [2.0, 5.0]})
        rai_insights.counterfactual.compute()

        assert len(rai_insights.counterfactual.get()) == 2
        cf_obj = rai_insights.counterfactual.get()[0]
        assert cf_obj is not None

        save_dir = tmpdir.mkdir('save-dir')
        rai_insights.save(save_dir)
        rai_insights_copy = RAIInsights.load(save_dir)

        assert len(rai_insights_copy.counterfactual.get()) == 2
        cf_obj = rai_insights_copy.counterfactual.get()[0]
        assert cf_obj is not None

        # Delete the dice-ml explainer directory so that the dice-ml
        # explainer can be re-trained rather being loaded from the
        # disc
        counterfactual_path = save_dir / "counterfactual"
        all_cf_dirs = DirectoryManager.list_sub_directories(
            counterfactual_path)
        for counterfactual_config_dir in all_cf_dirs:
            directory_manager = DirectoryManager(
                parent_directory_path=counterfactual_path,
                sub_directory_name=counterfactual_config_dir)
            explainer_pkl_path = \
                directory_manager.get_generators_directory() / "explainer.pkl"
            os.remove(explainer_pkl_path)

        rai_insights_copy_new = RAIInsights.load(save_dir)
        counterfactual_config_list = \
            rai_insights_copy_new.counterfactual._counterfactual_config_list
        assert len(counterfactual_config_list) == 2
        assert counterfactual_config_list[0].explainer is not None
        assert counterfactual_config_list[1].explainer is not None
示例#3
0
    def _load(path, rai_insights):
        """Load the CounterfactualManager from the given path.

        :param path: The directory path to load the CounterfactualManager from.
        :type path: str
        :param rai_insights: The loaded parent RAIInsights.
        :type rai_insights: RAIInsights
        :return: The CounterfactualManager manager after loading.
        :rtype: CounterfactualManager
        """
        inst = CounterfactualManager.__new__(CounterfactualManager)

        # Rehydrate model analysis data
        inst.__dict__[CounterfactualManager._MODEL] = rai_insights.model
        inst.__dict__[CounterfactualManager._TRAIN] = rai_insights.train
        inst.__dict__[CounterfactualManager._TEST] = rai_insights.test
        inst.__dict__[CounterfactualManager._TARGET_COLUMN] = \
            rai_insights.target_column
        inst.__dict__[CounterfactualManager._TASK_TYPE] = \
            rai_insights.task_type
        inst.__dict__[CounterfactualManager._CATEGORICAL_FEATURES] = \
            rai_insights.categorical_features

        inst.__dict__[CounterfactualManager._COUNTERFACTUAL_CONFIG_LIST] = []

        # DirectoryManager.ensure_dir_exists(path)
        all_cf_dirs = DirectoryManager.list_sub_directories(path)
        for counterfactual_config_dir in all_cf_dirs:
            directory_manager = DirectoryManager(
                parent_directory_path=path,
                sub_directory_name=counterfactual_config_dir)

            counterfactual_config = CounterfactualConfig.load_config(
                directory_manager.get_config_directory())

            counterfactual_config.load_result(
                directory_manager.get_data_directory())

            counterfactual_config.load_explainer(
                directory_manager.get_generators_directory())

            if counterfactual_config.explainer is None:
                explainer_load_err = (
                    'ERROR-LOADING-COUNTERFACTUAL-EXPLAINER: '
                    'There was an error loading the '
                    'counterfactual explainer model. '
                    'Retraining the counterfactual '
                    'explainer.')
                warnings.warn(explainer_load_err)
                counterfactual_config.explainer = \
                    inst._create_diceml_explainer(
                        counterfactual_config.method,
                        counterfactual_config.continuous_features)

            if counterfactual_config.counterfactual_obj is not None:
                # Validate the serialized output against schema
                schema = CounterfactualManager._get_counterfactual_schema(
                    version=counterfactual_config.counterfactual_obj.
                    metadata['version'])
                jsonschema.validate(
                    json.loads(
                        counterfactual_config.counterfactual_obj.to_json()),
                    schema)

            inst.__dict__[
                CounterfactualManager._COUNTERFACTUAL_CONFIG_LIST].append(
                    counterfactual_config)

        return inst
    def test_counterfactual_manager_save_load(self, tmpdir):
        X_train, X_test, y_train, y_test, feature_names, _ = \
            create_iris_data()

        model = create_lightgbm_classifier(X_train, y_train)
        X_train['target'] = y_train
        X_test['target'] = y_test

        rai_insights = RAIInsights(model=model,
                                   train=X_train,
                                   test=X_test.iloc[0:10],
                                   target_column='target',
                                   task_type='classification')

        rai_insights.counterfactual.add(
            total_CFs=10,
            desired_class=2,
            features_to_vary=[feature_names[0]],
            permitted_range={feature_names[0]: [2.0, 5.0]})
        rai_insights.counterfactual.add(
            total_CFs=10,
            desired_class=1,
            features_to_vary=[feature_names[0]],
            permitted_range={feature_names[0]: [2.0, 5.0]})
        rai_insights.counterfactual.compute()

        counterfactual_config_list_before_save = \
            rai_insights.counterfactual._counterfactual_config_list
        assert len(counterfactual_config_list_before_save) == 2
        assert len(rai_insights.counterfactual.get()) == 2
        cf_obj_1 = rai_insights.counterfactual.get()[0]
        assert cf_obj_1 is not None
        cf_obj_2 = rai_insights.counterfactual.get()[1]
        assert cf_obj_2 is not None

        save_dir = tmpdir.mkdir('save-dir')
        rai_insights.save(save_dir)
        rai_insights_copy = RAIInsights.load(save_dir)

        counterfactual_config_list_after_save = \
            rai_insights_copy.counterfactual._counterfactual_config_list
        assert len(rai_insights_copy.counterfactual.get()) == 2
        cf_obj_1 = rai_insights_copy.counterfactual.get()[0]
        assert cf_obj_1 is not None
        cf_obj_2 = rai_insights_copy.counterfactual.get()[1]
        assert cf_obj_2 is not None

        assert counterfactual_config_list_before_save[0].id in \
            [counterfactual_config_list_after_save[0].id,
             counterfactual_config_list_after_save[1].id]
        assert counterfactual_config_list_before_save[1].id in \
            [counterfactual_config_list_after_save[0].id,
             counterfactual_config_list_after_save[1].id]

        # Delete the dice-ml explainer directory so that the dice-ml
        # explainer can be re-trained rather being loaded from the
        # disc
        counterfactual_path = save_dir / "counterfactual"
        all_cf_dirs = DirectoryManager.list_sub_directories(
            counterfactual_path)
        for counterfactual_config_dir in all_cf_dirs:
            directory_manager = DirectoryManager(
                parent_directory_path=counterfactual_path,
                sub_directory_name=counterfactual_config_dir)
            explainer_pkl_path = \
                directory_manager.get_generators_directory() / "explainer.pkl"
            os.remove(explainer_pkl_path)

        with pytest.warns(UserWarning,
                          match='ERROR-LOADING-COUNTERFACTUAL-EXPLAINER: '
                          'There was an error loading the '
                          'counterfactual explainer model. '
                          'Retraining the counterfactual '
                          'explainer.'):
            rai_insights_copy_new = RAIInsights.load(save_dir)
        counterfactual_config_list = \
            rai_insights_copy_new.counterfactual._counterfactual_config_list
        assert len(counterfactual_config_list) == 2
        assert counterfactual_config_list[0].explainer is not None
        assert counterfactual_config_list[1].explainer is not None