Esempio n. 1
0
    def decisions_multiclass(
        df_preds,
        shap_values,
        expected_value,
        X_vald,
        y_vald,
        model_file_path,
        learner_name,
        class_names,
    ):

        for decision_type in ["worst", "best"]:
            m = 1 if decision_type == "worst" else -1
            for i in range(4):

                fig = plt.gcf()
                shap.multioutput_decision_plot(
                    list(expected_value),
                    shap_values,
                    row_index=df_preds.lp.iloc[m * i],
                    show=False,
                    legend_labels=class_names,
                    title=
                    f"It should be {class_names[df_preds.target.iloc[m*i]]}",
                )
                fig.tight_layout(pad=2.0)
                fig.savefig(
                    os.path.join(
                        model_file_path,
                        f"{learner_name}_sample_{i}_{decision_type}_decisions.png",
                    ))
                plt.close("all")
Esempio n. 2
0
def plot_prediction_desicion(model, X_test, pred, row_idx):
    #The decision plot below shows the model’s multiple outputs for a single observation
    #the dashed line is the prediction of our classifier
    explainer = shap.TreeExplainer(model, data=shap.sample(X_test, 100), feature_dependence="interventional")
    shap_values = explainer.shap_values(X_test)
    shap.multioutput_decision_plot([1, 2, 3], shap_values,
                                   row_index=row_idx,
                                   feature_names=list(X_test.columns) ,
                                   highlight=int(pred[row_idx]),
                                   legend_labels=["0-18", "19-70", "70+"], #legend_labels=["0-23", "24-50", "50+"],
                                   legend_location='lower right')
    plt.show()
Esempio n. 3
0
st.subheader('Decision plot')


def class_labels(row_index):
    return [
        f'{model.classes_[i]} (pred: {input_preds_proba[i].round(2)})'
        for i in range(len(expected_value))
    ]


decision_plot, ax = plt.subplots()
ax = shap.multioutput_decision_plot(
    expected_value,
    input_shap_values,
    row_index=0,
    feature_names=eval_set_features,
    legend_labels=class_labels(0),
    legend_location='lower right',
    #                               link='logit',
    highlight=np.argmax(input_preds_proba))  # Highlight the predicted class

st.pyplot(decision_plot)

