示例#1
0
def waterfall_chart():
    fig = go.Figure(
        go.Waterfall(
            name="20",
            orientation="v",
            measure=[
                "relative", "relative", "total", "relative", "relative",
                "total"
            ],
            x=[
                "Sales", "Tax credits", "Net revenue", "Purchases",
                "Other expenses", "Profit before tax"
            ],
            textposition="outside",
            text=["+240,000", "+32,000", "", "-130,000", "-5,000", "Total"],
            y=[240000, 32000, 0, -130000, -5000, 0],
            connector={"line": {
                "color": "rgb(63, 63, 63)"
            }},
        ))

    fig.update_layout(margin=dict(l=30, r=30, t=10, b=10),
                      yaxis=go.layout.YAxis(
                          title=go.layout.yaxis.Title(text="Amount ($)", )))

    return fig
def waterFall():
    df2 = pd.read_csv(join(dirname(__file__), 'waterfall.csv'))

    GTI = 9561.7  # To get from database for that day
    DC_Capacity = 19501  # To get from database for that day
    Actual_Energy = 143100000  # To get from database for that day

    Total_solar_Energy = GTI * DC_Capacity
    y1 = [Total_solar_Energy]
    y2 = df2['Loss Value'] * (Total_solar_Energy / 100)
    y3 = y2.to_list()
    y4 = y1 + y3 + [0] + [0] + [Actual_Energy * -1]

    y5 = sum(y4) * -1
    y6 = y1 + y3 + [0] + [y5] + [Actual_Energy]

    fig = go.Figure(go.Waterfall(
        name="Expected", orientation="v",
        measure=["relative", "relative", "relative", "relative", "relative", "relative", "relative", "relative",
                 "relative", "relative", "relative", "relative", "relative",
                 "total", "relative", "total", "relative", "total"],
        x=['Energy hitting Panels'] + df2['Loss Mode'].to_list() + ['Expected Energy'] + ['Incidental Loses'] + [
            'Actual Energy'],
        textposition="outside",
        y=y6,
        connector={"line": {"color": "rgb(63, 63, 63)"}},
    ))
    axis = dict(
            tickfont=dict(size=10),
            color='#e6e6e6',
            gridwidth=0.5,
            gridcolor="#333",
            zerolinecolor="#333"
        )
    fig.update_layout(
        xaxis=axis,
        yaxis=axis,
        title=dict(text="Energy losses waterfall analysis", font=dict(color='#e6e6e6')),
        showlegend=True,
        legend=dict(font=dict(color='#e6e6e6')),
        plot_bgcolor='#282828',
        paper_bgcolor='#222222',
        margin=dict(l=50, r=50, t=50, b=0)
    )

    return dict(msgWaterF=T('Waterfall'), figWaterF=fig.to_json())
示例#3
0
    def plot_path_plotly(pth, title="Waterfall Plotly"):
        vals = [v for _, split, v in pth]
        after = pd.Series(vals)
        before = after.shift(1).fillna(0)
        transitions = after - before

        trace = go.Waterfall(
            orientation="v",
            measure=["relative" for _ in pth],
            x=["init"] + [str(split) for _, split, v in pth[:-1]],
            textposition="outside",
            y=transitions,
            connector={"line": {
                "color": "rgb(63, 63, 63)"
            }},
        )

        layout = go.Layout(title=title, showlegend=True)

        py.iplot(go.Figure([trace], layout), filename="basic_waterfall_chart")
                    features=interpreter.test_data[pred, sample, 2:].numpy(), 
                    feature_names=interpreter.feat_names,
                    max_display=2)

# du.visualization.shap_waterfall_plot(interpreter.explainer.expected_value[0], interpreter.feat_scores[pred, sample],
du.visualization.shap_waterfall_plot(0, interpreter.feat_scores[pred, sample],
                                     interpreter.test_data[pred, sample, 2:], interpreter.feat_names,
                                     max_display=2)

# +
fig = go.Figure()

fig.add_trace(go.Waterfall(
    y = [["initial", "q1", "q2", "q3", "total", "q1", "q2", "q3", "total"]],
    measure = ["absolute", "relative", "relative", "relative", "total", "relative", "relative", "relative", "total"],
    x = [1, 2, 3, -1, None, 1, 2, -4, None],
    base = 1000,
    orientation='h'
))

