コード例 #1
0
ファイル: shap.py プロジェクト: kiminh/expybox
    def explain(self, options, instance=None):
        initjs()
        background_data = self._kmeans(options['kmeans_count']) \
            if options['background_data'] == 'kmeans' else options['data']
        nsamples = 'auto' if options['auto_nsamples'] else options['nsamples']
        explainer = KernelExplainer(model=self.predict_function,
                                    data=background_data,
                                    link=options['link'])
        # create sample from train data
        data = self.X_train[np.random.choice(
            self.X_train.shape[0], size=options['sample_size'], replace=False
        ), :]

        shap_values = explainer.shap_values(X=data,
                                            nsamples=nsamples,
                                            l1_reg=options['l1_reg'])

        # limit to only selected class (if any was selected)
        if 'class_to_explain' in options and options['class_to_explain'] != -1:
            shap_values = shap_values[options['class_to_explain']]

        summary_plot(shap_values=shap_values,
                     features=data,
                     feature_names=self.feature_names,
                     plot_type='bar',
                     class_names=self.class_names)
コード例 #2
0
def main():
    n_features = 100
    random_state = 0

    X_train = fetch_20newsgroups(subset='train', remove=('headers', 'footers', 'quotes'), categories=None).data
    X_train = preprocess_data(X_train)

    print('qui')
    stc = generate_synthetic_text_classifier(X_train, n_features=n_features, random_state=random_state)

    predict = stc['predict']
    predict_proba = stc['predict_proba']
    words = stc['words_vec']
    vectorizer = stc['vectorizer']
    nbr_terms = stc['nbr_terms']

    X_test = fetch_20newsgroups(subset='test', remove=('headers', 'footers', 'quotes'), categories=None).data
    X_test = preprocess_data(X_test)
    X_test = vectorizer.transform(X_test).toarray()
    print('quo')

    # reference = np.zeros(nbr_terms)
    # explainer = KernelExplainer(predict_proba, np.reshape(reference, (1, len(reference))))

    sentences_length = list()
    for x in X_test:
        sentences_length.append(len(np.where(x != 0)[0]))

    print('qua')
    avg_nbr_words = np.mean(sentences_length)
    std_nbr_words = np.std(sentences_length)
    words_with_weight = np.where(words != 0)[0]
    print(avg_nbr_words, std_nbr_words)
    print(words_with_weight)
    reference = list()
    for i in range(10):
        nbr_words_in_sentence = int(np.random.normal(avg_nbr_words, std_nbr_words))
        selected_words = np.random.choice(range(nbr_terms), size=nbr_words_in_sentence, replace=False)
        print(i, nbr_words_in_sentence, len(set(selected_words) & set(words_with_weight)))
        while len(set(selected_words) & set(words_with_weight)) > 0:
            nbr_words_in_sentence = int(np.random.normal(avg_nbr_words, std_nbr_words))
            selected_words = np.random.choice(range(nbr_terms), size=nbr_words_in_sentence, replace=False)
            print(i, nbr_words_in_sentence, len(set(selected_words) & set(words_with_weight)))
        sentence = np.zeros(nbr_terms)
        sentence[selected_words] = 1.0
        reference.append(sentence)
        print('')
    reference = np.array(reference)
    print(reference)
    explainer = KernelExplainer(predict_proba, reference) #X_test[:10])

    for x in X_test[:10]:
        expl_val = explainer.shap_values(x, l1_reg='bic')[1]
        gt_val = get_word_importance_explanation(x, stc)
        wbs = word_based_similarity(expl_val, gt_val, use_values=False)
        # print(expl_val)
        # print(gt_val)
        print(wbs, word_based_similarity(expl_val, gt_val, use_values=True))
        print('')
def LIME_graphSHAP(X, y, feature_names, ylabels, clf, xs_toexplain, labels_toexplain, ax=None, subplots=False, plotlime=True):

    ## Plot explanations on feature space
    if ax is None:
        if subplots:
            if len(xs_toexplain)>=5:
                nrows = int(len(xs_toexplain)/5)
                fig, axs = plt.subplots(nrows=nrows, ncols=int(len(xs_toexplain)/nrows), figsize=(15,3*nrows))
                axs = axs.flatten()
            else:
                fig, axs = plt.subplots(nrows=1, ncols=len(xs_toexplain), figsize=(15,3))
        else:
            fig, ax = plt.subplots()

    # Plot LIME result - loop if several
    for i in range(len(xs_toexplain)):
        
        if subplots:
            plt.sca(axs[i])
            ax = axs[i]
            
        if i==0 or subplots:
            

            # Plot contour of black-box predictions
            plot_classification_contour(X, clf, ax)
            # Plot training set
            plot_training_set(X, y, ax)

            ylim_bak = ax.get_ylim()
            xlim_bak = ax.get_xlim()
            #color_palette = sns.color_palette("bright", n_colors=len(xs_toexplain))
            color_palette = ['lime' for _ in range(len(xs_toexplain))]
        
        ## LIME - Generate explanations
        explainer = lime_assessment.lime_tabular.LimeTabularExplainer(X, feature_names=feature_names, class_names=ylabels, discretize_continuous=False, kernel_width=None)
        exp = explainer.explain_instance(xs_toexplain[i], clf.predict_proba, num_features=2, top_labels=len(ylabels), labels=range(len(ylabels)))
        
        shap_explainer = KernelExplainer(clf.predict_proba, X, nsamples=10000)
        e = shap_explainer.explain(np.reshape(xs_toexplain[i], (1, X.shape[1])))

        # Plot LIME regression
        if plotlime == True:
            plot_lime_regression(X, exp, xs_toexplain[i], labels_toexplain[i], ax, color_palette[i], exp.points_to_plot)
        x_ridge = [-10, 10]
        row = 0
        y_shap = [(0.5 - e.effects[0, row] * x - e.base_value[row])/e.effects[1, row] for x in x_ridge]


        # Plot LIME linear regression
        plt.sca(ax)
        plt.plot(x_ridge, y_shap, color='red', linestyle=':', linewidth=4, label="other shap regression")

        plt.scatter(xs_toexplain[i][0], xs_toexplain[i][1], color='lime', marker='8', linewidth=4)
        plt.ylim(ylim_bak)
        plt.xlim(xlim_bak)
