Example #1
0
    def plot(self,
             column: tp.Optional[tp.Label] = None,
             top_n: int = 5,
             plot_zones: bool = True,
             ts_trace_kwargs: tp.KwargsLike = None,
             peak_trace_kwargs: tp.KwargsLike = None,
             valley_trace_kwargs: tp.KwargsLike = None,
             recovery_trace_kwargs: tp.KwargsLike = None,
             active_trace_kwargs: tp.KwargsLike = None,
             decline_shape_kwargs: tp.KwargsLike = None,
             recovery_shape_kwargs: tp.KwargsLike = None,
             active_shape_kwargs: tp.KwargsLike = None,
             add_trace_kwargs: tp.KwargsLike = None,
             xref: str = 'x',
             yref: str = 'y',
             fig: tp.Optional[tp.BaseFigure] = None,
             **layout_kwargs) -> tp.BaseFigure:  # pragma: no cover
        """Plot drawdowns.

        Args:
            column (str): Name of the column to plot.
            top_n (int): Filter top N drawdown records by maximum drawdown.
            plot_zones (bool): Whether to plot zones.
            ts_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `Drawdowns.ts`.
            peak_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for peak values.
            valley_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for valley values.
            recovery_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for recovery values.
            active_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for active recovery values.
            decline_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for decline zones.
            recovery_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for recovery zones.
            active_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for active recovery zones.
            add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
            xref (str): X coordinate axis.
            yref (str): Y coordinate axis.
            fig (Figure or FigureWidget): Figure to add traces to.
            **layout_kwargs: Keyword arguments for layout.

        ## Example

        ```python-repl
        >>> import vectorbt as vbt
        >>> from datetime import datetime, timedelta
        >>> import pandas as pd

        >>> price = pd.Series([1, 2, 1, 2, 3, 2, 1, 2], name='Price')
        >>> price.index = [datetime(2020, 1, 1) + timedelta(days=i) for i in range(len(price))]
        >>> vbt.Drawdowns.from_ts(price, wrapper_kwargs=dict(freq='1 day')).plot()
        ```

        ![](/docs/img/drawdowns_plot.svg)
        """
        from vectorbt._settings import settings
        plotting_cfg = settings['plotting']

        self_col = self.select_one(column=column, group_by=False)
        if top_n is not None:
            # Drawdowns is negative, thus top_n becomes bottom_n
            self_col = self_col.apply_mask(
                self_col.drawdown.bottom_n_mask(top_n))

        if ts_trace_kwargs is None:
            ts_trace_kwargs = {}
        ts_trace_kwargs = merge_dicts(
            dict(line=dict(color=plotting_cfg['color_schema']['blue'])),
            ts_trace_kwargs)
        if peak_trace_kwargs is None:
            peak_trace_kwargs = {}
        if valley_trace_kwargs is None:
            valley_trace_kwargs = {}
        if recovery_trace_kwargs is None:
            recovery_trace_kwargs = {}
        if active_trace_kwargs is None:
            active_trace_kwargs = {}
        if decline_shape_kwargs is None:
            decline_shape_kwargs = {}
        if recovery_shape_kwargs is None:
            recovery_shape_kwargs = {}
        if active_shape_kwargs is None:
            active_shape_kwargs = {}
        if add_trace_kwargs is None:
            add_trace_kwargs = {}

        if fig is None:
            fig = make_figure()
        fig.update_layout(**layout_kwargs)
        y_domain = get_domain(yref, fig)

        if self_col.ts is not None:
            fig = self_col.ts.vbt.plot(trace_kwargs=ts_trace_kwargs,
                                       add_trace_kwargs=add_trace_kwargs,
                                       fig=fig)

        if self_col.count() > 0:
            # Extract information
            id_ = self_col.get_field_arr('id')
            id_title = self_col.get_field_title('id')

            peak_idx = self_col.get_map_field_to_index('peak_idx')
            peak_idx_title = self_col.get_field_title('peak_idx')

            if self_col.ts is not None:
                peak_val = self_col.ts.loc[peak_idx]
            else:
                peak_val = self_col.get_field_arr('peak_val')
            peak_val_title = self_col.get_field_title('peak_val')

            valley_idx = self_col.get_map_field_to_index('valley_idx')
            valley_idx_title = self_col.get_field_title('valley_idx')

            if self_col.ts is not None:
                valley_val = self_col.ts.loc[valley_idx]
            else:
                valley_val = self_col.get_field_arr('valley_val')
            valley_val_title = self_col.get_field_title('valley_val')

            end_idx = self_col.get_map_field_to_index('end_idx')
            end_idx_title = self_col.get_field_title('end_idx')

            if self_col.ts is not None:
                end_val = self_col.ts.loc[end_idx]
            else:
                end_val = self_col.get_field_arr('end_val')
            end_val_title = self_col.get_field_title('end_val')

            drawdown = self_col.drawdown.values
            recovery_return = self_col.recovery_return.values
            decline_duration = np.vectorize(str)(self_col.wrapper.to_timedelta(
                self_col.decline_duration.values,
                to_pd=True,
                silence_warnings=True))
            recovery_duration = np.vectorize(str)(
                self_col.wrapper.to_timedelta(
                    self_col.recovery_duration.values,
                    to_pd=True,
                    silence_warnings=True))
            duration = np.vectorize(str)(self_col.wrapper.to_timedelta(
                self_col.duration.values, to_pd=True, silence_warnings=True))

            status = self_col.get_field_arr('status')

            peak_mask = peak_idx != np.roll(
                end_idx, 1)  # peak and recovery at same time -> recovery wins
            if peak_mask.any():
                # Plot peak markers
                peak_customdata = id_[peak_mask][:, None]
                peak_scatter = go.Scatter(
                    x=peak_idx[peak_mask],
                    y=peak_val[peak_mask],
                    mode='markers',
                    marker=dict(
                        symbol='diamond',
                        color=plotting_cfg['contrast_color_schema']['blue'],
                        size=7,
                        line=dict(width=1,
                                  color=adjust_lightness(
                                      plotting_cfg['contrast_color_schema']
                                      ['blue']))),
                    name='Peak',
                    customdata=peak_customdata,
                    hovertemplate=f"{id_title}: %{{customdata[0]}}"
                    f"<br>{peak_idx_title}: %{{x}}"
                    f"<br>{peak_val_title}: %{{y}}")
                peak_scatter.update(**peak_trace_kwargs)
                fig.add_trace(peak_scatter, **add_trace_kwargs)

            recovered_mask = status == DrawdownStatus.Recovered
            if recovered_mask.any():
                # Plot valley markers
                valley_customdata = np.stack(
                    (id_[recovered_mask], drawdown[recovered_mask],
                     decline_duration[recovered_mask]),
                    axis=1)
                valley_scatter = go.Scatter(
                    x=valley_idx[recovered_mask],
                    y=valley_val[recovered_mask],
                    mode='markers',
                    marker=dict(
                        symbol='diamond',
                        color=plotting_cfg['contrast_color_schema']['red'],
                        size=7,
                        line=dict(width=1,
                                  color=adjust_lightness(
                                      plotting_cfg['contrast_color_schema']
                                      ['red']))),
                    name='Valley',
                    customdata=valley_customdata,
                    hovertemplate=f"{id_title}: %{{customdata[0]}}"
                    f"<br>{valley_idx_title}: %{{x}}"
                    f"<br>{valley_val_title}: %{{y}}"
                    f"<br>Drawdown: %{{customdata[1]:.2%}}"
                    f"<br>Duration: %{{customdata[2]}}")
                valley_scatter.update(**valley_trace_kwargs)
                fig.add_trace(valley_scatter, **add_trace_kwargs)

                if plot_zones:
                    # Plot drawdown zones
                    for i in range(len(id_[recovered_mask])):
                        fig.add_shape(**merge_dicts(
                            dict(
                                type="rect",
                                xref=xref,
                                yref="paper",
                                x0=peak_idx[recovered_mask][i],
                                y0=y_domain[0],
                                x1=valley_idx[recovered_mask][i],
                                y1=y_domain[1],
                                fillcolor='red',
                                opacity=0.2,
                                layer="below",
                                line_width=0,
                            ), decline_shape_kwargs))

                # Plot recovery markers
                recovery_customdata = np.stack(
                    (id_[recovered_mask], recovery_return[recovered_mask],
                     recovery_duration[recovered_mask]),
                    axis=1)
                recovery_scatter = go.Scatter(
                    x=end_idx[recovered_mask],
                    y=end_val[recovered_mask],
                    mode='markers',
                    marker=dict(
                        symbol='diamond',
                        color=plotting_cfg['contrast_color_schema']['green'],
                        size=7,
                        line=dict(width=1,
                                  color=adjust_lightness(
                                      plotting_cfg['contrast_color_schema']
                                      ['green']))),
                    name='Recovery/Peak',
                    customdata=recovery_customdata,
                    hovertemplate=f"{id_title}: %{{customdata[0]}}"
                    f"<br>{end_idx_title}: %{{x}}"
                    f"<br>{end_val_title}: %{{y}}"
                    f"<br>Return: %{{customdata[1]:.2%}}"
                    f"<br>Duration: %{{customdata[2]}}")
                recovery_scatter.update(**recovery_trace_kwargs)
                fig.add_trace(recovery_scatter, **add_trace_kwargs)

                if plot_zones:
                    # Plot recovery zones
                    for i in range(len(id_[recovered_mask])):
                        fig.add_shape(**merge_dicts(
                            dict(
                                type="rect",
                                xref=xref,
                                yref="paper",
                                x0=valley_idx[recovered_mask][i],
                                y0=y_domain[0],
                                x1=end_idx[recovered_mask][i],
                                y1=y_domain[1],
                                fillcolor='green',
                                opacity=0.2,
                                layer="below",
                                line_width=0,
                            ), recovery_shape_kwargs))

            # Plot active markers
            active_mask = status == DrawdownStatus.Active
            if active_mask.any():
                active_customdata = np.stack(
                    (id_[active_mask], drawdown[active_mask],
                     duration[active_mask]),
                    axis=1)
                active_scatter = go.Scatter(
                    x=end_idx[active_mask],
                    y=end_val[active_mask],
                    mode='markers',
                    marker=dict(
                        symbol='diamond',
                        color=plotting_cfg['contrast_color_schema']['orange'],
                        size=7,
                        line=dict(width=1,
                                  color=adjust_lightness(
                                      plotting_cfg['contrast_color_schema']
                                      ['orange']))),
                    name='Active',
                    customdata=active_customdata,
                    hovertemplate=f"{id_title}: %{{customdata[0]}}"
                    f"<br>{end_idx_title}: %{{x}}"
                    f"<br>{end_val_title}: %{{y}}"
                    f"<br>Return: %{{customdata[1]:.2%}}"
                    f"<br>Duration: %{{customdata[2]}}")
                active_scatter.update(**active_trace_kwargs)
                fig.add_trace(active_scatter, **add_trace_kwargs)

                if plot_zones:
                    # Plot active drawdown zones
                    for i in range(len(id_[active_mask])):
                        fig.add_shape(**merge_dicts(
                            dict(
                                type="rect",
                                xref=xref,
                                yref="paper",
                                x0=peak_idx[active_mask][i],
                                y0=y_domain[0],
                                x1=end_idx[active_mask][i],
                                y1=y_domain[1],
                                fillcolor='orange',
                                opacity=0.2,
                                layer="below",
                                line_width=0,
                            ), active_shape_kwargs))

        return fig