fig.add_trace(go.Waterfall(
    y = [["2016", "2017", "2017", "2017", "2017", "2018", "2018", "2018", "2018"],
        ["initial", "q1", "q2", "q3", "total", "q1", "q2", "q3", "total"]],
    measure = ["absolute", "relative", "relative", "relative", "total", "relative", "relative", "relative", "total"],
    x = [1.1, 2.2, 3.3, -1.1, None, 1.1, 2.2, -4.4, None],
    base = 1000,
    orientation='h'
))

fig.update_layout(
    waterfallgroupgap = 0.5,
示例#5
0
x = [   
  ['Revenue']*5 + ['Expenses']*3 + ['Benefits']*3 + ['Change In Int']*3 + ['Income Tax'] + ['DE'] 
  ,vrs + ['Distributable Earnings']
]

y = avg.groupby(by=['Variable'])['Value'].sum()
y = [y[vr]*-1 if vr in flip else y[vr] for vr in vrs]
y = [*y, sum(y)]

text =['${:,.2f}'.format(val) for val in y]

trace = go.Waterfall(
    orientation = "v",
    measure = ["relative" for x in x[0][:-1]] + ["total"],
    x=x,
    y=y,
    text=text,
    textposition = "outside",
)

fig = go.Figure([trace], go.Layout(height=1000))
fig

# COMMAND ----------

# MAGIC %md
# MAGIC ## More Advanced Analysis 
# MAGIC Let's look at correlations between ROTC, PVDE, ROA, etc... with the IDF Moderate Index. 
# MAGIC 
# MAGIC Below, I specify the variables that I will be looking at and filter the original dataframe to the columns and values for this analysis. 
示例#6
0
def plot_waterfall(df_cust, customer_id, n_top, thres, base_value, shaps):
    """
    Calculate waterfall based on shapley values for a given customer.
    
    params:
        df_cust :
            A customer description DataFrame.
        customer_id :
            The SK_ID_CURR value of the customer for whom application decision
            will be explained.
        n_top:
            Number of top criteria to display.
        thres:
            Threshold risk value above which a customer's loan is denied.
        base_value : 
            the aggregated base value for selected shap explainers.
        shaps :
            A numpy array containing aggregated shapley values.
            
    returns:
        The waterfall figure for selected customer.
        Loan applications with a final score below 0 are denied.
    """
    # Set data for waterfall
    df_waterfall = pd.DataFrame(shaps.T, index=df_cust.columns)
    df_waterfall.columns = ['values']
    df_waterfall['abs'] = df_waterfall['values'].apply('abs')
    df_waterfall.sort_values(by='abs', inplace=True)

    # Aggregate shap values not in top n
    df_top = df_waterfall.tail(n_top)
    df_others = pd.DataFrame(df_waterfall.iloc[:-n_top].sum(axis=0)).T
    df_others.index = [f'others (n={len(df_waterfall.iloc[:-n_top])})']
    df_waterfall = df_others.append(df_top)

    # Plot waterfall
    fig = go.Figure(
        go.Waterfall(base=base_value,
                     orientation='h',
                     y=df_waterfall.index,
                     x=df_waterfall['values']),
        layout=go.Layout(
            height=200 + (25 * n_top),
            #width=600,
            xaxis_title='Confidence score',
            yaxis_title='Criteria',
            yaxis_side='right',
            yaxis_tickfont=dict(size=10),
            margin_l=10,
            margin_r=10,
            margin_t=30,
            margin_b=10))

    # Add base value and final result on the plot
    fig.add_shape(type='line', x0=base_value, x1=base_value, y0=-1, y1=1)
    fig.add_annotation(text='Base value', x=base_value, y=0)

    final_value = df_waterfall['values'].sum() + base_value
    fig.add_shape(type='line',
                  x0=final_value,
                  x1=final_value,
                  y0=n_top,
                  y1=n_top + 1)
    fig.add_annotation(text='score = {:.3}'.format(final_value),
                       x=final_value,
                       y=n_top + 1)

    # Threshold line
    fig.add_shape(type='line',
                  x0=0,
                  x1=0,
                  y0=-1,
                  y1=n_top + 1,
                  line_color='red',
                  line_dash='dot')

    return fig
示例#7
0
    def graph_local_explanation(self, x_explain: Union[pd.Series, pd.DataFrame, np.ndarray],
                                cols: Optional[List[str]] = None, n_cols: Optional[int] = None,
                                info_values: Optional[Union[pd.DataFrame, pd.Series]] = None) -> go.Figure:
        """
        creates a waterfall plotly figure to represent the influance of each feature on the final decision for a single
        prediction of the model.

        You can filter the columns you want to see in your graph and limit the final number of columns you want to see.
        If you choose to do so the filter will be applied first and of those filtered columns at most `n_cols` will be
        kept

        :param x_explain: the example of the model this must be a dataframe with a single ow
        :param cols: the columns to keep if you want to filter (if None - default) all the columns will be kept
        :param n_cols: the number of columns to limit the graph to. (if None - default) all the columns will be kept

        :raises ValueError: if x_explain doesn't have the right shape
        """
        if x_explain.shape[0] != 1:
            raise ValueError('can only explain single observations, if you only have one sample, use reshape(1, -1)')

        info_values = x_explain if info_values is None else info_values

        info_values = self._get_dataframe_from_mixed_input(info_values)
        x_explain = self._get_dataframe_from_mixed_input(x_explain)
        # transforming the info values to a Series
        info_values = info_values.iloc[0, :]

        # checking info columns and x_explain match
        if any(info_values.index != x_explain.columns):
            raise ValueError(
                'columns differ from x_explain ({}) and info_values({})'.format(x_explain.columns, info_values.index)
            )

        cols = cols or x_explain.columns.to_list()
        importance_dict = self.explain_filtered_local(x_explain, cols=cols, n_cols=n_cols)[0]

        output_value = self._model_to_explain.predict_proba(x_explain.values)[0, 1]
        start_value = output_value - sum(importance_dict.values())
        rest = importance_dict.pop('rest')

        sorted_importances = sorted(importance_dict.items(), key=lambda importance: abs(importance[1]), reverse=True)
        hovertext = ['start value',
                     *['{} = {}'.format(col_name, info_values[col_name]) for col_name, col_value in sorted_importances],
                     'rest', 'output_value = {}'.format(output_value)]
        fig = go.Figure(go.Waterfall(
            orientation="v",
            measure=['absolute', *['relative' for _ in importance_dict], 'relative', 'absolute'],
            y=[start_value, *map(operator.itemgetter(1), sorted_importances), rest, output_value],
            textposition="outside",
            #     text = ["+60", "+80", "", "-40", "-20", "Total"],
            x=['start_value', *map(operator.itemgetter(0), sorted_importances), 'rest', 'output_value'],
            connector={"line": {"color": GREY}},
            decreasing={"marker": {"color": '#DB643D'}},
            increasing={"marker": {"color": '#3DDC97'}},
            totals={"marker": {"color": BLUE}},
            hovertext=hovertext,
            hoverinfo='text+delta',
        ))
        fig.update_layout(
            title="explanation",
            showlegend=False
        )
        return fig
示例#8
0
def updates_table(Year, month, zone_com):

    if Year == "All Year":

        all_plot = all_m.copy()
        df_plot_y = df.copy()
        df_plot_m = df.copy()

    else:
        all_plot = all_m[(all_m['Year'] == Year) & (all_m['Month'] == month) &
                         (all_m['Territory'] == zone_com)]

        df_plot_y = df[(df['Year'] == Year) & (df['Territory'] == zone_com)]
    pv = pd.pivot_table(df_plot_y,
                        index=['ZONE_COM'],
                        values=['GROSS', 'Previous'],
                        aggfunc=sum,
                        fill_value=0)

    MTD = go.Bar(x=pv.index,
                 y=pv['GROSS'],
                 name='GROSS',
                 marker_color='rgb(55, 83, 109)')
    MTD_1 = go.Bar(x=pv.index,
                   y=pv['Previous'],
                   name='PREVIOUS',
                   marker_color='rgb(26, 118, 255)')

    df_plot_m = df[(df['Year'] == Year) & (df['Month'] == month) &
                   (df['Territory'] == zone_com)]
    pv = pd.pivot_table(df_plot_m,
                        index=['ZONE_COM'],
                        values=['GROSS', 'Previous'],
                        aggfunc=sum,
                        fill_value=0)

    MTD_m = go.Bar(x=pv.index,
                   y=pv['GROSS'],
                   name='GROSS',
                   marker_color='rgb(55, 83, 109)')
    MTD_1_m = go.Bar(x=pv.index,
                     y=pv['Previous'],
                     name='PREVIOUS',
                     marker_color='rgb(26, 118, 255)')

    st2 = fd_s[(fd_s['Year'] == Year) & (fd_s['Month'] == month) &
               (fd_s['Territory'] == zone_com)]
    trace_1 = go.Waterfall(x=st2.Week,
                           measure=[
                               "relative", "relative", "relative", "relative",
                               "total", "relative", "relative", "relative",
                               "relative", "total", "relative", "relative",
                               "total", "relative", "total"
                           ],
                           y=st2['GROSS'],
                           connector={
                               "mode": "between",
                               "line": {
                                   "width": 4,
                                   "color": "rgb(0, 0, 0)",
                                   "dash": "solid"
                               }
                           },
                           decreasing={
                               "marker": {
                                   "color": "Maroon",
                                   "line": {
                                       "color": "red",
                                       "width": 2
                                   }
                               }
                           },
                           increasing={"marker": {
                               "color": "Teal"
                           }},
                           totals={
                               "marker": {
                                   "color": "deep sky blue",
                                   "line": {
                                       "color": 'blue',
                                       "width": 3
                                   }
                               }
                           })

    fig = go.Figure(data=[trace_1])

    return all_plot.to_dict('records'), {
        'data': [MTD, MTD_1],
        'layout': go.Layout(title=' GROSS YOY% {}'.format(Year),
                            barmode='group')
    }, {
        'data': [MTD_m, MTD_1_m],
        'layout': go.Layout(title=' GROSS MOM% {}'.format(Year),
                            barmode='group')
    }, fig
示例#9
0
tracex = go.Waterfall(x=fd_s.Week,
                      measure=[
                          "relative", "relative", "relative", "relative",
                          "total", "relative", "relative", "relative",
                          "relative", "total", "relative", "relative", "total",
                          "relative", "total"
                      ],
                      y=fd_s['GROSS'],
                      connector={
                          "mode": "between",
                          "line": {
                              "width": 4,
                              "color": "rgb(0, 0, 0)",
                              "dash": "solid"
                          }
                      },
                      decreasing={
                          "marker": {
                              "color": "Maroon",
                              "line": {
                                  "color": "red",
                                  "width": 2
                              }
                          }
                      },
                      increasing={"marker": {
                          "color": "Teal"
                      }},
                      totals={
                          "marker": {
                              "color": "deep sky blue",
                              "line": {
                                  "color": 'blue',
                                  "width": 3
                              }
                          }
                      })
示例#10
0
def build_plot(pred_df, pred):
    x = pred_df["feature"].values
    measures = ["relative" for i in range(len(x))]
    y_raw = (pred_df['weight']).values
    y = [
        y_raw[i] if (i == 0) else y_raw[i] - y_raw[i - 1]
        for i in range(len(y_raw))
    ]
    mean_y = np.mean(y)
    text = [str(val) for val in y_raw]
    name = "Feature Contribution"
    textposition = "outside"
    connector = {"line": {"color": "rgb(63, 63, 63)"}}

    fig = go.Figure(
        go.Waterfall(name=name,
                     orientation="v",
                     measure=measures,
                     x=x,
                     textposition=textposition,
                     text=text,
                     y=y,
                     connector=connector,
                     decreasing={
                         "marker": {
                             "color": "Maroon",
                             "line": {
                                 "color": "red",
                                 "width": 2
                             }
                         }
                     },
                     increasing={"marker": {
                         "color": "Teal"
                     }},
                     totals={
                         "marker": {
                             "color": "deep sky blue",
                             "line": {
                                 "color": 'blue',
                                 "width": 3
                             }
                         }
                     }))

    # fig.update_layout(
    #     plot_bgcolor='#00203FFF',
    #     paper_bgcolor='#ADEFD1FF',
    #     height = 800,
    #     title = "Model Interpretation",
    #     showlegend = True,
    #     yaxis=dict(range=[-0.75 , 1.75]))
    rng_mean = np.mean(np.abs(y_raw))
    ymax = max(y_raw) + rng_mean
    ymin = min(y_raw) - rng_mean
    fig.layout.height = 700
    fig.layout.showlegend = True
    fig.layout.yaxis = dict(range=[ymin, ymax])
    fig.layout.title = f"Probability of Success {str(pred)[:5]}%"
    # fig.layout.plot_bgcolor='#00203f'
    # fig.layout.paper_bgcolor='#ADEFD1'

    return dcc.Graph(id='sk', figure=fig)
示例#11
0
                    ]),
                #Graph
                dbc.Row([
                  dbc.Col(
                    [
                        dcc.Graph(id='waterfall1',

                                        figure = {'data':[

                        go.Waterfall(
                    name="week over week",
                    orientation = "v",
                    textposition = "outside",
                    text = df_sort['change'],
                    x =df_sort['Week'],
                    y = df_sort['change'],
                    #connector = {"line":{"color":"rgb(63, 63, 63)"}},
                    decreasing = {"marker":{"color":"Maroon", "line":{"color":"red", "width":3}}},
                    increasing = {"marker":{"color":"Teal"}},
                    totals = {"marker":{"color":"deep sky blue", "line":{"color":'blue', "width":4}}}
                    )

                                                ],

                                        'layout':go.Layout(paper_bgcolor = '#ffffff', plot_bgcolor='#ffffff', waterfallgap = 0.3)}
                                        ),
                        html.H5("Week over Week variation",style={'text-align':'center','padding':'0.5% 0.5% 0.5% 0.5%'})
                    ],
                    md=12,
                ), 
                    ]),