コード例 #4
0
ファイル: object.py プロジェクト: vishalbelsare/DALEX
    def fit(self,
            explainer,
            new_observation,
            shap_explainer_type=None,
            **kwargs):
        """Calculate the result of explanation

        Fit method makes calculations in place and changes the attributes.

        Parameters
        -----------
        explainer : Explainer object
            Model wrapper created using the Explainer class.
        new_observation : pd.Series or np.ndarray
            An observation for which a prediction needs to be explained.
        shap_explainer_type : {'TreeExplainer', 'DeepExplainer', 'GradientExplainer', 'LinearExplainer', 'KernelExplainer'}
            String name of the Explainer class (default is `None`, which automatically
            chooses an Explainer to use).
        kwargs: dict
            Keyword parameters passed to the `shapley_values` method.

        Returns
        -----------
        None
        """
        from shap import TreeExplainer, DeepExplainer, GradientExplainer, LinearExplainer, KernelExplainer

        checks.check_compatibility(explainer)
        shap_explainer_type = checks.check_shap_explainer_type(
            shap_explainer_type, explainer.model)

        if self.type == 'predict_parts':
            new_observation = checks.check_new_observation_predict_parts(
                new_observation, explainer)

        if shap_explainer_type == "TreeExplainer":
            try:
                self.shap_explainer = TreeExplainer(explainer.model,
                                                    explainer.data.values)
            except:  # https://github.com/ModelOriented/DALEX/issues/371
                self.shap_explainer = TreeExplainer(explainer.model)
        elif shap_explainer_type == "DeepExplainer":
            self.shap_explainer = DeepExplainer(explainer.model,
                                                explainer.data.values)
        elif shap_explainer_type == "GradientExplainer":
            self.shap_explainer = GradientExplainer(explainer.model,
                                                    explainer.data.values)
        elif shap_explainer_type == "LinearExplainer":
            self.shap_explainer = LinearExplainer(explainer.model,
                                                  explainer.data.values)
        elif shap_explainer_type == "KernelExplainer":
            self.shap_explainer = KernelExplainer(
                lambda x: explainer.predict(x), explainer.data.values)

        self.result = self.shap_explainer.shap_values(new_observation.values,
                                                      **kwargs)
        self.new_observation = new_observation
        self.shap_explainer_type = shap_explainer_type
コード例 #5
0
def main(model, i):

    with open(f'{model}-models.pkl', 'rb') as fh:
        models = pickle.load(fh)

    sampled = np.load(f'{model}-selected.npy', allow_pickle=True)

    sampled_indices = sampled[-1]

    def model_wrapper(X, i=0):
        mu, var = predict_coregionalized(models[0], X, int(i))
        return mu.flatten()

    shap0 = KernelExplainer(partial(model_wrapper, i=int(i)),
                            shap.kmeans(X, 40))

    shap0_values = shap0.shap_values(X)

    np.save(f'{model}-shap-{i}', shap0_values)
コード例 #6
0
ファイル: shap.py プロジェクト: kiminh/expybox
    def explain(self, options, instance=None):

        if instance is None:
            raise ValueError("Instance was not provided")

        initjs()
        instance = instance.to_numpy()
        data = self._kmeans(options['kmeans_count']) \
            if options['background_data'] == 'kmeans' else options['data']
        nsamples = 'auto' if options['auto_nsamples'] else options['nsamples']
        explainer = KernelExplainer(model=self.predict_function,
                                    data=data,
                                    link=options['link'])
        shap_values = explainer.shap_values(X=instance,
                                            nsamples=nsamples,
                                            l1_reg=options['l1_reg'])
        if self.is_classification:
            shap_values = shap_values[options['class_to_explain']]
            base_value = explainer.expected_value[[
                options['class_to_explain']
            ]]
        else:
            base_value = explainer.expected_value

        if options['plot_type'] == 'force' or options['plot_type'] == 'both':
            display(
                force_plot(base_value=base_value,
                           shap_values=shap_values,
                           features=instance,
                           feature_names=self.feature_names,
                           show=True,
                           link=options['link']))

        if options['plot_type'] == 'decision' or options['plot_type'] == 'both':
            decision_plot(base_value=base_value,
                          shap_values=shap_values,
                          features=instance,
                          feature_names=list(self.feature_names),
                          show=True,
                          color_bar=True,
                          link=options['link'])
コード例 #7
0
def test_front_page_model_agnostic():
    from shap import KernelExplainer, DenseData, visualize, initjs
    from sklearn import datasets,neighbors
    from numpy import random, arange

    # print the JS visualization code to the notebook
    initjs()

    # train a k-nearest neighbors classifier on a random subset
    iris = datasets.load_iris()
    random.seed(2)
    inds = arange(len(iris.target))
    random.shuffle(inds)
    knn = neighbors.KNeighborsClassifier()
    knn.fit(iris.data, iris.target == 0)

    # use Shap to explain a single prediction
    background = DenseData(iris.data[inds[:100],:], iris.feature_names) # name the features
    explainer = KernelExplainer(knn.predict, background, nsamples=100)#knn.predict
    x = iris.data[inds[102:103],:]
    visualize(explainer.explain(x))
コード例 #8
0
ファイル: explainer.py プロジェクト: biolab/orange3-explain
def _explain_other_models(
    model: Model,
    transformed_data: Table,
    transformed_reference_data: Table,
    progress_callback: Callable,
) -> Tuple[List[np.ndarray], np.ndarray, np.ndarray]:
    """
    Computes SHAP values for any learner with KernelExplainer.
    """
    # 1000 is a number that for normal data and model do not take so long
    data_sample, sample_mask = _subsample_data(transformed_data, 1000)

    try:
        ref = kmeans(transformed_reference_data.X, k=10)
    except ValueError:
        # k-means fails with value error when it cannot produce enough clusters
        # in this case we will use sample instead of clusters
        ref = sample(transformed_reference_data.X, nsamples=100)

    explainer = KernelExplainer(
        lambda x:
        (model(x)
         if model.domain.class_var.is_continuous else model(x, model.Probs)),
        ref,
    )

    shap_values = []
    for i, row in enumerate(data_sample.X):
        progress_callback(i / len(data_sample))
        shap_values.append(
            explainer.shap_values(row,
                                  nsamples=100,
                                  silent=True,
                                  l1_reg="num_features(90)"))
    return (
        _join_shap_values(shap_values),
        sample_mask,
        explainer.expected_value,
    )
コード例 #9
0
class ShapExplainer(Explainer):
    def __init__(self, model, training_data):
        super(ShapExplainer, self).__init__(model, training_data)
        data = self.preprocessor.transform(training_data)
        background_data = kmeans(data, 10)
        self.explainer = KernelExplainer(model=self.model.predict_proba,
                                         data=background_data)

    def explain(self, instance, budget):
        instance = self.preprocessor.transform([instance])[0]
        values = self.explainer.shap_values(X=instance,
                                            nsamples="auto",
                                            l1_reg='num_features(%d)' %
                                            budget)[1]
        pairs = sorted(zip(self.feature_names, values),
                       key=lambda x: abs(x[1]),
                       reverse=True)
        return pairs[:budget]