Example #2
0
    def plot(self,
             column: tp.Optional[tp.Label] = None,
             top_n: int = 5,
             plot_zones: bool = True,
             ts_trace_kwargs: tp.KwargsLike = None,
             start_trace_kwargs: tp.KwargsLike = None,
             end_trace_kwargs: tp.KwargsLike = None,
             open_shape_kwargs: tp.KwargsLike = None,
             closed_shape_kwargs: tp.KwargsLike = None,
             add_trace_kwargs: tp.KwargsLike = None,
             xref: str = 'x',
             yref: str = 'y',
             fig: tp.Optional[tp.BaseFigure] = None,
             **layout_kwargs) -> tp.BaseFigure:  # pragma: no cover
        """Plot ranges.

        Args:
            column (str): Name of the column to plot.
            top_n (int): Filter top N range records by maximum duration.
            plot_zones (bool): Whether to plot zones.
            ts_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `Ranges.ts`.
            start_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for start values.
            end_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for end values.
            open_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for open zones.
            closed_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for closed zones.
            add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
            xref (str): X coordinate axis.
            yref (str): Y coordinate axis.
            fig (Figure or FigureWidget): Figure to add traces to.
            **layout_kwargs: Keyword arguments for layout.

        ## Example

        ```python-repl
        >>> import vectorbt as vbt
        >>> from datetime import datetime, timedelta
        >>> import pandas as pd

        >>> price = pd.Series([1, 2, 1, 2, 3, 2, 1, 2], name='Price')
        >>> price.index = [datetime(2020, 1, 1) + timedelta(days=i) for i in range(len(price))]
        >>> vbt.Ranges.from_ts(price >= 2, wrapper_kwargs=dict(freq='1 day')).plot()
        ```

        ![](/docs/img/ranges_plot.svg)
        """
        from vectorbt._settings import settings
        plotting_cfg = settings['plotting']

        self_col = self.select_one(column=column, group_by=False)
        if top_n is not None:
            self_col = self_col.apply_mask(self_col.duration.top_n_mask(top_n))

        if ts_trace_kwargs is None:
            ts_trace_kwargs = {}
        ts_trace_kwargs = merge_dicts(
            dict(line=dict(color=plotting_cfg['color_schema']['blue'])),
            ts_trace_kwargs)
        if start_trace_kwargs is None:
            start_trace_kwargs = {}
        if end_trace_kwargs is None:
            end_trace_kwargs = {}
        if open_shape_kwargs is None:
            open_shape_kwargs = {}
        if closed_shape_kwargs is None:
            closed_shape_kwargs = {}
        if add_trace_kwargs is None:
            add_trace_kwargs = {}

        if fig is None:
            fig = make_figure()
        fig.update_layout(**layout_kwargs)
        y_domain = get_domain(yref, fig)

        if self_col.ts is not None:
            fig = self_col.ts.vbt.plot(trace_kwargs=ts_trace_kwargs,
                                       add_trace_kwargs=add_trace_kwargs,
                                       fig=fig)

        if self_col.count() > 0:
            # Extract information
            id_ = self_col.get_field_arr('id')
            id_title = self_col.get_field_title('id')

            start_idx = self_col.get_map_field_to_index('start_idx')
            start_idx_title = self_col.get_field_title('start_idx')
            if self_col.ts is not None:
                start_val = self_col.ts.loc[start_idx]
            else:
                start_val = np.full(len(start_idx), 0)

            end_idx = self_col.get_map_field_to_index('end_idx')
            end_idx_title = self_col.get_field_title('end_idx')
            if self_col.ts is not None:
                end_val = self_col.ts.loc[end_idx]
            else:
                end_val = np.full(len(end_idx), 0)

            duration = np.vectorize(str)(self_col.wrapper.to_timedelta(
                self_col.duration.values, to_pd=True, silence_warnings=True))

            status = self_col.get_field_arr('status')

            # Plot start markers
            start_customdata = id_[:, None]
            start_scatter = go.Scatter(
                x=start_idx,
                y=start_val,
                mode='markers',
                marker=dict(
                    symbol='diamond',
                    color=plotting_cfg['contrast_color_schema']['blue'],
                    size=7,
                    line=dict(
                        width=1,
                        color=adjust_lightness(
                            plotting_cfg['contrast_color_schema']['blue']))),
                name='Start',
                customdata=start_customdata,
                hovertemplate=f"{id_title}: %{{customdata[0]}}"
                f"<br>{start_idx_title}: %{{x}}")
            start_scatter.update(**start_trace_kwargs)
            fig.add_trace(start_scatter, **add_trace_kwargs)

            closed_mask = status == RangeStatus.Closed
            if closed_mask.any():
                # Plot end markers
                closed_end_customdata = np.stack(
                    (id_[closed_mask], duration[closed_mask]), axis=1)
                closed_end_scatter = go.Scatter(
                    x=end_idx[closed_mask],
                    y=end_val[closed_mask],
                    mode='markers',
                    marker=dict(
                        symbol='diamond',
                        color=plotting_cfg['contrast_color_schema']['green'],
                        size=7,
                        line=dict(width=1,
                                  color=adjust_lightness(
                                      plotting_cfg['contrast_color_schema']
                                      ['green']))),
                    name='Closed',
                    customdata=closed_end_customdata,
                    hovertemplate=f"{id_title}: %{{customdata[0]}}"
                    f"<br>{end_idx_title}: %{{x}}"
                    f"<br>Duration: %{{customdata[1]}}")
                closed_end_scatter.update(**end_trace_kwargs)
                fig.add_trace(closed_end_scatter, **add_trace_kwargs)

                if plot_zones:
                    # Plot closed range zones
                    for i in range(len(id_[closed_mask])):
                        fig.add_shape(**merge_dicts(
                            dict(
                                type="rect",
                                xref=xref,
                                yref="paper",
                                x0=start_idx[closed_mask][i],
                                y0=y_domain[0],
                                x1=end_idx[closed_mask][i],
                                y1=y_domain[1],
                                fillcolor='teal',
                                opacity=0.2,
                                layer="below",
                                line_width=0,
                            ), closed_shape_kwargs))

            open_mask = status == RangeStatus.Open
            if open_mask.any():
                # Plot end markers
                open_end_customdata = np.stack(
                    (id_[open_mask], duration[open_mask]), axis=1)
                open_end_scatter = go.Scatter(
                    x=end_idx[open_mask],
                    y=end_val[open_mask],
                    mode='markers',
                    marker=dict(
                        symbol='diamond',
                        color=plotting_cfg['contrast_color_schema']['orange'],
                        size=7,
                        line=dict(width=1,
                                  color=adjust_lightness(
                                      plotting_cfg['contrast_color_schema']
                                      ['orange']))),
                    name='Open',
                    customdata=open_end_customdata,
                    hovertemplate=f"{id_title}: %{{customdata[0]}}"
                    f"<br>{end_idx_title}: %{{x}}"
                    f"<br>Duration: %{{customdata[1]}}")
                open_end_scatter.update(**end_trace_kwargs)
                fig.add_trace(open_end_scatter, **add_trace_kwargs)

                if plot_zones:
                    # Plot open range zones
                    for i in range(len(id_[open_mask])):
                        fig.add_shape(**merge_dicts(
                            dict(
                                type="rect",
                                xref=xref,
                                yref="paper",
                                x0=start_idx[open_mask][i],
                                y0=y_domain[0],
                                x1=end_idx[open_mask][i],
                                y1=y_domain[1],
                                fillcolor='orange',
                                opacity=0.2,
                                layer="below",
                                line_width=0,
                            ), open_shape_kwargs))

        return fig