示例#12
0
    def get_change_fte_wf(dates_range, wf_type, n_functions_selected,
                          functions_selected, le_selected):
        if not isinstance(le_selected, list):
            le_selected = [le_selected]
        if not isinstance(functions_selected, list):
            functions_selected = [functions_selected]
        start_date = dates_list[dates_range[0] - 1]
        start_period = datetime.strftime(start_date, '%Y_%m')
        end_date = dates_list[dates_range[1] - 1]
        end_period = datetime.strftime(end_date, '%Y_%m')
        n_cases = n_functions_selected
        df_periods_total = pd.read_sql('''
                SELECT period, legal_entity_short_eng, function, SUM(fte) as fte 
                FROM hc_data_main
                WHERE period = "{}" OR period = "{}" 
                GROUP BY month_start, legal_entity_short_eng, function
            '''.format(start_period, end_period),
                                       con=engine)
        df_start = df_periods_total[
            df_periods_total['legal_entity_short_eng'].isin(le_selected)
            & (df_periods_total['period'] == start_period)]
        df_start = df_start.groupby('period').agg(fte=('fte',
                                                       'sum')).reset_index()
        df_start['title'] = 'FTE на ' + datetime.strftime(start_date, '%B %Y')
        df_start['measure'] = 'absolute'
        df_end = df_periods_total[
            df_periods_total['legal_entity_short_eng'].isin(le_selected)
            & (df_periods_total['period'] == end_period)]
        df_end = df_end.groupby('period').agg(fte=('fte', 'sum')).reset_index()
        df_end['title'] = 'FTE на ' + datetime.strftime(end_date, '%B %Y')
        df_end['measure'] = 'absolute'

        df_change = df_periods_total[df_periods_total['legal_entity_short_eng']
                                     .isin(le_selected)].copy()
        df_change['function'].fillna('Не опознаны', inplace=True)
        dff = pd.pivot_table(df_change,
                             index='function',
                             columns='period',
                             values='fte',
                             aggfunc='sum',
                             fill_value=0).reset_index()

        dff.rename(columns={
            start_period: 'start_fte',
            end_period: 'end_fte'
        },
                   inplace=True)
        dff['change'] = dff['end_fte'] - dff['start_fte']
        dff['rank'] = dff['change'].abs().rank(method='first', ascending=False)
        if wf_type == 'top_n':
            dff.loc[dff['rank'] > n_cases, 'function'] = 'Другие'
        else:
            dff.loc[~dff['function'].isin(functions_selected),
                    'function'] = 'Другие'

        dff = dff.groupby(['function']).agg(change=('change',
                                                    'sum')).reset_index()
        dff['measure'] = 'relative'
        dff.sort_values(by='change', ascending=False, inplace=True)
        dff.loc[dff['function'] == 'Другие', 'sorter'] = 1
        dff.loc[dff['function'] != 'Другие', 'sorter'] = 0
        dff.sort_values(by='sorter', ascending=True, inplace=True)
        dff.rename(columns={
            'function': 'title',
            'change': 'fte'
        },
                   inplace=True)
        df_result = pd.concat([
            df_start[['title', 'fte', 'measure']],
            dff[['title', 'fte', 'measure']],
            df_end[['title', 'fte', 'measure']],
        ]).round({'fte': 1})
        fig = go.Figure(
            go.Waterfall(
                orientation="v",
                measure=df_result['measure'],
                x=df_result['title'],
                textposition="outside",
                text=df_result['fte'],
                y=df_result['fte'],
                hovertemplate=
                '<span style="color: #000000">%{x}: %{text}</span><extra></extra>',
                decreasing={"marker": {
                    "color": 'rgba(211, 94, 96, 1)'
                }},
                increasing={"marker": {
                    "color": 'rgba(135, 186, 91, 1)'
                }},
                totals={"marker": {
                    "color": 'rgba(114, 147, 203, 1)'
                }}))
        title_string = 'Изменения по функциям в рамках общего изменения численности с {}'.format(
            str(datetime.strftime(start_date, '%b %Y')) + " по " +
            str(datetime.strftime(end_date, '%b %Y')))
        fig.update_layout(title_text=title_string,
                          autosize=True,
                          margin=dict(l=30, r=30, b=100, t=60),
                          plot_bgcolor="#EDEDED",
                          paper_bgcolor="#EDEDED",
                          hovermode='x',
                          legend=dict(font=dict(size=10), orientation="h"))
        max_value = df_result['fte'].max()
        fig.update_yaxes(range=[0, max_value * 1.3])

        return fig