コード例 #10
0
def main():
    n_features = (20, 20)
    img_size = (32, 32, 3)
    cell_size = (4, 4)
    colors_p = np.array([0.15, 0.7, 0.15])
    p_border = 1.0

    # img_draft = np.array([ # ['k', 'k', 'k', 'k', 'k', 'k', 'k', 'k'],
    #     ['k', 'k', 'k', 'k', 'k', 'g', 'r', 'k'],
    #     ['g', 'k', 'k', 'k', 'k', 'k', 'k', 'g'],
    #     ['k', 'g', 'k', 'k', 'k', 'b', 'k', 'k'],
    #     ['k', 'g', 'k', 'k', 'g', 'g', 'k', 'b'],
    #     ['k', 'k', 'k', 'k', 'g', 'k', 'k', 'g'],
    #     ['g', 'k', 'k', 'k', 'k', 'k', 'k', 'k'],
    #     ['k', 'k', 'k', 'k', 'k', 'k', 'k', 'k'],
    #     ['k', 'k', 'k', 'k', 'g', 'k', 'k', 'k'],
    #
    # ])
    # img = generate_img_defined(img_draft, img_size=img_size, cell_size=cell_size)
    # plt.imshow(img)
    # plt.xticks(())
    # plt.yticks(())
    # # plt.savefig('../fig/pattern.png', format='png', bbox_inches='tight')
    # plt.show()

    pattern_draft = np.array([  # ['k', 'k', 'k', 'k', 'k', 'k', 'k', 'k'],
        ['k', 'k', 'k', 'k', 'k'],
        ['k', 'k', 'k', 'b', 'k'],
        ['k', 'k', 'g', 'g', 'k'],
        ['k', 'k', 'g', 'k', 'k'],
        ['k', 'k', 'k', 'k', 'k'],
    ])

    pattern = generate_img_defined(pattern_draft,
                                   img_size=(20, 20, 3),
                                   cell_size=cell_size)

    sic = generate_synthetic_image_classifier(img_size=img_size,
                                              cell_size=cell_size,
                                              n_features=n_features,
                                              p_border=p_border,
                                              pattern=pattern)

    pattern = sic['pattern']
    predict = sic['predict']
    predict_proba = sic['predict_proba']

    plt.imshow(pattern)
    plt.xticks(())
    plt.yticks(())
    # plt.savefig('../fig/pattern.png', format='png', bbox_inches='tight')
    plt.show()

    X_test = generate_random_img_dataset(pattern,
                                         nbr_images=1000,
                                         pattern_ratio=0.4,
                                         img_size=img_size,
                                         cell_size=cell_size,
                                         min_nbr_cells=0.1,
                                         max_nbr_cells=0.3,
                                         colors_p=colors_p)

    Y_test = predict(X_test)
    idx = np.where(Y_test == 1)[0][0]

    # x = X_test[idx]
    img_draft = np.array([  # ['k', 'k', 'k', 'k', 'k', 'k', 'k', 'k'],
        ['k', 'k', 'k', 'k', 'k', 'g', 'r', 'k'],
        ['g', 'k', 'k', 'k', 'k', 'k', 'k', 'g'],
        ['k', 'g', 'k', 'k', 'k', 'b', 'k', 'k'],
        ['k', 'g', 'k', 'k', 'g', 'g', 'k', 'b'],
        ['k', 'k', 'k', 'k', 'g', 'k', 'k', 'g'],
        ['g', 'k', 'k', 'k', 'k', 'k', 'k', 'k'],
        ['k', 'k', 'k', 'k', 'k', 'k', 'k', 'k'],
        ['k', 'k', 'k', 'k', 'g', 'k', 'k', 'k'],
    ])
    x = generate_img_defined(img_draft, img_size=img_size, cell_size=cell_size)
    plt.imshow(x)
    plt.xticks(())
    plt.yticks(())
    # plt.savefig('../fig/image.png', format='png', bbox_inches='tight')
    plt.show()

    gt_val = get_pixel_importance_explanation(x, sic)
    max_val = np.nanpercentile(np.abs(gt_val), 99.9)
    plt.imshow(np.reshape(gt_val, img_size[:2]),
               cmap='RdYlBu',
               vmin=-max_val,
               vmax=max_val,
               alpha=0.7)
    plt.xticks(())
    plt.yticks(())
    # plt.savefig('../fig/saliencymap.png', format='png', bbox_inches='tight')
    plt.show()

    # plt.imshow(x)
    # plt.imshow(np.reshape(gt_val, img_size[:2]), cmap='RdYlBu', vmin=-max_val, vmax=max_val, alpha=0.7)
    # plt.xticks(())
    # plt.yticks(())
    # plt.savefig('../fig/saliencymap2.png', format='png', bbox_inches='tight')
    # plt.show()

    lime_explainer = LimeImageExplainer()
    segmenter = SegmentationAlgorithm('quickshift',
                                      kernel_size=1,
                                      max_dist=10,
                                      ratio=0.5)
    tot_num_features = img_size[0] * img_size[1]

    lime_exp = lime_explainer.explain_instance(x,
                                               predict_proba,
                                               top_labels=2,
                                               hide_color=0,
                                               num_samples=10000,
                                               segmentation_fn=segmenter)
    _, lime_expl_val = lime_exp.get_image_and_mask(
        1,
        positive_only=True,
        num_features=tot_num_features,
        hide_rest=False,
        min_weight=0.0)
    max_val = np.nanpercentile(np.abs(lime_expl_val), 99.9)
    plt.imshow(lime_expl_val,
               cmap='RdYlBu',
               vmin=-max_val,
               vmax=max_val,
               alpha=0.7)
    plt.xticks(())
    plt.yticks(())
    plt.title('lime', fontsize=20)
    plt.savefig('../fig/saliencymap_lime.png',
                format='png',
                bbox_inches='tight')
    plt.show()

    background = np.array([np.zeros(img_size).ravel()] * 10)
    shap_explainer = KernelExplainer(predict_proba, background)

    shap_expl_val = shap_explainer.shap_values(x.ravel(), l1_reg='bic')[1]
    shap_expl_val = np.sum(np.reshape(shap_expl_val, img_size), axis=2)
    tmp = np.zeros(shap_expl_val.shape)
    tmp[np.where(shap_expl_val > 0.0)] = 1.0
    shap_expl_val = tmp
    max_val = np.nanpercentile(np.abs(shap_expl_val), 99.9)
    plt.imshow(shap_expl_val,
               cmap='RdYlBu',
               vmin=-max_val,
               vmax=max_val,
               alpha=0.7)
    plt.xticks(())
    plt.yticks(())
    plt.title('shap', fontsize=20)
    plt.savefig('../fig/saliencymap_shap.png',
                format='png',
                bbox_inches='tight')
    plt.show()

    nbr_records = 10
    Xm_test = np.array([x.ravel() for x in X_test[:nbr_records]])
    maple_explainer = MAPLE(Xm_test,
                            Y_test[:nbr_records],
                            Xm_test,
                            Y_test[:nbr_records],
                            n_estimators=5,
                            max_features=0.5,
                            min_samples_leaf=5)

    maple_exp = maple_explainer.explain(x)
    maple_expl_val = maple_exp['coefs'][:-1]
    maple_expl_val = np.sum(np.reshape(maple_expl_val, img_size), axis=2)
    tmp = np.zeros(maple_expl_val.shape)
    tmp[np.where(maple_expl_val > 0.0)] = 1.0
    maple_expl_val = tmp
    max_val = np.nanpercentile(np.abs(shap_expl_val), 99.9)
    plt.imshow(maple_expl_val,
               cmap='RdYlBu',
               vmin=-max_val,
               vmax=max_val,
               alpha=0.7)
    plt.xticks(())
    plt.yticks(())
    plt.title('maple', fontsize=20)
    plt.savefig('../fig/saliencymap_maple.png',
                format='png',
                bbox_inches='tight')
    plt.show()

    lime_f1, lime_pre, lime_rec = pixel_based_similarity(lime_expl_val.ravel(),
                                                         gt_val,
                                                         ret_pre_rec=True)
    shap_f1, shap_pre, shap_rec = pixel_based_similarity(shap_expl_val.ravel(),
                                                         gt_val,
                                                         ret_pre_rec=True)
    maple_f1, maple_pre, maple_rec = pixel_based_similarity(
        maple_expl_val.ravel(), gt_val, ret_pre_rec=True)

    print(lime_f1, lime_pre, lime_rec)
    print(shap_f1, shap_pre, shap_rec)
    print(maple_f1, maple_pre, maple_rec)