Example #3
0
    def plot(self,
             column: tp.Optional[tp.Label] = None,
             top_n: int = 5,
             plot_ts: bool = True,
             plot_zones: bool = True,
             ts_trace_kwargs: tp.KwargsLike = None,
             peak_trace_kwargs: tp.KwargsLike = None,
             valley_trace_kwargs: tp.KwargsLike = None,
             recovery_trace_kwargs: tp.KwargsLike = None,
             active_trace_kwargs: tp.KwargsLike = None,
             ptv_shape_kwargs: tp.KwargsLike = None,
             vtr_shape_kwargs: tp.KwargsLike = None,
             active_shape_kwargs: tp.KwargsLike = None,
             add_trace_kwargs: tp.KwargsLike = None,
             xref: str = 'x',
             yref: str = 'y',
             fig: tp.Optional[tp.BaseFigure] = None,
             **layout_kwargs) -> tp.BaseFigure:  # pragma: no cover
        """Plot drawdowns over `Drawdowns.ts`.

        Args:
            column (str): Name of the column to plot.
            top_n (int): Filter top N drawdown records by maximum drawdown.
            plot_ts (bool): Whether to plot time series.
            plot_zones (bool): Whether to plot zones.
            ts_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for time series.
            peak_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for peak values.
            valley_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for valley values.
            recovery_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for recovery values.
            active_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for active recovery values.
            ptv_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for PtV zones.
            vtr_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for VtR zones.
            active_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for active VtR zones.
            add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
            xref (str): X coordinate axis.
            yref (str): Y coordinate axis.
            fig (Figure or FigureWidget): Figure to add traces to.
            **layout_kwargs: Keyword arguments for layout.

        ## Example

        ```python-repl
        >>> import vectorbt as vbt
        >>> import pandas as pd

        >>> ts = pd.Series([1, 2, 1, 2, 3, 2, 1, 2])
        >>> vbt.Drawdowns.from_ts(ts, freq='1 days').plot()
        ```

        ![](/vectorbt/docs/img/drawdowns_plot.svg)
        """
        from vectorbt._settings import settings
        plotting_cfg = settings['plotting']

        self_col = self.select_one(column=column, group_by=False)
        if top_n is not None:
            # Drawdowns is negative, thus top_n becomes bottom_n
            self_col = self_col.filter_by_mask(
                self_col.drawdown.bottom_n_mask(top_n))

        if ts_trace_kwargs is None:
            ts_trace_kwargs = {}
        ts_trace_kwargs = merge_dicts(
            dict(line=dict(color=plotting_cfg['color_schema']['blue'])),
            ts_trace_kwargs)
        if peak_trace_kwargs is None:
            peak_trace_kwargs = {}
        if valley_trace_kwargs is None:
            valley_trace_kwargs = {}
        if recovery_trace_kwargs is None:
            recovery_trace_kwargs = {}
        if active_trace_kwargs is None:
            active_trace_kwargs = {}
        if ptv_shape_kwargs is None:
            ptv_shape_kwargs = {}
        if vtr_shape_kwargs is None:
            vtr_shape_kwargs = {}
        if active_shape_kwargs is None:
            active_shape_kwargs = {}
        if add_trace_kwargs is None:
            add_trace_kwargs = {}

        if fig is None:
            fig = make_figure()
        fig.update_layout(**layout_kwargs)
        y_domain = get_domain(yref, fig)

        if plot_ts:
            fig = self_col.ts.vbt.plot(trace_kwargs=ts_trace_kwargs,
                                       add_trace_kwargs=add_trace_kwargs,
                                       fig=fig)

        if len(self_col.values) > 0:
            # Extract information
            _id = self_col.values['id']
            start_idx = self_col.values['start_idx']
            valley_idx = self_col.values['valley_idx']
            end_idx = self_col.values['end_idx']
            status = self_col.values['status']

            start_val = self_col.ts.values[start_idx]
            valley_val = self_col.ts.values[valley_idx]
            end_val = self_col.ts.values[end_idx]

            def get_duration_str(from_idx, to_idx):
                if isinstance(self_col.wrapper.index, DatetimeIndexes):
                    duration = self_col.wrapper.index[
                        to_idx] - self_col.wrapper.index[from_idx]
                elif self_col.wrapper.freq is not None:
                    duration = self_col.wrapper.to_time_units(to_idx -
                                                              from_idx)
                else:
                    duration = to_idx - from_idx
                return np.vectorize(str)(duration)

            # Plot peak markers
            peak_mask = start_idx != np.roll(
                end_idx, 1)  # peak and recovery at same time -> recovery wins
            if np.any(peak_mask):
                peak_customdata = _id[peak_mask][:, None]
                peak_scatter = go.Scatter(
                    x=self_col.ts.index[start_idx[peak_mask]],
                    y=start_val[peak_mask],
                    mode='markers',
                    marker=dict(
                        symbol='diamond',
                        color=plotting_cfg['contrast_color_schema']['blue'],
                        size=7,
                        line=dict(width=1,
                                  color=adjust_lightness(
                                      plotting_cfg['contrast_color_schema']
                                      ['blue']))),
                    name='Peak',
                    customdata=peak_customdata,
                    hovertemplate="Drawdown Id: %{customdata[0]}"
                    "<br>Date: %{x}"
                    "<br>Price: %{y}")
                peak_scatter.update(**peak_trace_kwargs)
                fig.add_trace(peak_scatter, **add_trace_kwargs)

            recovery_mask = status == DrawdownStatus.Recovered
            if np.any(recovery_mask):
                # Plot valley markers
                valley_drawdown = (
                    valley_val[recovery_mask] -
                    start_val[recovery_mask]) / start_val[recovery_mask]
                valley_duration = get_duration_str(start_idx[recovery_mask],
                                                   valley_idx[recovery_mask])
                valley_customdata = np.stack(
                    (_id[recovery_mask], valley_drawdown, valley_duration),
                    axis=1)
                valley_scatter = go.Scatter(
                    x=self_col.ts.index[valley_idx[recovery_mask]],
                    y=valley_val[recovery_mask],
                    mode='markers',
                    marker=dict(
                        symbol='diamond',
                        color=plotting_cfg['contrast_color_schema']['red'],
                        size=7,
                        line=dict(width=1,
                                  color=adjust_lightness(
                                      plotting_cfg['contrast_color_schema']
                                      ['red']))),
                    name='Valley',
                    customdata=valley_customdata,
                    hovertemplate="Drawdown Id: %{customdata[0]}"
                    "<br>Date: %{x}"
                    "<br>Price: %{y}"
                    "<br>Drawdown: %{customdata[1]:.2%}"
                    "<br>Duration: %{customdata[2]}")
                valley_scatter.update(**valley_trace_kwargs)
                fig.add_trace(valley_scatter, **add_trace_kwargs)

                if plot_zones:
                    # Plot drawdown zones
                    for i in np.flatnonzero(recovery_mask):
                        fig.add_shape(**merge_dicts(
                            dict(
                                type="rect",
                                xref=xref,
                                yref="paper",
                                x0=self_col.ts.index[start_idx[i]],
                                y0=y_domain[0],
                                x1=self_col.ts.index[valley_idx[i]],
                                y1=y_domain[1],
                                fillcolor='red',
                                opacity=0.2,
                                layer="below",
                                line_width=0,
                            ), ptv_shape_kwargs))

                # Plot recovery markers
                recovery_return = (
                    end_val[recovery_mask] -
                    valley_val[recovery_mask]) / valley_val[recovery_mask]
                recovery_duration = get_duration_str(valley_idx[recovery_mask],
                                                     end_idx[recovery_mask])
                recovery_customdata = np.stack(
                    (_id[recovery_mask], recovery_return, recovery_duration),
                    axis=1)
                recovery_scatter = go.Scatter(
                    x=self_col.ts.index[end_idx[recovery_mask]],
                    y=end_val[recovery_mask],
                    mode='markers',
                    marker=dict(
                        symbol='diamond',
                        color=plotting_cfg['contrast_color_schema']['green'],
                        size=7,
                        line=dict(width=1,
                                  color=adjust_lightness(
                                      plotting_cfg['contrast_color_schema']
                                      ['green']))),
                    name='Recovery/Peak',
                    customdata=recovery_customdata,
                    hovertemplate="Drawdown Id: %{customdata[0]}"
                    "<br>Date: %{x}"
                    "<br>Price: %{y}"
                    "<br>Return: %{customdata[1]:.2%}"
                    "<br>Duration: %{customdata[2]}")
                recovery_scatter.update(**recovery_trace_kwargs)
                fig.add_trace(recovery_scatter, **add_trace_kwargs)

                if plot_zones:
                    # Plot recovery zones
                    for i in np.flatnonzero(recovery_mask):
                        fig.add_shape(**merge_dicts(
                            dict(
                                type="rect",
                                xref=xref,
                                yref="paper",
                                x0=self_col.ts.index[valley_idx[i]],
                                y0=y_domain[0],
                                x1=self_col.ts.index[end_idx[i]],
                                y1=y_domain[1],
                                fillcolor='green',
                                opacity=0.2,
                                layer="below",
                                line_width=0,
                            ), vtr_shape_kwargs))

            # Plot active markers
            active_mask = ~recovery_mask
            if np.any(active_mask):
                active_drawdown = (
                    valley_val[active_mask] -
                    start_val[active_mask]) / start_val[active_mask]
                active_duration = get_duration_str(valley_idx[active_mask],
                                                   end_idx[active_mask])
                active_customdata = np.stack(
                    (_id[active_mask], active_drawdown, active_duration),
                    axis=1)
                active_scatter = go.Scatter(
                    x=self_col.ts.index[end_idx[active_mask]],
                    y=end_val[active_mask],
                    mode='markers',
                    marker=dict(
                        symbol='diamond',
                        color=plotting_cfg['contrast_color_schema']['orange'],
                        size=7,
                        line=dict(width=1,
                                  color=adjust_lightness(
                                      plotting_cfg['contrast_color_schema']
                                      ['orange']))),
                    name='Active',
                    customdata=active_customdata,
                    hovertemplate="Drawdown Id: %{customdata[0]}"
                    "<br>Date: %{x}"
                    "<br>Price: %{y}"
                    "<br>Drawdown: %{customdata[1]:.2%}"
                    "<br>Duration: %{customdata[2]}")
                active_scatter.update(**active_trace_kwargs)
                fig.add_trace(active_scatter, **add_trace_kwargs)

                if plot_zones:
                    # Plot active drawdown zones
                    for i in np.flatnonzero(active_mask):
                        fig.add_shape(**merge_dicts(
                            dict(
                                type="rect",
                                xref=xref,
                                yref="paper",
                                x0=self_col.ts.index[start_idx[i]],
                                y0=y_domain[0],
                                x1=self_col.ts.index[end_idx[i]],
                                y1=y_domain[1],
                                fillcolor='orange',
                                opacity=0.2,
                                layer="below",
                                line_width=0,
                            ), active_shape_kwargs))

        return fig
