Esempio n. 1
0
    def display_model_analysis(self):
        """
        Displays information about the model used : class name, library name, library version,
        model parameters, ...
        """
        print_md(f"**Model used :** {self.explainer.model.__class__.__name__}")

        print_md(f"**Library :** {self.explainer.model.__class__.__module__}")

        for name, module in sorted(sys.modules.items()):
            if hasattr(module, '__version__') \
                    and self.explainer.model.__class__.__module__.split('.')[0] == module.__name__:
                print_md(f"**Library version :** {module.__version__}")

        print_md("**Model parameters :** ")
        model_params = self.explainer.model.__dict__
        table_template = template_env.get_template("double_table.html")
        print_html(
            table_template.render(
                columns1=["Parameter key", "Parameter value"],
                rows1=[{
                    "name": truncate_str(str(k), 50),
                    "value": truncate_str(str(v), 300)
                } for k, v in list(model_params.items())[:len(model_params) //
                                                         2:]
                       ],  # Getting half of the parameters
                columns2=["Parameter key", "Parameter value"],
                rows2=[{
                    "name": truncate_str(str(k), 50),
                    "value": truncate_str(str(v), 300)
                } for k, v in list(model_params.items())[len(model_params) //
                                                         2:]
                       ]  # Getting 2nd half of the parameters
            ))
        print_md('---')
Esempio n. 2
0
 def test_truncate_str_3(self):
     t = truncate_str("this is a test", 10)
     assert t == "this is a..."
Esempio n. 3
0
 def test_truncate_str_2(self):
     t = truncate_str("this is a test", 50)
     assert t == "this is a test"
Esempio n. 4
0
 def test_truncate_str_1(self):
     t = truncate_str(12)
     assert t == 12
Esempio n. 5
0
    def make_skeleton(self):
        """
        Describe the app skeleton (bootstrap grid) and initialize components containers
        """
        self.skeleton['navbar'] = dbc.Container(
            [
                dbc.Row(
                    [
                        dbc.Col(html.A(
                            dbc.Row(
                                [
                                    html.Img(src=self.logo, height="40px"),
                                    html.H4("Shapash Monitor",
                                            id="shapash_title"),
                                ],
                                align="center",
                            ),
                            href="https://github.com/MAIF/shapash",
                            target="_blank",
                        ),
                                md=4,
                                align="left"),
                        dbc.Col(html.A(
                            dbc.Row(
                                [
                                    html.H3(truncate_str(
                                        self.explainer.title_story, maxlen=40),
                                            id="shapash_title_story"),
                                ],
                                align="center",
                            ),
                            href="https://github.com/MAIF/shapash",
                            target="_blank",
                        ),
                                md=4,
                                align="center"),
                        dbc.Col(
                            self.components['menu'],
                            md=4,
                            align='right',
                        )
                    ],
                    style={
                        'padding': "5px 15px",
                        "verticalAlign": "middle"
                    },
                )
            ],
            fluid=True,
            style={
                'height': '50px',
                'backgroundColor': self.bkg_color
            },
        )

        self.skeleton['body'] = dbc.Container(
            [
                dbc.Row(
                    [
                        dbc.Col(
                            [
                                html.Div(
                                    self.draw_component(
                                        'graph', 'global_feature_importance'),
                                    className="card",
                                    id="card_global_feature_importance",
                                )
                            ],
                            md=5,
                            align="center",
                            style={'padding': '0px 10px'},
                        ),
                        dbc.Col(
                            [
                                html.Div(
                                    self.draw_component('table', 'dataset'),
                                    className="card",
                                    id='card_dataset',
                                    style={'cursor': 'pointer'},
                                )
                            ],
                            md=7,
                            align="center",
                            style={'padding': '0px 10px'},
                        ),
                    ],
                    style={'padding': '15px 10px 0px 10px'},
                ),
                dbc.Row(
                    [
                        dbc.Col(
                            [
                                html.Div(
                                    self.draw_component(
                                        'graph', 'feature_selector'),
                                    className="card",
                                    id='card_feature_selector',
                                )
                            ],
                            md=5,
                            align="center",
                            style={'padding': '0px 10px'},
                        ),
                        dbc.Col(
                            [
                                dbc.Row(
                                    [
                                        dbc.Col(
                                            [
                                                html.Div(
                                                    self.draw_component(
                                                        'graph',
                                                        'detail_feature'),
                                                    className="card",
                                                    id='card_detail_feature',
                                                ),
                                            ],
                                            md=8,
                                            align="center",
                                            # style={'height': '27rem'}
                                        ),
                                        dbc.Col(
                                            [
                                                html.Div(
                                                    self.draw_filter(),
                                                    className="card_filter",
                                                    id='card_filter',
                                                ),
                                            ],
                                            md=4,
                                            align="center",
                                        ),
                                    ], ),
                            ],
                            md=7,
                            align="center",
                            style={'padding': '0px 10px'},
                        ),
                    ],
                    style={'padding': '15px 10px'},
                ),
            ],
            className="mt-12",
            fluid=True)
Esempio n. 6
0
def generate_fig_univariate_categorical(
    df_all: pd.DataFrame,
    col: str,
    hue: str,
    nb_cat_max: int = 7,
) -> plt.Figure:
    """
    Returns a matplotlib figure containing the distribution of a categorical feature.

    If the feature is categorical and contains too many categories, the smallest
    categories are grouped into a new 'Other' category so that the graph remains
    readable.

    Parameters
    ----------
    df_all : pd.DataFrame
        The input dataframe that contains the column of interest
    col : str
        The column of interest
    hue : str
        The column used to distinguish the values (ex. 'train' and 'test')
    nb_cat_max : int
        The number max of categories to be displayed. If the number of categories
        is greater than nb_cat_max then groups smallest categories into a new
        'Other' category

    Returns
    -------
    matplotlib.pyplot.Figure
    """
    df_cat = df_all.groupby([col, hue]).agg({col: 'count'})\
                   .rename(columns={col: "count"}).reset_index()
    df_cat['Percent'] = df_cat['count'] * 100 / df_cat.groupby(
        hue)['count'].transform('sum')

    if pd.api.types.is_numeric_dtype(df_cat[col].dtype):
        df_cat = df_cat.sort_values(col, ascending=True)
        df_cat[col] = df_cat[col].astype(str)

    nb_cat = df_cat.groupby([col]).agg({
        'count': 'sum'
    }).reset_index()[col].nunique()

    if nb_cat > nb_cat_max:
        df_cat = _merge_small_categories(df_cat=df_cat,
                                         col=col,
                                         hue=hue,
                                         nb_cat_max=nb_cat_max)

    fig, ax = plt.subplots(figsize=(7, 4))

    sns.barplot(data=df_cat,
                x='Percent',
                y=col,
                hue=hue,
                palette=dict_color_palette,
                ax=ax)

    for p in ax.patches:
        ax.annotate("{:.1f}%".format(np.nan_to_num(p.get_width(), nan=0)),
                    xy=(p.get_width(), p.get_y() + p.get_height() / 2),
                    xytext=(5, 0),
                    textcoords='offset points',
                    ha="left",
                    va="center")

    # Shrink current axis by 20%
    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])

    # Put a legend to the right of the current axis
    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))

    # Removes plot borders
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    new_labels = [
        truncate_str(i.get_text(), maxlen=45)
        for i in ax.yaxis.get_ticklabels()
    ]
    ax.yaxis.set_ticklabels(new_labels)

    return fig