コード例 #11
0
ファイル: object.py プロジェクト: milicanikolic/DALEX
class ShapWrapper:
    """Explanation wrapper for the 'shap' package

    This object uses the shap package to create the model explanation.
    See ttps://github.com/slundberg/shap

    Parameters
    ----------
    type : {'predict_parts', 'model_parts'}


    Attributes
    ----------
    result : list or numpy.ndarray
        Calculated shap values for `new_observation` data.
    shap_explainer : {shap.TreeExplainer, shap.DeepExplainer,
        shap.GradientExplainer, shap.LinearExplainer, shap.KernelExplainer}
        Explainer object from the 'shap' package.
    shap_explainer_type : {'TreeExplainer', 'DeepExplainer',
        'GradientExplainer', 'LinearExplainer', 'KernelExplainer'}
        String name of the Explainer class.
    new_observation : pandas.Series or pandas.DataFrame
        Observations for which the shap values will be calculated
        (later stored in `result`).
    type : {'predict_parts', 'model_parts'}


    Notes
    ----------
    https://github.com/slundberg/shap
    """
    def __init__(self, type):
        self.shap_explainer = None
        self.type = type
        self.result = None
        self.new_observation = None
        self.shap_explainer_type = None

    def _repr_html_(self):
        return self.result._repr_html_()

    def fit(self,
            explainer,
            new_observation,
            shap_explainer_type=None,
            **kwargs):
        """Calculate the result of explanation

        Fit method makes calculations in place and changes the attributes.

        Parameters
        -----------
        explainer : Explainer object
            Model wrapper created using the Explainer class.
        new_observation : pd.Series or np.ndarray
            An observation for which a prediction needs to be explained.
        shap_explainer_type : {'TreeExplainer', 'DeepExplainer',
            'GradientExplainer', 'LinearExplainer', 'KernelExplainer'}
            String name of the Explainer class (default is None, which automatically
            chooses an Explainer to use).

        Returns
        -----------
        None
        """
        from shap import TreeExplainer, DeepExplainer, GradientExplainer, LinearExplainer, KernelExplainer

        check_compatibility(explainer)
        shap_explainer_type = check_shap_explainer_type(
            shap_explainer_type, explainer.model)

        if self.type == 'predict_parts':
            new_observation = check_new_observation_predict_parts(
                new_observation, explainer)

        if shap_explainer_type == "TreeExplainer":
            self.shap_explainer = TreeExplainer(explainer.model,
                                                explainer.data.values)
        elif shap_explainer_type == "DeepExplainer":
            self.shap_explainer = DeepExplainer(explainer.model,
                                                explainer.data.values)
        elif shap_explainer_type == "GradientExplainer":
            self.shap_explainer = GradientExplainer(explainer.model,
                                                    explainer.data.values)
        elif shap_explainer_type == "LinearExplainer":
            self.shap_explainer = LinearExplainer(explainer.model,
                                                  explainer.data.values)
        elif shap_explainer_type == "KernelExplainer":
            self.shap_explainer = KernelExplainer(
                lambda x: explainer.predict(x), explainer.data.values)

        self.result = self.shap_explainer.shap_values(new_observation.values,
                                                      **kwargs)
        self.new_observation = new_observation
        self.shap_explainer_type = shap_explainer_type

    def plot(self, **kwargs):
        """Plot the Shap Wrapper

        Parameters
        ----------
        kwargs :
            Keyword arguments passed to one of the:
                - shap.force_plot when type is 'predict_parts'
                - shap.summary_plot when type is 'model_parts'
            Exceptions are: `base_value`, `shap_values`,
            `features` and `feature_names`.
            Other parameters: https://github.com/slundberg/shap

        Returns
        -----------
        None

        Notes
        --------
        https://github.com/slundberg/shap
        """
        from shap import force_plot, summary_plot

        if self.type == 'predict_parts':
            if isinstance(self.shap_explainer.expected_value,
                          (np.ndarray, list)):
                base_value = self.shap_explainer.expected_value[1]
            else:
                base_value = self.shap_explainer.expected_value

            shap_values = self.result[1] if isinstance(self.result,
                                                       list) else self.result
            force_plot(base_value=base_value,
                       shap_values=shap_values,
                       features=self.new_observation.values,
                       feature_names=self.new_observation.columns,
                       matplotlib=True,
                       **kwargs)
        elif self.type == 'model_parts':
            summary_plot(shap_values=self.result,
                         features=self.new_observation,
                         **kwargs)