Example #4
0
    def plots(self,
              subplots: tp.Optional[tp.MaybeIterable[tp.Union[str, tp.Tuple[
                  str, tp.Kwargs]]]] = None,
              tags: tp.Optional[tp.MaybeIterable[str]] = None,
              column: tp.Optional[tp.Label] = None,
              group_by: tp.GroupByLike = None,
              silence_warnings: tp.Optional[bool] = None,
              template_mapping: tp.Optional[tp.Mapping] = None,
              settings: tp.KwargsLike = None,
              filters: tp.KwargsLike = None,
              subplot_settings: tp.KwargsLike = None,
              show_titles: bool = None,
              hide_id_labels: bool = None,
              group_id_labels: bool = None,
              make_subplots_kwargs: tp.KwargsLike = None,
              **layout_kwargs) -> tp.Optional[tp.BaseFigure]:
        """Plot various parts of this object.

        Args:
            subplots (str, tuple, iterable, or dict): Subplots to plot.

                Each element can be either:

                * a subplot name (see keys in `PlotsBuilderMixin.subplots`)
                * a tuple of a subplot name and a settings dict as in `PlotsBuilderMixin.subplots`.

                The settings dict can contain the following keys:

                * `title`: Title of the subplot. Defaults to the name.
                * `plot_func` (required): Plotting function for custom subplots.
                    Should write the supplied figure `fig` in-place and can return anything (it won't be used).
                * `xaxis_kwargs`: Layout keyword arguments for the x-axis. Defaults to `dict(title='Index')`.
                * `yaxis_kwargs`: Layout keyword arguments for the y-axis. Defaults to empty dict.
                * `tags`, `check_{filter}`, `inv_check_{filter}`, `resolve_plot_func`, `pass_{arg}`,
                    `resolve_path_{arg}`, `resolve_{arg}` and `template_mapping`:
                    The same as in `vectorbt.generic.stats_builder.StatsBuilderMixin` for `calc_func`.
                * Any other keyword argument that overrides the settings or is passed directly to `plot_func`.

                If `resolve_plot_func` is True, the plotting function may "request" any of the
                following arguments by accepting them or if `pass_{arg}` was found in the settings dict:

                * Each of `vectorbt.utils.attr_.AttrResolver.self_aliases`: original object
                    (ungrouped, with no column selected)
                * `group_by`: won't be passed if it was used in resolving the first attribute of `plot_func`
                    specified as a path, use `pass_group_by=True` to pass anyway
                * `column`
                * `subplot_name`
                * `trace_names`: list with the subplot name, can't be used in templates
                * `add_trace_kwargs`: dict with subplot row and column index
                * `xref`
                * `yref`
                * `xaxis`
                * `yaxis`
                * `x_domain`
                * `y_domain`
                * `fig`
                * `silence_warnings`
                * Any argument from `settings`
                * Any attribute of this object if it meant to be resolved
                    (see `vectorbt.utils.attr_.AttrResolver.resolve_attr`)

                !!! note
                    Layout-related resolution arguments such as `add_trace_kwargs` are unavailable
                    before filtering and thus cannot be used in any templates but can still be overridden.

                Pass `subplots='all'` to plot all supported subplots.
            tags (str or iterable): See `tags` in `vectorbt.generic.stats_builder.StatsBuilderMixin`.
            column (str): See `column` in `vectorbt.generic.stats_builder.StatsBuilderMixin`.
            group_by (any): See `group_by` in `vectorbt.generic.stats_builder.StatsBuilderMixin`.
            silence_warnings (bool): See `silence_warnings` in `vectorbt.generic.stats_builder.StatsBuilderMixin`.
            template_mapping (mapping): See `template_mapping` in `vectorbt.generic.stats_builder.StatsBuilderMixin`.

                Applied on `settings`, `make_subplots_kwargs`, and `layout_kwargs`, and then on each subplot settings.
            filters (dict): See `filters` in `vectorbt.generic.stats_builder.StatsBuilderMixin`.
            settings (dict): See `settings` in `vectorbt.generic.stats_builder.StatsBuilderMixin`.
            subplot_settings (dict): See `metric_settings` in `vectorbt.generic.stats_builder.StatsBuilderMixin`.
            show_titles (bool): Whether to show the title of each subplot.
            hide_id_labels (bool): Whether to hide identical legend labels.

                Two labels are identical if their name, marker style and line style match.
            group_id_labels (bool): Whether to group identical legend labels.
            make_subplots_kwargs (dict): Keyword arguments passed to `plotly.subplots.make_subplots`.
            **layout_kwargs: Keyword arguments used to update the layout of the figure.

        !!! note
            `PlotsBuilderMixin` and `vectorbt.generic.stats_builder.StatsBuilderMixin` are very similar.
            Some artifacts follow the same concept, just named differently:

            * `plots_defaults` vs `stats_defaults`
            * `subplots` vs `metrics`
            * `subplot_settings` vs `metric_settings`

            See further notes under `vectorbt.generic.stats_builder.StatsBuilderMixin`.

        Usage:
            See `vectorbt.portfolio.base` for examples.
        """
        from vectorbt._settings import settings as _settings
        plotting_cfg = _settings['plotting']

        # Resolve defaults
        if silence_warnings is None:
            silence_warnings = self.plots_defaults['silence_warnings']
        if show_titles is None:
            show_titles = self.plots_defaults['show_titles']
        if hide_id_labels is None:
            hide_id_labels = self.plots_defaults['hide_id_labels']
        if group_id_labels is None:
            group_id_labels = self.plots_defaults['group_id_labels']
        template_mapping = merge_dicts(self.plots_defaults['template_mapping'],
                                       template_mapping)
        filters = merge_dicts(self.plots_defaults['filters'], filters)
        settings = merge_dicts(self.plots_defaults['settings'], settings)
        subplot_settings = merge_dicts(self.plots_defaults['subplot_settings'],
                                       subplot_settings)
        make_subplots_kwargs = merge_dicts(
            self.plots_defaults['make_subplots_kwargs'], make_subplots_kwargs)
        layout_kwargs = merge_dicts(self.plots_defaults['layout_kwargs'],
                                    layout_kwargs)

        # Replace templates globally (not used at subplot level)
        if len(template_mapping) > 0:
            sub_settings = deep_substitute(settings, mapping=template_mapping)
            sub_make_subplots_kwargs = deep_substitute(
                make_subplots_kwargs, mapping=template_mapping)
            sub_layout_kwargs = deep_substitute(layout_kwargs,
                                                mapping=template_mapping)
        else:
            sub_settings = settings
            sub_make_subplots_kwargs = make_subplots_kwargs
            sub_layout_kwargs = layout_kwargs

        # Resolve self
        reself = self.resolve_self(cond_kwargs=sub_settings,
                                   impacts_caching=False,
                                   silence_warnings=silence_warnings)

        # Prepare subplots
        if subplots is None:
            subplots = reself.plots_defaults['subplots']
        if subplots == 'all':
            subplots = reself.subplots
        if isinstance(subplots, dict):
            subplots = list(subplots.items())
        if isinstance(subplots, (str, tuple)):
            subplots = [subplots]

        # Prepare tags
        if tags is None:
            tags = reself.plots_defaults['tags']
        if isinstance(tags, str) and tags == 'all':
            tags = None
        if isinstance(tags, (str, tuple)):
            tags = [tags]

        # Bring to the same shape
        new_subplots = []
        for i, subplot in enumerate(subplots):
            if isinstance(subplot, str):
                subplot = (subplot, reself.subplots[subplot])
            if not isinstance(subplot, tuple):
                raise TypeError(
                    f"Subplot at index {i} must be either a string or a tuple")
            new_subplots.append(subplot)
        subplots = new_subplots

        # Handle duplicate names
        subplot_counts = Counter(list(map(lambda x: x[0], subplots)))
        subplot_i = {k: -1 for k in subplot_counts.keys()}
        subplots_dct = {}
        for i, (subplot_name, _subplot_settings) in enumerate(subplots):
            if subplot_counts[subplot_name] > 1:
                subplot_i[subplot_name] += 1
                subplot_name = subplot_name + '_' + str(
                    subplot_i[subplot_name])
            subplots_dct[subplot_name] = _subplot_settings

        # Check subplot_settings
        missed_keys = set(subplot_settings.keys()).difference(
            set(subplots_dct.keys()))
        if len(missed_keys) > 0:
            raise ValueError(
                f"Keys {missed_keys} in subplot_settings could not be matched with any subplot"
            )

        # Merge settings
        opt_arg_names_dct = {}
        custom_arg_names_dct = {}
        resolved_self_dct = {}
        mapping_dct = {}
        for subplot_name, _subplot_settings in list(subplots_dct.items()):
            opt_settings = merge_dicts(
                {name: reself
                 for name in reself.self_aliases},
                dict(column=column,
                     group_by=group_by,
                     subplot_name=subplot_name,
                     trace_names=[subplot_name],
                     silence_warnings=silence_warnings), settings)
            _subplot_settings = _subplot_settings.copy()
            passed_subplot_settings = subplot_settings.get(subplot_name, {})
            merged_settings = merge_dicts(opt_settings, _subplot_settings,
                                          passed_subplot_settings)
            subplot_template_mapping = merged_settings.pop(
                'template_mapping', {})
            template_mapping_merged = merge_dicts(template_mapping,
                                                  subplot_template_mapping)
            template_mapping_merged = deep_substitute(template_mapping_merged,
                                                      mapping=merged_settings)
            mapping = merge_dicts(template_mapping_merged, merged_settings)
            # safe because we will use deep_substitute again once layout params are known
            merged_settings = deep_substitute(merged_settings,
                                              mapping=mapping,
                                              safe=True)

            # Filter by tag
            if tags is not None:
                in_tags = merged_settings.get('tags', None)
                if in_tags is None or not match_tags(tags, in_tags):
                    subplots_dct.pop(subplot_name, None)
                    continue

            custom_arg_names = set(_subplot_settings.keys()).union(
                set(passed_subplot_settings.keys()))
            opt_arg_names = set(opt_settings.keys())
            custom_reself = reself.resolve_self(
                cond_kwargs=merged_settings,
                custom_arg_names=custom_arg_names,
                impacts_caching=True,
                silence_warnings=merged_settings['silence_warnings'])

            subplots_dct[subplot_name] = merged_settings
            custom_arg_names_dct[subplot_name] = custom_arg_names
            opt_arg_names_dct[subplot_name] = opt_arg_names
            resolved_self_dct[subplot_name] = custom_reself
            mapping_dct[subplot_name] = mapping

        # Filter subplots
        for subplot_name, _subplot_settings in list(subplots_dct.items()):
            custom_reself = resolved_self_dct[subplot_name]
            mapping = mapping_dct[subplot_name]
            _silence_warnings = _subplot_settings.get('silence_warnings')

            subplot_filters = set()
            for k in _subplot_settings.keys():
                filter_name = None
                if k.startswith('check_'):
                    filter_name = k[len('check_'):]
                elif k.startswith('inv_check_'):
                    filter_name = k[len('inv_check_'):]
                if filter_name is not None:
                    if filter_name not in filters:
                        raise ValueError(
                            f"Metric '{subplot_name}' requires filter '{filter_name}'"
                        )
                    subplot_filters.add(filter_name)

            for filter_name in subplot_filters:
                filter_settings = filters[filter_name]
                _filter_settings = deep_substitute(filter_settings,
                                                   mapping=mapping)
                filter_func = _filter_settings['filter_func']
                warning_message = _filter_settings.get('warning_message', None)
                inv_warning_message = _filter_settings.get(
                    'inv_warning_message', None)
                to_check = _subplot_settings.get('check_' + filter_name, False)
                inv_to_check = _subplot_settings.get(
                    'inv_check_' + filter_name, False)

                if to_check or inv_to_check:
                    whether_true = filter_func(custom_reself,
                                               _subplot_settings)
                    to_remove = (to_check
                                 and not whether_true) or (inv_to_check
                                                           and whether_true)
                    if to_remove:
                        if to_check and warning_message is not None and not _silence_warnings:
                            warnings.warn(warning_message)
                        if inv_to_check and inv_warning_message is not None and not _silence_warnings:
                            warnings.warn(inv_warning_message)

                        subplots_dct.pop(subplot_name, None)
                        custom_arg_names_dct.pop(subplot_name, None)
                        opt_arg_names_dct.pop(subplot_name, None)
                        resolved_self_dct.pop(subplot_name, None)
                        mapping_dct.pop(subplot_name, None)
                        break

        # Any subplots left?
        if len(subplots_dct) == 0:
            if not silence_warnings:
                warnings.warn("No subplots to plot", stacklevel=2)
            return None

        # Set up figure
        rows = sub_make_subplots_kwargs.pop('rows', len(subplots_dct))
        cols = sub_make_subplots_kwargs.pop('cols', 1)
        specs = sub_make_subplots_kwargs.pop('specs', [[{}
                                                        for _ in range(cols)]
                                                       for _ in range(rows)])
        row_col_tuples = []
        for row, row_spec in enumerate(specs):
            for col, col_spec in enumerate(row_spec):
                if col_spec is not None:
                    row_col_tuples.append((row + 1, col + 1))
        shared_xaxes = sub_make_subplots_kwargs.pop('shared_xaxes', True)
        shared_yaxes = sub_make_subplots_kwargs.pop('shared_yaxes', False)
        default_height = plotting_cfg['layout']['height']
        default_width = plotting_cfg['layout']['width'] + 50
        min_space = 10  # space between subplots with no axis sharing
        max_title_spacing = 30
        max_xaxis_spacing = 50
        max_yaxis_spacing = 100
        legend_height = 50
        if show_titles:
            title_spacing = max_title_spacing
        else:
            title_spacing = 0
        if not shared_xaxes and rows > 1:
            xaxis_spacing = max_xaxis_spacing
        else:
            xaxis_spacing = 0
        if not shared_yaxes and cols > 1:
            yaxis_spacing = max_yaxis_spacing
        else:
            yaxis_spacing = 0
        if 'height' in sub_layout_kwargs:
            height = sub_layout_kwargs.pop('height')
        else:
            height = default_height + title_spacing
            if rows > 1:
                height *= rows
                height += min_space * rows - min_space
                height += legend_height - legend_height * rows
                if shared_xaxes:
                    height += max_xaxis_spacing - max_xaxis_spacing * rows
        if 'width' in sub_layout_kwargs:
            width = sub_layout_kwargs.pop('width')
        else:
            width = default_width
            if cols > 1:
                width *= cols
                width += min_space * cols - min_space
                if shared_yaxes:
                    width += max_yaxis_spacing - max_yaxis_spacing * cols
        if height is not None:
            if 'vertical_spacing' in sub_make_subplots_kwargs:
                vertical_spacing = sub_make_subplots_kwargs.pop(
                    'vertical_spacing')
            else:
                vertical_spacing = min_space + title_spacing + xaxis_spacing
            if vertical_spacing is not None and vertical_spacing > 1:
                vertical_spacing /= height
            legend_y = 1 + (min_space + title_spacing) / height
        else:
            vertical_spacing = sub_make_subplots_kwargs.pop(
                'vertical_spacing', None)
            legend_y = 1.02
        if width is not None:
            if 'horizontal_spacing' in sub_make_subplots_kwargs:
                horizontal_spacing = sub_make_subplots_kwargs.pop(
                    'horizontal_spacing')
            else:
                horizontal_spacing = min_space + yaxis_spacing
            if horizontal_spacing is not None and horizontal_spacing > 1:
                horizontal_spacing /= width
        else:
            horizontal_spacing = sub_make_subplots_kwargs.pop(
                'horizontal_spacing', None)
        if show_titles:
            _subplot_titles = []
            for i in range(len(subplots_dct)):
                _subplot_titles.append('$title_' + str(i))
        else:
            _subplot_titles = None
        fig = make_subplots(rows=rows,
                            cols=cols,
                            specs=specs,
                            shared_xaxes=shared_xaxes,
                            shared_yaxes=shared_yaxes,
                            subplot_titles=_subplot_titles,
                            vertical_spacing=vertical_spacing,
                            horizontal_spacing=horizontal_spacing,
                            **sub_make_subplots_kwargs)
        sub_layout_kwargs = merge_dicts(
            dict(showlegend=True,
                 width=width,
                 height=height,
                 legend=dict(orientation="h",
                             yanchor="bottom",
                             y=legend_y,
                             xanchor="right",
                             x=1,
                             traceorder='normal')), sub_layout_kwargs)
        fig.update_layout(
            **sub_layout_kwargs)  # final destination for sub_layout_kwargs

        # Plot subplots
        arg_cache_dct = {}
        for i, (subplot_name,
                _subplot_settings) in enumerate(subplots_dct.items()):
            try:
                final_kwargs = _subplot_settings.copy()
                opt_arg_names = opt_arg_names_dct[subplot_name]
                custom_arg_names = custom_arg_names_dct[subplot_name]
                custom_reself = resolved_self_dct[subplot_name]
                mapping = mapping_dct[subplot_name]

                # Compute figure artifacts
                row, col = row_col_tuples[i]
                xref = 'x' if i == 0 else 'x' + str(i + 1)
                yref = 'y' if i == 0 else 'y' + str(i + 1)
                xaxis = 'xaxis' + xref[1:]
                yaxis = 'yaxis' + yref[1:]
                x_domain = get_domain(xref, fig)
                y_domain = get_domain(yref, fig)
                subplot_layout_kwargs = dict(
                    add_trace_kwargs=dict(row=row, col=col),
                    xref=xref,
                    yref=yref,
                    xaxis=xaxis,
                    yaxis=yaxis,
                    x_domain=x_domain,
                    y_domain=y_domain,
                    fig=fig,
                    pass_fig=True  # force passing fig
                )
                for k in subplot_layout_kwargs:
                    opt_arg_names.add(k)
                    if k in final_kwargs:
                        custom_arg_names.add(k)
                final_kwargs = merge_dicts(subplot_layout_kwargs, final_kwargs)
                mapping = merge_dicts(subplot_layout_kwargs, mapping)
                final_kwargs = deep_substitute(final_kwargs, mapping=mapping)

                # Clean up keys
                for k, v in list(final_kwargs.items()):
                    if k.startswith('check_') or k.startswith(
                            'inv_check_') or k in ('tags', ):
                        final_kwargs.pop(k, None)

                # Get subplot-specific values
                _column = final_kwargs.get('column')
                _group_by = final_kwargs.get('group_by')
                _silence_warnings = final_kwargs.get('silence_warnings')
                title = final_kwargs.pop('title', subplot_name)
                plot_func = final_kwargs.pop('plot_func', None)
                xaxis_kwargs = final_kwargs.pop('xaxis_kwargs', None)
                yaxis_kwargs = final_kwargs.pop('yaxis_kwargs', None)
                resolve_plot_func = final_kwargs.pop('resolve_plot_func', True)
                use_caching = final_kwargs.pop('use_caching', True)

                if plot_func is not None:
                    # Resolve plot_func
                    if resolve_plot_func:
                        if not callable(plot_func):
                            passed_kwargs_out = {}

                            def _getattr_func(
                                obj: tp.Any,
                                attr: str,
                                args: tp.ArgsLike = None,
                                kwargs: tp.KwargsLike = None,
                                call_attr: bool = True,
                                _final_kwargs: tp.Kwargs = final_kwargs,
                                _opt_arg_names: tp.Set[str] = opt_arg_names,
                                _custom_arg_names: tp.
                                Set[str] = custom_arg_names,
                                _arg_cache_dct: tp.Kwargs = arg_cache_dct
                            ) -> tp.Any:
                                if attr in final_kwargs:
                                    return final_kwargs[attr]
                                if args is None:
                                    args = ()
                                if kwargs is None:
                                    kwargs = {}
                                if obj is custom_reself and _final_kwargs.pop(
                                        'resolve_path_' + attr, True):
                                    if call_attr:
                                        return custom_reself.resolve_attr(
                                            attr,
                                            args=args,
                                            cond_kwargs={
                                                k: v
                                                for k, v in
                                                _final_kwargs.items()
                                                if k in _opt_arg_names
                                            },
                                            kwargs=kwargs,
                                            custom_arg_names=_custom_arg_names,
                                            cache_dct=_arg_cache_dct,
                                            use_caching=use_caching,
                                            passed_kwargs_out=passed_kwargs_out
                                        )
                                    return getattr(obj, attr)
                                out = getattr(obj, attr)
                                if callable(out) and call_attr:
                                    return out(*args, **kwargs)
                                return out

                            plot_func = custom_reself.deep_getattr(
                                plot_func,
                                getattr_func=_getattr_func,
                                call_last_attr=False)

                            if 'group_by' in passed_kwargs_out:
                                if 'pass_group_by' not in final_kwargs:
                                    final_kwargs.pop('group_by', None)
                        if not callable(plot_func):
                            raise TypeError("plot_func must be callable")

                        # Resolve arguments
                        func_arg_names = get_func_arg_names(plot_func)
                        for k in func_arg_names:
                            if k not in final_kwargs:
                                if final_kwargs.pop('resolve_' + k, False):
                                    try:
                                        arg_out = custom_reself.resolve_attr(
                                            k,
                                            cond_kwargs=final_kwargs,
                                            custom_arg_names=custom_arg_names,
                                            cache_dct=arg_cache_dct,
                                            use_caching=use_caching)
                                    except AttributeError:
                                        continue
                                    final_kwargs[k] = arg_out
                        for k in list(final_kwargs.keys()):
                            if k in opt_arg_names:
                                if 'pass_' + k in final_kwargs:
                                    if not final_kwargs.get(
                                            'pass_' + k):  # first priority
                                        final_kwargs.pop(k, None)
                                elif k not in func_arg_names:  # second priority
                                    final_kwargs.pop(k, None)
                        for k in list(final_kwargs.keys()):
                            if k.startswith('pass_') or k.startswith(
                                    'resolve_'):
                                final_kwargs.pop(k, None)  # cleanup

                        # Call plot_func
                        plot_func(**final_kwargs)
                    else:
                        # Do not resolve plot_func
                        plot_func(custom_reself, _subplot_settings)

                # Update global layout
                for annotation in fig.layout.annotations:
                    if 'text' in annotation and annotation[
                            'text'] == '$title_' + str(i):
                        annotation['text'] = title
                subplot_layout = dict()
                subplot_layout[xaxis] = merge_dicts(dict(title='Index'),
                                                    xaxis_kwargs)
                subplot_layout[yaxis] = merge_dicts(dict(), yaxis_kwargs)
                fig.update_layout(**subplot_layout)
            except Exception as e:
                warnings.warn(f"Subplot '{subplot_name}' raised an exception",
                              stacklevel=2)
                raise e

        # Remove duplicate legend labels
        found_ids = dict()
        unique_idx = 0
        for trace in fig.data:
            if 'name' in trace:
                name = trace['name']
            else:
                name = None
            if 'marker' in trace:
                marker = trace['marker']
            else:
                marker = {}
            if 'symbol' in marker:
                marker_symbol = marker['symbol']
            else:
                marker_symbol = None
            if 'color' in marker:
                marker_color = marker['color']
            else:
                marker_color = None
            if 'line' in trace:
                line = trace['line']
            else:
                line = {}
            if 'dash' in line:
                line_dash = line['dash']
            else:
                line_dash = None
            if 'color' in line:
                line_color = line['color']
            else:
                line_color = None

            id = (name, marker_symbol, marker_color, line_dash, line_color)
            if id in found_ids:
                if hide_id_labels:
                    trace['showlegend'] = False
                if group_id_labels:
                    trace['legendgroup'] = found_ids[id]
            else:
                if group_id_labels:
                    trace['legendgroup'] = unique_idx
                found_ids[id] = unique_idx
                unique_idx += 1

        # Remove all except the last title if sharing the same axis
        if shared_xaxes:
            i = 0
            for row in range(rows):
                for col in range(cols):
                    if specs[row][col] is not None:
                        xaxis = 'xaxis' if i == 0 else 'xaxis' + str(i + 1)
                        if row < rows - 1:
                            fig.layout[xaxis]['title'] = None
                        i += 1
        if shared_yaxes:
            i = 0
            for row in range(rows):
                for col in range(cols):
                    if specs[row][col] is not None:
                        yaxis = 'yaxis' if i == 0 else 'yaxis' + str(i + 1)
                        if col > 0:
                            fig.layout[yaxis]['title'] = None
                        i += 1

        # Return the figure
        return fig