示例#13
0
def predict_hd_summary(data_patient):

    # read in data and predict likelihood of heart disease
    x_new = pd.read_json(data_patient)
    y_val = hdpred_model.predict_proba(x_new)[:, 1] * 100
    text_val = str(np.round(y_val[0], 1)) + "%"

    # assign a risk group
    if y_val / 100 <= 0.275685:
        risk_grp = 'low risk'
    elif y_val / 100 <= 0.795583:
        risk_grp = 'medium risk'
    else:
        risk_grp = 'high risk'

    # assign an action related to the risk group
    rg_actions = {
        'low risk': [
            'Discuss with patient any single large risk factors they may have, and otherwise '
            'continue supporting healthy lifestyle habits. Follow-up in 12 months'
        ],
        'medium risk': [
            'Discuss lifestyle with patient and identify changes to reduce risk. '
            'Schedule follow-up with patient in 3 months on how changes are progressing. '
            'Recommend performing simple tests to assess positive impact of changes.'
        ],
        'high risk': [
            'Immediate follow-up with patient to discuss next steps including additional '
            'follow-up tests, lifestyle changes and medications.'
        ]
    }

    next_action = rg_actions[risk_grp][0]

    # create a single bar plot showing likelihood of heart disease
    fig1 = go.Figure()
    fig1.add_trace(
        go.Bar(y=[''],
               x=y_val,
               marker_color='rgb(112, 128, 144)',
               orientation='h',
               width=1,
               text=text_val,
               textposition='auto',
               hoverinfo='skip'))

    # add blocks for risk groups
    bot_val = 0.5
    top_val = 1

    fig1.add_shape(type="rect",
                   x0=0,
                   y0=bot_val,
                   x1=0.275686 * 100,
                   y1=top_val,
                   line=dict(color="white", ),
                   fillcolor="green")
    fig1.add_shape(type="rect",
                   x0=0.275686 * 100,
                   y0=bot_val,
                   x1=0.795584 * 100,
                   y1=top_val,
                   line=dict(color="white", ),
                   fillcolor="orange")
    fig1.add_shape(type="rect",
                   x0=0.795584 * 100,
                   y0=bot_val,
                   x1=1 * 100,
                   y1=top_val,
                   line=dict(color="white", ),
                   fillcolor="red")
    fig1.add_annotation(x=0.275686 / 2 * 100,
                        y=0.75,
                        text="Low risk",
                        showarrow=False,
                        font=dict(color="black", size=14))
    fig1.add_annotation(x=0.53 * 100,
                        y=0.75,
                        text="Medium risk",
                        showarrow=False,
                        font=dict(color="black", size=14))
    fig1.add_annotation(x=0.9 * 100,
                        y=0.75,
                        text="High risk",
                        showarrow=False,
                        font=dict(color="black", size=14))
    fig1.update_layout(margin=dict(l=0, r=50, t=10, b=10),
                       xaxis={'range': [0, 100]})

    # do shap value calculations for basic waterfall plot
    explainer_patient = shap.TreeExplainer(hdpred_model)
    shap_values_patient = explainer_patient.shap_values(x_new)
    updated_fnames = x_new.T.reset_index()
    updated_fnames.columns = ['feature', 'value']
    updated_fnames['shap_original'] = pd.Series(shap_values_patient[0])
    updated_fnames['shap_abs'] = updated_fnames['shap_original'].abs()
    updated_fnames = updated_fnames.sort_values(by=['shap_abs'],
                                                ascending=True)

    # need to collapse those after first 9, so plot always shows 10 bars
    show_features = 9
    num_other_features = updated_fnames.shape[0] - show_features
    col_other_name = f"{num_other_features} other features"
    f_group = pd.DataFrame(updated_fnames.head(num_other_features).sum()).T
    f_group['feature'] = col_other_name
    plot_data = pd.concat([f_group, updated_fnames.tail(show_features)])

    # additional things for plotting
    plot_range = plot_data['shap_original'].cumsum().max(
    ) - plot_data['shap_original'].cumsum().min()
    plot_data['text_pos'] = np.where(
        plot_data['shap_original'].abs() > (1 / 9) * plot_range, "inside",
        "outside")
    plot_data['text_col'] = "white"
    plot_data.loc[(plot_data['text_pos'] == "outside") &
                  (plot_data['shap_original'] < 0), 'text_col'] = "#3283FE"
    plot_data.loc[(plot_data['text_pos'] == "outside") &
                  (plot_data['shap_original'] > 0), 'text_col'] = "#F6222E"

    fig2 = go.Figure(
        go.Waterfall(name="",
                     orientation="h",
                     measure=['absolute'] + ['relative'] * show_features,
                     base=explainer_patient.expected_value,
                     textposition=plot_data['text_pos'],
                     text=plot_data['shap_original'],
                     textfont={"color": plot_data['text_col']},
                     texttemplate='%{text:+.2f}',
                     y=plot_data['feature'],
                     x=plot_data['shap_original'],
                     connector={
                         "mode": "spanning",
                         "line": {
                             "width": 1,
                             "color": "rgb(102, 102, 102)",
                             "dash": "dot"
                         }
                     },
                     decreasing={"marker": {
                         "color": "#3283FE"
                     }},
                     increasing={"marker": {
                         "color": "#F6222E"
                     }},
                     hoverinfo="skip"))
    fig2.update_layout(waterfallgap=0.2,
                       autosize=False,
                       width=800,
                       height=400,
                       paper_bgcolor='rgba(0,0,0,0)',
                       plot_bgcolor='rgba(0,0,0,0)',
                       yaxis=dict(showgrid=True,
                                  zeroline=True,
                                  showline=True,
                                  gridcolor='lightgray'),
                       xaxis=dict(showgrid=False,
                                  zeroline=False,
                                  showline=True,
                                  showticklabels=True,
                                  linecolor='black',
                                  tickcolor='black',
                                  ticks='outside',
                                  ticklen=5),
                       margin={
                           't': 25,
                           'b': 50
                       },
                       shapes=[
                           dict(type='line',
                                yref='paper',
                                y0=0,
                                y1=1.02,
                                xref='x',
                                x0=plot_data['shap_original'].sum() +
                                explainer_patient.expected_value,
                                x1=plot_data['shap_original'].sum() +
                                explainer_patient.expected_value,
                                layer="below",
                                line=dict(color="black", width=1, dash="dot"))
                       ])
    fig2.update_yaxes(automargin=True)
    fig2.add_annotation(yref='paper',
                        xref='x',
                        x=explainer_patient.expected_value,
                        y=-0.12,
                        text="E[f(x)] = {:.2f}".format(
                            explainer_patient.expected_value),
                        showarrow=False,
                        font=dict(color="black", size=14))
    fig2.add_annotation(
        yref='paper',
        xref='x',
        x=plot_data['shap_original'].sum() + explainer_patient.expected_value,
        y=1.075,
        text="f(x) = {:.2f}".format(plot_data['shap_original'].sum() +
                                    explainer_patient.expected_value),
        showarrow=False,
        font=dict(color="black", size=14))

    return fig1,\
        f"Based on the patient's profile, the predicted likelihood of heart disease is {text_val}. " \
        f"This patient is in the {risk_grp} group.",\
        f"Recommended action(s) for a patient in the {risk_grp} group",\
        next_action, \
        fig2