# Decision plot expander explanations
with st.beta_expander("More on decision plots"):
    st.markdown("""
     Just like the force plot, the [**decision plot**](https://slundberg.github.io/shap/notebooks/plots) shows how each feature has contributed in moving away or towards the base value (the grey line, aka. the average model output on the evaluation dataset) to the predicted value of the specific instance (inputed on the left side bar), but allows us to visualize those effects **for each class**.
It also show the impact of less influencial features more clearly.

From SHAP documentation:
- *The x-axis represents the model's output. In this case, the units are log odds. (SHAP doesn't support probability output for multiclass)*
Esempio n. 4
0
    shap.multioutput_decision_plot(
        base_values=list(shap_model.expected_value),
        shap_values=list(shap_model.shap_values),
        row_index=row_index,
        features=[
            f'F0: {state[0]}', f'F1: {state[1]}', f'F2: {state[2]}',
            f'F3: {state[3]}', f'F4: {state[4]}', f'F5: {state[5]}',
            f'F6: {state[6]}', f'F7: {state[7]}'
        ],  #None,
        feature_names=None,
        feature_order=None,  #"importance", #
        feature_display_range=None,
        highlight=None,  #[np.argmax(predictions[row_index])],#
        link="identity",
        plot_color=plt.get_cmap("tab20c"),  #None,
        axis_color="#333333",
        y_demarc_color="#333333",
        alpha=None,
        color_bar=False,  #True,
        auto_size_plot=True,
        #title=f'F0:{possibilits[row_index][0]} | F1:{possibilits[row_index][1]} | F2:{possibilits[row_index][2]} | Action={actions[row_index]}',#None,
        xlim=None,
        show=False,  #True,
        return_objects=False,
        ignore_warnings=False,
        new_base_value=None,
        legend_labels=None,
        legend_location="best",
    )
Esempio n. 5
0
        shap.dependence_plot(feature,
                             shap_values[j],
                             x.iloc[:EXPLAIN_OBSERVATION_NO, :],
                             show=False)
        plt.title(f'SHAP dependece {x.columns[feature]}', size=18)
        plt.tight_layout()
        plt.savefig(os.path.join('reports', 'explain', 'dependence',
                                 f'{model.classes_[j]}',
                                 f'SHAP-dependence{x.columns[feature]}.png'),
                    dpi=300)
        plt.show()

y_pred_all = model.predict(x)
for row_index in range(EXPLAIN_TRAJECTORIES_NO):
    plt.figure()
    shap.multioutput_decision_plot(list(base_value),
                                   shap_values,
                                   row_index=row_index,
                                   feature_names=list(x.columns),
                                   legend_labels=list(model.classes_),
                                   legend_location='upper left',
                                   show=False)
    fig.set_size_inches(6, 8)
    plt.title(
        f"Real class: {y[row_index]}, Pred class: {y_pred_all[row_index]}")
    plt.tight_layout()
    plt.savefig(os.path.join('reports', 'explain', 'decision-plot',
                             f'SHAP-decision{row_index}.png'),
                dpi=300)
    plt.show()
Esempio n. 6
0
def plot_multi_output(runs=1):
    #Create Path to save
    Path(f"{args.output_path}multioutput").mkdir(parents=True, exist_ok=True)

    def hex_to_rgb(hex_color):
        hex_color = hex_color.split('#')[1]
        rgb = list(int(hex_color[i:i + 2], 16) / 256 for i in (0, 2, 4))
        rgb.append(1)
        return rgb

    def create_tab_b(total=8, num=1):
        #cores_hex = ['#2cbdfe','#3aacf6','#489bee','#568ae6','#6379de', '#6e6cd8',
        #            '#7168d7', '#7f57cf','#8d46c7','#9739c1','#9b35bf','#a924b7', '#b317b1']

        #cores_hex = ['#322e2f','#12a4d9','#12a4d9','#b20238','#d9138a','#e2d810','#fbcbc9','#6b7b8c']
        cores_hex = [
            '#322e2f', '#375f9f', '#12a4d9', '#3caea3', '#b9d604', '#f6d55c',
            '#ed553b', '#b20238'
        ]
        colors = []
        sub_total = 0
        jump_total = 0
        while sub_total < total:
            for sub_num in range(num):
                if sub_total >= total:
                    break
                colors.append(hex_to_rgb(cores_hex[jump_total]))
                sub_total += 1
            jump_total += 1

        return colors

    def add_subplot_axes(ax, rect, axisbg='w'):
        fig = plt.gcf()
        box = ax.get_position()
        width = box.width
        height = box.height
        inax_position = ax.transAxes.transform(rect[0:2])
        transFigure = fig.transFigure.inverted()
        infig_position = transFigure.transform(inax_position)
        x = infig_position[0]
        y = infig_position[1]
        width *= rect[2]
        height *= rect[3]  # <= Typo was here

        subax = fig.add_axes([x, y, width, height])
        x_labelsize = subax.get_xticklabels()[0].get_size()
        y_labelsize = subax.get_yticklabels()[0].get_size()
        x_labelsize *= rect[2]**0.5
        y_labelsize *= rect[3]**0.5
        subax.xaxis.set_tick_params(labelsize=x_labelsize)
        subax.yaxis.set_tick_params(labelsize=y_labelsize)
        #subax.axis('off')

        plt.title("Actions")

        hexbins = ax.hexbin(list(range(0, 8)), [-1] * 8,
                            C=colors,
                            bins=8,
                            gridsize=8,
                            cmap=cm)
        cmin, cmax = hexbins.get_clim()
        below = 0.25 * (cmax - cmin) + cmin
        above = 0.75 * (cmax - cmin) + cmin

        cbar = fig.colorbar(hexbins, cax=subax, orientation='vertical')
        #subax.xaxis.set_ticks_position('top')

        return subax

    #red = Color("lightblue")
    #colors = list(red.range_to(Color("darkblue"),60))
    #colors = [color.rgb for color in colors]
    colors = create_tab_b()

    cm = LinearSegmentedColormap.from_list('cmap_name', colors, N=8)

    for row_index in range(runs):

        state = [np.ceil(s / .025) for s in shap_model.possibilits[row_index]]

        shap.multioutput_decision_plot(
            base_values=list(shap_model.expected_value),
            shap_values=list(shap_model.shap_values),
            row_index=row_index,
            features=[
                f'F0: {state[0]}', f'F1: {state[1]}', f'F2: {state[2]}',
                f'F3: {state[3]}', f'F4: {state[4]}', f'F5: {state[5]}',
                f'F6: {state[6]}', f'F7: {state[7]}'
            ],  #None,
            feature_names=None,
            feature_order="importance",  #None, #
            feature_display_range=None,
            highlight=None,  #[np.argmax(predictions[row_index])],# None,#
            link="identity",
            plot_color=plt.get_cmap("tab20c"),  #None,
            axis_color="#333333",
            y_demarc_color="#333333",
            alpha=None,
            color_bar=False,  #True,
            auto_size_plot=True,
            #title=f'F0:{possibilits[row_index][0]} | F1:{possibilits[row_index][1]} | F2:{possibilits[row_index][2]} | Action={actions[row_index]}',#None,
            xlim=None,
            show=False,  #True,
            return_objects=False,
            ignore_warnings=False,
            new_base_value=None,
            legend_labels=None,
            legend_location="best",
        )

        ax = plt.gca()
        fig = plt.gcf()
        fig.set_size_inches(args.isx, args.isy)
        #fig.tight_layout()  # otherwise the right y-label is slightly clipped
        plt.grid(False)
        #plt.title(f'F0:{possibilits[row_index][0]} | F1:{possibilits[row_index][1]} | F2:{possibilits[row_index][2]} | Action={actions[row_index]}')

        for idx, a in enumerate(plt.gca().get_lines()[8:]):
            a.set_color(colors[idx])
            a.set_linewidth(1)
            a.set_alpha(1)

            #Essa marca alguma linha de fase desejada
            #a.set_alpha(.5)
            if idx in [np.argmax(predictions[row_index])]:
                #a.set_alpha(1)
                a.set_linewidth(1.5)
                a.set_linestyle("-.")

        rect = [0.93, 0.05, .05, 0.7]
        ax1 = add_subplot_axes(ax, rect)

        plt.savefig(
            f"{args.output_path}multioutput/multioutput_{row_index}.pdf",
            bbox_inches='tight')
        plt.close()