コード例 #12
0
def run(black_box, n_records, img_size, cell_size, n_features, p_border,
        colors_p, random_state, filename):

    sic = generate_synthetic_image_classifier(img_size=img_size,
                                              cell_size=cell_size,
                                              n_features=n_features,
                                              p_border=p_border,
                                              random_state=random_state)

    pattern = sic['pattern']
    predict = sic['predict']
    predict_proba = sic['predict_proba']

    X_test = generate_random_img_dataset(pattern,
                                         nbr_images=n_records,
                                         pattern_ratio=0.5,
                                         img_size=img_size,
                                         cell_size=cell_size,
                                         min_nbr_cells=0.1,
                                         max_nbr_cells=0.3,
                                         colors_p=colors_p)

    Y_test_proba = predict_proba(X_test)
    Y_test = predict(X_test)

    lime_explainer = LimeImageExplainer()
    segmenter = SegmentationAlgorithm('quickshift',
                                      kernel_size=1,
                                      max_dist=10,
                                      ratio=0.5)
    tot_num_features = img_size[0] * img_size[1]

    background = np.array([np.zeros(img_size).ravel()] * 10)
    shap_explainer = KernelExplainer(predict_proba, background)

    nbr_records_explainer = 10
    idx_records_train_expl = np.random.choice(range(len(X_test)),
                                              size=nbr_records_explainer,
                                              replace=False)
    idx_records_test_expl = np.random.choice(range(len(X_test)),
                                             size=nbr_records_explainer,
                                             replace=False)

    Xm_train = np.array([x.ravel() for x in X_test[idx_records_train_expl]])
    Xm_test = np.array([x.ravel() for x in X_test[idx_records_test_expl]])

    print(datetime.datetime.now(), 'build maple')
    maple_explainer = MAPLE(Xm_train,
                            Y_test_proba[idx_records_train_expl][:, 1],
                            Xm_test,
                            Y_test_proba[idx_records_test_expl][:, 1],
                            n_estimators=100,
                            max_features=0.5,
                            min_samples_leaf=2)
    print(datetime.datetime.now(), 'build maple done')

    idx = 0
    results = list()
    for x, y in zip(X_test, Y_test):
        print(datetime.datetime.now(),
              'seneca - image',
              'black_box %s' % black_box,
              'n_features %s' % str(n_features),
              'rs %s' % random_state,
              '%s/%s' % (idx, n_records),
              end=' ')

        gt_val = get_pixel_importance_explanation(x, sic)

        lime_exp = lime_explainer.explain_instance(x,
                                                   predict_proba,
                                                   top_labels=2,
                                                   hide_color=0,
                                                   num_samples=10000,
                                                   segmentation_fn=segmenter)
        _, lime_expl_val = lime_exp.get_image_and_mask(
            y,
            positive_only=True,
            num_features=tot_num_features,
            hide_rest=False,
            min_weight=0.0)

        shap_expl_val = shap_explainer.shap_values(x.ravel(), l1_reg='bic')[1]
        shap_expl_val = np.sum(np.reshape(shap_expl_val, img_size), axis=2)
        tmp = np.zeros(shap_expl_val.shape)
        tmp[np.where(shap_expl_val > 0.0)] = 1.0
        shap_expl_val = tmp

        maple_exp = maple_explainer.explain(x)
        maple_expl_val = maple_exp['coefs'][:-1]
        maple_expl_val = np.sum(np.reshape(maple_expl_val, img_size), axis=2)
        tmp = np.zeros(maple_expl_val.shape)
        tmp[np.where(maple_expl_val > 0.0)] = 1.0
        maple_expl_val = tmp

        lime_f1, lime_pre, lime_rec = pixel_based_similarity(
            lime_expl_val.ravel(), gt_val, ret_pre_rec=True)
        shap_f1, shap_pre, shap_rec = pixel_based_similarity(
            shap_expl_val.ravel(), gt_val, ret_pre_rec=True)
        maple_f1, maple_pre, maple_rec = pixel_based_similarity(
            maple_expl_val.ravel(), gt_val, ret_pre_rec=True)

        res = {
            'black_box': black_box,
            'n_records': n_records,
            'img_size': '"%s"' % str(img_size),
            'cell_size': '"%s"' % str(cell_size),
            'n_features': '"%s"' % str(n_features),
            'random_state': random_state,
            'idx': idx,
            'lime_f1': lime_f1,
            'lime_pre': lime_pre,
            'lime_rec': lime_rec,
            'shap_f1': shap_f1,
            'shap_pre': shap_pre,
            'shap_rec': shap_rec,
            'maple_f1': maple_f1,
            'maple_pre': maple_pre,
            'maple_rec': maple_rec,
            'p_border': p_border
        }
        results.append(res)
        print('lime %.2f' % lime_f1, 'shap %.2f' % shap_f1,
              'maple %.2f' % maple_f1)

        idx += 1

    df = pd.DataFrame(data=results)
    df = df[[
        'black_box', 'n_records', 'img_size', 'cell_size', 'n_features',
        'random_state', 'idx', 'lime_f1', 'lime_pre', 'lime_rec', 'shap_f1',
        'shap_pre', 'shap_rec', 'maple_f1', 'maple_pre', 'maple_rec',
        'p_border'
    ]]
    # print(df.head())

    if not os.path.isfile(filename):
        df.to_csv(filename, index=False)
    else:
        df.to_csv(filename, mode='a', index=False, header=False)