Example #5
0
    def plot_pnl_returns(self,
                         column: tp.Optional[tp.Label] = None,
                         as_pct: bool = True,
                         marker_size_range: tp.Tuple[float, float] = (7, 14),
                         opacity_range: tp.Tuple[float, float] = (0.75, 0.9),
                         closed_profit_trace_kwargs: tp.KwargsLike = None,
                         closed_loss_trace_kwargs: tp.KwargsLike = None,
                         open_trace_kwargs: tp.KwargsLike = None,
                         hline_shape_kwargs: tp.KwargsLike = None,
                         add_trace_kwargs: tp.KwargsLike = None,
                         xref: str = 'x',
                         yref: str = 'y',
                         fig: tp.Optional[tp.BaseFigure] = None,
                         **layout_kwargs) -> tp.BaseFigure:  # pragma: no cover
        """Plot trade PnL.

        Args:
            column (str): Name of the column to plot.
            as_pct (bool): Whether to set y-axis to `Trades.returns`, otherwise to `Trades.pnl`.
            marker_size_range (tuple): Range of marker size.
            opacity_range (tuple): Range of marker opacity.
            closed_profit_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Closed - Profit" markers.
            closed_loss_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Closed - Loss" markers.
            open_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Open" markers.
            hline_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for zeroline.
            add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
            xref (str): X coordinate axis.
            yref (str): Y coordinate axis.
            fig (Figure or FigureWidget): Figure to add traces to.
            **layout_kwargs: Keyword arguments for layout.
        """
        from vectorbt._settings import settings
        plotting_cfg = settings['plotting']

        self_col = self.select_one(column=column, group_by=False)

        if closed_profit_trace_kwargs is None:
            closed_profit_trace_kwargs = {}
        if closed_loss_trace_kwargs is None:
            closed_loss_trace_kwargs = {}
        if open_trace_kwargs is None:
            open_trace_kwargs = {}
        if hline_shape_kwargs is None:
            hline_shape_kwargs = {}
        if add_trace_kwargs is None:
            add_trace_kwargs = {}
        marker_size_range = tuple(marker_size_range)
        xaxis = 'xaxis' + xref[1:]
        yaxis = 'yaxis' + yref[1:]

        if fig is None:
            fig = make_figure()
        if as_pct:
            _layout_kwargs = dict()
            _layout_kwargs[yaxis] = dict(tickformat='.2%')
            fig.update_layout(**_layout_kwargs)
        fig.update_layout(**layout_kwargs)
        x_domain = get_domain(xref, fig)

        if len(self_col.values) > 0:
            # Extract information
            _pnl_str = '%{customdata[1]:.6f}' if as_pct else '%{y}'
            _return_str = '%{y}' if as_pct else '%{customdata[1]:.2%}'
            exit_idx = self_col.values['exit_idx']
            pnl = self_col.values['pnl']
            returns = self_col.values['return']
            status = self_col.values['status']

            neutral_mask = pnl == 0
            profit_mask = pnl > 0
            loss_mask = pnl < 0

            marker_size = min_rel_rescale(np.abs(returns), marker_size_range)
            opacity = max_rel_rescale(np.abs(returns), opacity_range)

            open_mask = status == TradeStatus.Open
            closed_profit_mask = (~open_mask) & profit_mask
            closed_loss_mask = (~open_mask) & loss_mask
            open_mask &= ~neutral_mask

            def _plot_scatter(mask: tp.Array1d, name: tp.TraceName,
                              color: tp.Any, kwargs: tp.Kwargs) -> None:
                if np.any(mask):
                    if self_col.trade_type == TradeType.Trade:
                        customdata = np.stack(
                            (self_col.values['id'][mask],
                             self_col.values['position_id'][mask],
                             pnl[mask] if as_pct else returns[mask]),
                            axis=1)
                        hovertemplate = "Trade Id: %{customdata[0]}" \
                                        "<br>Position Id: %{customdata[1]}" \
                                        "<br>Date: %{x}" \
                                        f"<br>PnL: {_pnl_str}" \
                                        f"<br>Return: {_return_str}"
                    else:
                        customdata = np.stack(
                            (self_col.values['id'][mask],
                             pnl[mask] if as_pct else returns[mask]),
                            axis=1)
                        hovertemplate = "Position Id: %{customdata[0]}" \
                                        "<br>Date: %{x}" \
                                        f"<br>PnL: {_pnl_str}" \
                                        f"<br>Return: {_return_str}"
                    scatter = go.Scatter(
                        x=self_col.wrapper.index[exit_idx[mask]],
                        y=returns[mask] if as_pct else pnl[mask],
                        mode='markers',
                        marker=dict(
                            symbol='circle',
                            color=color,
                            size=marker_size[mask],
                            opacity=opacity[mask],
                            line=dict(width=1, color=adjust_lightness(color)),
                        ),
                        name=name,
                        customdata=customdata,
                        hovertemplate=hovertemplate)
                    scatter.update(**kwargs)
                    fig.add_trace(scatter, **add_trace_kwargs)

            # Plot Closed - Profit scatter
            _plot_scatter(closed_profit_mask, 'Closed - Profit',
                          plotting_cfg['contrast_color_schema']['green'],
                          closed_profit_trace_kwargs)

            # Plot Closed - Profit scatter
            _plot_scatter(closed_loss_mask, 'Closed - Loss',
                          plotting_cfg['contrast_color_schema']['red'],
                          closed_loss_trace_kwargs)

            # Plot Open scatter
            _plot_scatter(open_mask, 'Open',
                          plotting_cfg['contrast_color_schema']['orange'],
                          open_trace_kwargs)

        # Plot zeroline
        fig.add_shape(**merge_dicts(
            dict(type='line',
                 xref="paper",
                 yref=yref,
                 x0=x_domain[0],
                 y0=0,
                 x1=x_domain[1],
                 y1=0,
                 line=dict(
                     color="gray",
                     dash="dash",
                 )), hline_shape_kwargs))
        return fig