Exemple #1
0
def mock_tree_shap_explainer(monkeypatch, request):
    """
    Instantiates a TreeShap explainer where both the `_explainer` and `model` atttributes
    are mocked

    Parameters
    ----------
    request:
        Tuple containing number of outputs and model_output string (see TreeShap constructor).
    """

    n_outs, model_output = request.param
    seed = 0
    predictor = MockPredictor(out_dim=n_outs, out_type=model_output, seed=seed)
    tree_explainer = MockTreeExplainer(predictor, seed=seed)
    monkeypatch.setattr(shap, "TreeExplainer", tree_explainer)
    explainer = TreeShap(predictor, model_output=model_output)

    return explainer
Exemple #2
0
def make_tree_shap(dirname: Optional[Path] = None) -> TreeShap:
    np.random.seed(0)

    # get X_train for explainer fit
    adult = fetch_adult()
    data = adult.data
    target = adult.target
    data_perm = np.random.permutation(np.c_[data, target])
    data = data_perm[:, :-1]
    target = data_perm[:, -1]
    idx = 30000
    X_train, y_train = data[:idx, :], target[:idx]
    X_test, y_test = data[idx + 1:, :], target[idx + 1:]

    d_train = xgboost.DMatrix(X_train, label=y_train)
    d_test = xgboost.DMatrix(X_test, label=y_test)

    params = {
        "eta": 0.01,
        "objective": "binary:logistic",
        "subsample": 0.5,
        "base_score": np.mean(y_train),
        "eval_metric": "logloss",
    }
    model = xgboost.train(
        params,
        d_train,
        5000,
        evals=[(d_test, "test")],
        verbose_eval=100,
        early_stopping_rounds=20,
    )

    tree_explainer = TreeShap(model, model_output="raw", task="classification")
    tree_explainer.fit(X_train)

    if dirname is not None:
        tree_explainer.save(dirname)
    return tree_explainer
Exemple #3
0
def tree_explainer(rf_classifier, iris_data):
    treeshap = TreeShap(predictor=rf_classifier,
                        model_output='probability',
                        feature_names=iris_data['metadata']['feature_names'])
    treeshap.fit(iris_data['X_train'])
    return treeshap