def run(black_box, n_records, n_features, random_state, filename):

    X_train = fetch_20newsgroups(subset='train',
                                 remove=('headers', 'footers', 'quotes'),
                                 categories=None).data
    X_train = preprocess_data(X_train)

    stc = generate_synthetic_text_classifier(X_train,
                                             n_features=n_features,
                                             random_state=random_state)

    predict_proba = stc['predict_proba']
    vectorizer = stc['vectorizer']
    nbr_terms = stc['nbr_terms']

    X_test = fetch_20newsgroups(subset='test',
                                remove=('headers', 'footers', 'quotes'),
                                categories=None).data
    X_test = preprocess_data(X_test)
    X_test_nbrs = vectorizer.transform(X_test).toarray()
    Y_test = predict_proba(X_test_nbrs)

    lime_explainer = LimeTextExplainer(class_names=[0, 1])

    print(datetime.datetime.now(), 'build shap')
    reference = get_reference4shap(X_test_nbrs,
                                   stc['words_vec'],
                                   nbr_terms,
                                   nbr_references=10)
    shap_explainer = KernelExplainer(predict_proba, reference)
    print(datetime.datetime.now(), 'build shap done')

    # print(idx_records_train_expl)
    # print(X_test_nbrs[idx_records_train_expl])
    # print(Y_test[idx_records_train_expl][:, 1])
    # print(np.any(np.isnan(X_test_nbrs[idx_records_train_expl])))
    # print(np.any(np.isnan(X_test_nbrs[idx_records_test_expl])))
    # print(np.any(np.isnan(Y_test[idx_records_train_expl][:, 1])))
    # print(np.any(np.isnan(Y_test[idx_records_test_expl][:, 1])))

    nbr_records_explainer = 100
    idx_records_train_expl = np.random.choice(range(len(X_test)),
                                              size=nbr_records_explainer,
                                              replace=False)
    idx_records_test_expl = np.random.choice(range(len(X_test)),
                                             size=nbr_records_explainer,
                                             replace=False)

    print(datetime.datetime.now(), 'build maple')
    maple_explainer = MAPLE(X_test_nbrs[idx_records_train_expl],
                            Y_test[idx_records_train_expl][:, 1],
                            X_test_nbrs[idx_records_test_expl],
                            Y_test[idx_records_test_expl][:, 1],
                            n_estimators=100,
                            max_features=0.5,
                            min_samples_leaf=2)
    print(datetime.datetime.now(), 'build maple done')

    results = list()
    explained = 0
    for idx, x in enumerate(X_test):
        x_nbrs = X_test_nbrs[idx]

        print(datetime.datetime.now(),
              'seneca - text',
              'black_box %s' % black_box,
              'n_features %s' % n_features,
              'rs %s' % random_state,
              '%s/%s' % (idx, n_records),
              end=' ')

        gt_val_text = get_word_importance_explanation_text(x, stc)
        gt_val = get_word_importance_explanation(x_nbrs, stc)

        try:
            lime_exp = lime_explainer.explain_instance(x,
                                                       predict_proba,
                                                       num_features=n_features)
            lime_expl_val = {e[0]: e[1] for e in lime_exp.as_list()}

            shap_expl_val = shap_explainer.shap_values(x_nbrs, l1_reg='bic')[1]

            maple_exp = maple_explainer.explain(x_nbrs)
            maple_expl_val = maple_exp['coefs'][:-1]
        except ValueError:
            print(datetime.datetime.now(), 'Error in explanation')
            continue

        lime_cs = word_based_similarity_text(lime_expl_val,
                                             gt_val_text,
                                             use_values=True)
        lime_f1, lime_pre, lime_rec = word_based_similarity_text(
            lime_expl_val, gt_val_text, use_values=False, ret_pre_rec=True)

        shap_cs = word_based_similarity(shap_expl_val, gt_val, use_values=True)
        shap_f1, shap_pre, shap_rec = word_based_similarity(shap_expl_val,
                                                            gt_val,
                                                            use_values=False,
                                                            ret_pre_rec=True)

        maple_cs = word_based_similarity(maple_expl_val,
                                         gt_val,
                                         use_values=True)
        maple_f1, maple_pre, maple_rec = word_based_similarity(
            maple_expl_val, gt_val, use_values=False, ret_pre_rec=True)

        # print(gt_val)
        # print(lime_expl_val)
        # print(shap_expl_val)
        # print(maple_expl_val)

        res = {
            'black_box': black_box,
            'n_records': n_records,
            'nbr_terms': nbr_terms,
            'n_features': n_features,
            'random_state': random_state,
            'idx': idx,
            'lime_cs': lime_cs,
            'lime_f1': lime_f1,
            'lime_pre': lime_pre,
            'lime_rec': lime_rec,
            'shap_cs': shap_cs,
            'shap_f1': shap_f1,
            'shap_pre': shap_pre,
            'shap_rec': shap_rec,
            'maple_cs': maple_cs,
            'maple_f1': maple_f1,
            'maple_pre': maple_pre,
            'maple_rec': maple_rec,
        }
        results.append(res)
        print('lime %.2f %.2f' % (lime_cs, lime_f1),
              'shap %.2f %.2f' % (shap_cs, shap_f1),
              'maple %.2f %.2f' % (maple_cs, maple_f1))

        explained += 1
        if explained >= n_records:
            break

    df = pd.DataFrame(data=results)
    df = df[[
        'black_box',
        'n_records',
        'nbr_terms',
        'n_features',
        'random_state',
        'idx',
        'lime_cs',
        'lime_f1',
        'lime_pre',
        'lime_rec',
        'shap_cs',
        'shap_f1',
        'shap_pre',
        'shap_rec',
        'maple_cs',
        'maple_f1',
        'maple_pre',
        'maple_rec',
    ]]
    # print(df.head())

    if not os.path.isfile(filename):
        df.to_csv(filename, index=False)
    else:
        df.to_csv(filename, mode='a', index=False, header=False)
コード例 #14
0
# Set this flag to True and re-run the notebook to see the SHAP plots
shap_enabled = True

# COMMAND ----------

if shap_enabled:
    from shap import KernelExplainer, summary_plot
    # Sample background data for SHAP Explainer. Increase the sample size to reduce variance.
    train_sample = X_train.sample(n=min(100, len(X_train.index)))

    # Sample a single example from the validation set to explain. Increase the sample size and rerun for more thorough results.
    example = X_val.sample(n=1)

    # Use Kernel SHAP to explain feature importance on the example from the validation set.
    predict = lambda x: model.predict(pd.DataFrame(x, columns=X_train.columns))
    explainer = KernelExplainer(predict, train_sample, link="identity")
    shap_values = explainer.shap_values(example, l1_reg=False)
    summary_plot(shap_values, example)

# COMMAND ----------

# MAGIC %md
# MAGIC # Hyperopt
# MAGIC
# MAGIC Hyperopt is a Python library for "serial and parallel optimization over awkward search spaces, which may include real-valued, discrete, and conditional dimensions"... In simpler terms, its a python library for hyperparameter tuning.
# MAGIC
# MAGIC There are two main ways to scale hyperopt with Apache Spark:
# MAGIC * Use single-machine hyperopt with a distributed training algorithm (e.g. MLlib)
# MAGIC * Use distributed hyperopt with single-machine training algorithms with the SparkTrials class.
# MAGIC
# MAGIC Resources:
コード例 #15
0
 def __init__(self, model, training_data):
     super(ShapExplainer, self).__init__(model, training_data)
     data = self.preprocessor.transform(training_data)
     background_data = kmeans(data, 10)
     self.explainer = KernelExplainer(model=self.model.predict_proba,
                                      data=background_data)
def run(black_box, n_records, n_all_features, n_features, n_coefficients,
        random_state, filename):

    n = n_records
    m = n_all_features

    slc = generate_synthetic_linear_classifier2(n_features=n_features,
                                                n_all_features=n_all_features,
                                                n_coefficients=n_coefficients,
                                                random_state=random_state)

    feature_names = slc['feature_names']
    class_values = slc['class_values']
    predict_proba = slc['predict_proba']
    # predict = slc['predict']

    X_test = np.random.uniform(size=(n, n_all_features))
    Xz = list()
    for x in X_test:
        nz = np.random.randint(0, n_features)
        zeros_idx = np.random.choice(np.arange(n_features),
                                     size=nz,
                                     replace=False)
        x[zeros_idx] = 0.0
        Xz.append(x)
    X_test = np.array(Xz)

    Y_test = predict_proba(X_test)[:, 1]

    lime_explainer = LimeTabularExplainer(X_test,
                                          feature_names=feature_names,
                                          class_names=class_values,
                                          discretize_continuous=False,
                                          discretizer='entropy')

    reference = np.zeros(m)
    shap_explainer = KernelExplainer(
        predict_proba, np.reshape(reference, (1, len(reference))))

    maple_explainer = MAPLE(X_test, Y_test, X_test, Y_test)

    results = list()
    for idx, x in enumerate(X_test):
        gt_val = get_feature_importance_explanation2(x,
                                                     slc,
                                                     n_features,
                                                     n_all_features,
                                                     get_values=True)
        gt_val_bin = get_feature_importance_explanation2(x,
                                                         slc,
                                                         n_features,
                                                         n_all_features,
                                                         get_values=False)

        lime_exp = lime_explainer.explain_instance(x,
                                                   predict_proba,
                                                   num_features=m)
        lime_expl_val = np.array([e[1] for e in lime_exp.as_list()])
        tmp = np.zeros(lime_expl_val.shape)
        tmp[np.where(lime_expl_val != 0.0)] = 1.0
        lime_expl_val_bin = tmp

        shap_expl_val = shap_explainer.shap_values(x, l1_reg='bic')[1]
        tmp = np.zeros(shap_expl_val.shape)
        tmp[np.where(shap_expl_val != 0.0)] = 1.0
        shap_expl_val_bin = tmp

        maple_exp = maple_explainer.explain(x)
        maple_expl_val = maple_exp['coefs'][:-1]
        tmp = np.zeros(maple_expl_val.shape)
        tmp[np.where(maple_expl_val != 0.0)] = 1.0
        maple_expl_val_bin = tmp

        lime_fis = feature_importance_similarity(lime_expl_val, gt_val)
        shap_fis = feature_importance_similarity(shap_expl_val, gt_val)
        maple_fis = feature_importance_similarity(maple_expl_val, gt_val)

        lime_rbs = rule_based_similarity(lime_expl_val_bin, gt_val_bin)
        shap_rbs = rule_based_similarity(shap_expl_val_bin, gt_val_bin)
        maple_rbs = rule_based_similarity(maple_expl_val_bin, gt_val_bin)

        # print(gt_val)
        # print(lime_expl_val)
        # print(shap_expl_val)
        # print(maple_expl_val)

        res = {
            'black_box': black_box,
            'n_records': n_records,
            'n_all_features': n_all_features,
            'n_features': n_features,
            'n_coefficients': n_coefficients,
            'random_state': random_state,
            'idx': idx,
            'lime_cs': lime_fis,
            'shap_cs': shap_fis,
            'maple_cs': maple_fis,
            'lime_f1': lime_rbs,
            'shap_f1': shap_rbs,
            'maple_f1': maple_rbs,
        }
        results.append(res)
        print(datetime.datetime.now(), 'syege - tlsb2',
              'black_box %s' % black_box, 'n_all_features %s' % n_all_features,
              'n_features %s' % n_features,
              'n_coefficients % s' % n_coefficients, 'rs %s' % random_state,
              '%s %s' % (idx, n_records),
              'lime %.2f %.2f' % (lime_fis, lime_rbs),
              'shap %.2f %.2f' % (shap_fis, shap_rbs),
              'maple %.2f %.2f' % (maple_fis, maple_rbs))

    df = pd.DataFrame(data=results)
    df = df[[
        'black_box', 'n_records', 'n_all_features', 'n_features',
        'n_coefficients', 'random_state', 'idx', 'lime_cs', 'shap_cs',
        'maple_cs', 'lime_f1', 'shap_f1', 'maple_f1'
    ]]
    # print(df.head())

    if not os.path.isfile(filename):
        df.to_csv(filename, index=False)
    else:
        df.to_csv(filename, mode='a', index=False, header=False)
コード例 #17
0
def run(black_box, n_records, n_all_features, n_features, random_state, filename):

    n = n_records
    m = n_all_features

    p_binary = 0.7
    p_parenthesis = 0.3

    slc = generate_synthetic_linear_classifier(expr=None, n_features=n_features, n_all_features=m,
                                               random_state=random_state,
                                               p_binary=p_binary, p_parenthesis=p_parenthesis)
    expr = slc['expr']

    X = slc['X']
    feature_names = slc['feature_names']
    class_values = slc['class_values']
    predict_proba = slc['predict_proba']
    # predict = slc['predict']

    X_test = np.random.uniform(np.min(X), np.max(X), size=(n, m))
    Y_test = predict_proba(X_test)[:, 1]

    lime_explainer = LimeTabularExplainer(X_test, feature_names=feature_names, class_names=class_values,
                                          discretize_continuous=False, discretizer='entropy')

    reference = np.zeros(m)
    shap_explainer = KernelExplainer(predict_proba, np.reshape(reference, (1, len(reference))))

    maple_explainer = MAPLE(X_test, Y_test, X_test, Y_test)

    results = list()
    for idx, x in enumerate(X_test):
        gt_val = get_feature_importance_explanation(x, slc, n_features, get_values=True)

        lime_exp = lime_explainer.explain_instance(x, predict_proba, num_features=m)
        # lime_exp_as_dict = {e[0]: e[1] for e in lime_exp.as_list()}
        # lime_expl_val = np.asarray([lime_exp_as_dict.get('x%s' % i, .0) for i in range(m)])
        lime_expl_val = np.array([e[1] for e in lime_exp.as_list()])

        shap_expl_val = shap_explainer.shap_values(x, l1_reg='bic')[1]

        maple_exp = maple_explainer.explain(x)
        maple_expl_val = maple_exp['coefs'][:-1]

        lime_fis = feature_importance_similarity(lime_expl_val, gt_val)
        shap_fis = feature_importance_similarity(shap_expl_val, gt_val)
        maple_fis = feature_importance_similarity(maple_expl_val, gt_val)

        # print(gt_val)
        # print(lime_expl_val)
        # print(shap_expl_val)
        # print(maple_expl_val)

        res = {
            'black_box': black_box,
            'n_records': n_records,
            'n_all_features': n_all_features,
            'n_features': n_features,
            'random_state': random_state,
            'idx': idx,
            'lime': lime_fis,
            'shap': shap_fis,
            'maple': maple_fis,
            'expr': expr,
        }
        results.append(res)
        print(datetime.datetime.now(), 'syege - tlsb', 'black_box %s' % black_box,
              'n_all_features %s' % n_all_features, 'n_features %s' % n_features, 'rs %s' % random_state,
              '%s %s' % (idx, n_records), expr,
              'lime %.2f' % lime_fis, 'shap %.2f' % shap_fis, 'maple %.2f' % maple_fis)

        if idx > 0 and idx % 10 == 0:
            df = pd.DataFrame(data=results)
            df = df[['black_box', 'n_records', 'n_all_features', 'n_features', 'random_state', 'expr',
                     'idx', 'lime', 'shap', 'maple']]
            # print(df.head())

            if not os.path.isfile(filename):
                df.to_csv(filename, index=False)
            else:
                df.to_csv(filename, mode='a', index=False, header=False)
            results = list()

    df = pd.DataFrame(data=results)
    df = df[['black_box', 'n_records', 'n_all_features', 'n_features', 'random_state', 'expr',
             'idx', 'lime', 'shap', 'maple']]
    # print(df.head())

    if not os.path.isfile(filename):
        df.to_csv(filename, index=False)
    else:
        df.to_csv(filename, mode='a', index=False, header=False)
コード例 #18
0
    def data_shift_automatic(self,
                             fraction_shift=0.3,
                             fraction_points=0.1,
                             sample=True):
        """
        Performing a dataset shift. For a subset of the instances, a subset of
        the features is shifted upwards. The features to shift are selected
        according to their Shapley value, so that the dataset shift will likely
        impact the predictions of the model applied to the shifted data.

        Parameters
        ----------

        fraction_shift : float
            The fraction of variables to shift (optional).

        fraction_points : float
            The fraction of points from the total dataset (so train and test
            combined) to be shifted (optional).

        sample : bool
            Whether a subsample of the points should be used to select the most
            important features in the dataset, to speed up calculations
            (optional).

        """

        # fit a linear regression to the data
        reg = LinearRegression().fit(self.X, self.y)

        # optionally take a sample of the points to speed up calculations
        if sample:
            X = self.X[np.random.randint(self.X.shape[0], size=100), :]
        else:
            X = self.X

        # build a Shapley explainer on the regression and get SHAP values
        explainer = KernelExplainer(reg.predict, X)
        shap_values = explainer.shap_values(X, nsamples=20, l1_reg="aic")

        # determine most important variables by SHAP value on average
        avg_shap = np.average(np.absolute(shap_values), axis=0).flatten()

        # get number of features to shift
        shift_count = int(fraction_shift * self.X.shape[1])

        # get indices of most important features
        shift_fts = avg_shap.argsort()[::-1][:shift_count]

        # new array for new data with shifted features
        shifted_X = np.zeros_like(self.X)

        # number of points to shift
        ix = int((1 - fraction_points) * self.X.shape[0])

        # shift feature by feature
        for f_ix in range(self.X.shape[1]):

            # place original feature in matrix
            shifted_X[:, f_ix] = self.X[:, f_ix]

            # check if feature has to be shifted
            if f_ix in shift_fts:

                # get feature from data
                ft = self.X[ix:, f_ix]

                # determine the maximum of this feature
                max_f = np.max(ft)

                # shift feature upward according to a Gaussian distribution
                shifted_X[ix:, f_ix] = ft + np.random.normal(
                    max_f, abs(.1 * max_f), shifted_X[ix:, f_ix].shape[0])

        # store the (partially) shifted features back
        self.X = shifted_X
コード例 #19
0
def main():
    n_features = (8, 8)
    img_size = (32, 32, 3)
    cell_size = (4, 4)
    colors_p = np.array([0.15, 0.7, 0.15])
    p_border = 0.0

    sic = generate_synthetic_image_classifier(img_size=img_size,
                                              cell_size=cell_size,
                                              n_features=n_features,
                                              p_border=p_border)

    pattern = sic['pattern']
    predict = sic['predict']
    predict_proba = sic['predict_proba']

    plt.imshow(pattern)
    plt.show()

    X_test = generate_random_img_dataset(pattern,
                                         nbr_images=1000,
                                         pattern_ratio=0.4,
                                         img_size=img_size,
                                         cell_size=cell_size,
                                         min_nbr_cells=0.1,
                                         max_nbr_cells=0.3,
                                         colors_p=colors_p)

    Y_test = predict(X_test)

    background = np.array([np.zeros(img_size).ravel()] * 10)
    explainer = KernelExplainer(predict_proba, background)

    x = X_test[-1]
    plt.imshow(x)
    plt.show()
    print(Y_test[-1])
    expl_val = explainer.shap_values(x.ravel())[1]
    # expl_val = (expl_val - np.min(expl_val)) / (np.max(expl_val) - np.min(expl_val))
    print(expl_val)
    print(np.unique(expl_val, return_counts=True))
    print(expl_val.shape)

    sv = np.sum(np.reshape(expl_val, img_size), axis=2)
    sv01 = np.zeros(sv.shape)
    sv01[np.where(sv > 0.0)] = 1.0
    # np.array([1.0 if v > 0.0 else 0.0 for v in expl_val])
    sv = sv01
    print(sv)
    print(sv.shape)

    max_val = np.nanpercentile(np.abs(sv), 99.9)
    # plt.imshow(x)
    plt.imshow(sv, cmap='RdYlBu', vmin=-max_val, vmax=max_val, alpha=0.7)
    plt.show()
    # shap.image_plot(expl_val, x)

    gt_val = get_pixel_importance_explanation(x, sic)
    print(gt_val.shape)
    max_val = np.nanpercentile(np.abs(gt_val), 99.9)
    # plt.imshow(x)
    plt.imshow(np.reshape(gt_val, img_size[:2]),
               cmap='RdYlBu',
               vmin=-max_val,
               vmax=max_val,
               alpha=0.7)
    plt.show()

    print(pixel_based_similarity(sv.ravel(), gt_val))