Beispiel #1
0
class OHLCVDFAccessor(GenericDFAccessor):  # pragma: no cover
    """Accessor on top of OHLCV data. For DataFrames only.

    Accessible through `pd.DataFrame.vbt.ohlcv`."""

    def __init__(self, obj: tp.Frame, column_names: tp.KwargsLike = None, **kwargs) -> None:
        self._column_names = column_names

        GenericDFAccessor.__init__(self, obj, column_names=column_names, **kwargs)

    @property
    def column_names(self) -> tp.Kwargs:
        """Column names."""
        from vectorbt._settings import settings
        ohlcv_cfg = settings['ohlcv']

        return merge_dicts(ohlcv_cfg['column_names'], self._column_names)

    def get_column(self, col_name: str) -> tp.Optional[tp.Series]:
        """Get column from `OHLCVDFAccessor.column_names`."""
        df_column_names = self.obj.columns.str.lower().tolist()
        col_name = self.column_names[col_name].lower()
        if col_name not in df_column_names:
            return None
        return self.obj.iloc[:, df_column_names.index(col_name)]

    @property
    def open(self) -> tp.Optional[tp.Series]:
        """Open series."""
        return self.get_column('open')

    @property
    def high(self) -> tp.Optional[tp.Series]:
        """High series."""
        return self.get_column('high')

    @property
    def low(self) -> tp.Optional[tp.Series]:
        """Low series."""
        return self.get_column('low')

    @property
    def close(self) -> tp.Optional[tp.Series]:
        """Close series."""
        return self.get_column('close')

    @property
    def ohlc(self) -> tp.Optional[tp.Frame]:
        """Open, high, low, and close series."""
        to_concat = []
        if self.open is not None:
            to_concat.append(self.open)
        if self.high is not None:
            to_concat.append(self.high)
        if self.low is not None:
            to_concat.append(self.low)
        if self.close is not None:
            to_concat.append(self.close)
        if len(to_concat) == 0:
            return None
        return pd.concat(to_concat, axis=1)

    @property
    def volume(self) -> tp.Optional[tp.Series]:
        """Volume series."""
        return self.get_column('volume')

    # ############# Stats ############# #

    @property
    def stats_defaults(self) -> tp.Kwargs:
        """Defaults for `OHLCVDFAccessor.stats`.

        Merges `vectorbt.generic.accessors.GenericAccessor.stats_defaults` and
        `ohlcv.stats` from `vectorbt._settings.settings`."""
        from vectorbt._settings import settings
        ohlcv_stats_cfg = settings['ohlcv']['stats']

        return merge_dicts(
            GenericAccessor.stats_defaults.__get__(self),
            ohlcv_stats_cfg
        )

    _metrics: tp.ClassVar[Config] = Config(
        dict(
            start=dict(
                title='Start',
                calc_func=lambda self: self.wrapper.index[0],
                agg_func=None,
                tags='wrapper'
            ),
            end=dict(
                title='End',
                calc_func=lambda self: self.wrapper.index[-1],
                agg_func=None,
                tags='wrapper'
            ),
            period=dict(
                title='Period',
                calc_func=lambda self: len(self.wrapper.index),
                apply_to_timedelta=True,
                agg_func=None,
                tags='wrapper'
            ),
            first_price=dict(
                title='First Price',
                calc_func=lambda ohlc: nb.bfill_1d_nb(ohlc.values.flatten())[0],
                resolve_ohlc=True,
                tags=['ohlcv', 'ohlc']
            ),
            lowest_price=dict(
                title='Lowest Price',
                calc_func=lambda ohlc: ohlc.values.min(),
                resolve_ohlc=True,
                tags=['ohlcv', 'ohlc']
            ),
            highest_price=dict(
                title='Highest Price',
                calc_func=lambda ohlc: ohlc.values.max(),
                resolve_ohlc=True,
                tags=['ohlcv', 'ohlc']
            ),
            last_price=dict(
                title='Last Price',
                calc_func=lambda ohlc: nb.ffill_1d_nb(ohlc.values.flatten())[-1],
                resolve_ohlc=True,
                tags=['ohlcv', 'ohlc']
            ),
            first_volume=dict(
                title='First Volume',
                calc_func=lambda volume: nb.bfill_1d_nb(volume.values)[0],
                resolve_volume=True,
                tags=['ohlcv', 'volume']
            ),
            lowest_volume=dict(
                title='Lowest Volume',
                calc_func=lambda volume: volume.values.min(),
                resolve_volume=True,
                tags=['ohlcv', 'volume']
            ),
            highest_volume=dict(
                title='Highest Volume',
                calc_func=lambda volume: volume.values.max(),
                resolve_volume=True,
                tags=['ohlcv', 'volume']
            ),
            last_volume=dict(
                title='Last Volume',
                calc_func=lambda volume: nb.ffill_1d_nb(volume.values)[-1],
                resolve_volume=True,
                tags=['ohlcv', 'volume']
            ),
        ),
        copy_kwargs=dict(copy_mode='deep')
    )

    @property
    def metrics(self) -> Config:
        return self._metrics

    # ############# Plotting ############# #

    def plot(self,
             plot_type: tp.Union[None, str, tp.BaseTraceType] = None,
             show_volume: tp.Optional[bool] = None,
             ohlc_kwargs: tp.KwargsLike = None,
             volume_kwargs: tp.KwargsLike = None,
             ohlc_add_trace_kwargs: tp.KwargsLike = None,
             volume_add_trace_kwargs: tp.KwargsLike = None,
             fig: tp.Optional[tp.BaseFigure] = None,
             **layout_kwargs) -> tp.BaseFigure:  # pragma: no cover
        """Plot OHLCV data.

        Args:
            plot_type: Either 'OHLC', 'Candlestick' or Plotly trace.

                Pass None to use the default.
            show_volume (bool): If True, shows volume as bar chart.
            ohlc_kwargs (dict): Keyword arguments passed to `plot_type`.
            volume_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Bar`.
            ohlc_add_trace_kwargs (dict): Keyword arguments passed to `add_trace` for OHLC.
            volume_add_trace_kwargs (dict): Keyword arguments passed to `add_trace` for volume.
            fig (Figure or FigureWidget): Figure to add traces to.
            **layout_kwargs: Keyword arguments for layout.

        ## Example

        ```python-repl
        >>> import vectorbt as vbt

        >>> vbt.YFData.download("BTC-USD").get().vbt.ohlcv.plot()
        ```

        ![](/docs/img/ohlcv_plot.svg)
        """
        from vectorbt._settings import settings
        plotting_cfg = settings['plotting']
        ohlcv_cfg = settings['ohlcv']

        if ohlc_kwargs is None:
            ohlc_kwargs = {}
        if volume_kwargs is None:
            volume_kwargs = {}
        if ohlc_add_trace_kwargs is None:
            ohlc_add_trace_kwargs = {}
        if volume_add_trace_kwargs is None:
            volume_add_trace_kwargs = {}
        if show_volume is None:
            show_volume = self.volume is not None
        if show_volume:
            ohlc_add_trace_kwargs = merge_dicts(dict(row=1, col=1), ohlc_add_trace_kwargs)
            volume_add_trace_kwargs = merge_dicts(dict(row=2, col=1), volume_add_trace_kwargs)

        # Set up figure
        if fig is None:
            if show_volume:
                fig = make_subplots(rows=2, cols=1, shared_xaxes=True, vertical_spacing=0, row_heights=[0.7, 0.3])
            else:
                fig = make_figure()
            fig.update_layout(
                showlegend=True,
                xaxis=dict(
                    rangeslider_visible=False,
                    showgrid=True
                ),
                yaxis=dict(
                    showgrid=True
                )
            )
            if show_volume:
                fig.update_layout(
                    xaxis2=dict(
                        showgrid=True
                    ),
                    yaxis2=dict(
                        showgrid=True
                    ),
                    bargap=0
                )
        fig.update_layout(**layout_kwargs)
        if plot_type is None:
            plot_type = ohlcv_cfg['plot_type']
        if isinstance(plot_type, str):
            if plot_type.lower() == 'ohlc':
                plot_type = 'OHLC'
                plot_obj = go.Ohlc
            elif plot_type.lower() == 'candlestick':
                plot_type = 'Candlestick'
                plot_obj = go.Candlestick
            else:
                raise ValueError("Plot type can be either 'OHLC' or 'Candlestick'")
        else:
            plot_obj = plot_type
        ohlc = plot_obj(
            x=self.wrapper.index,
            open=self.open,
            high=self.high,
            low=self.low,
            close=self.close,
            name=plot_type,
            increasing=dict(
                line=dict(
                    color=plotting_cfg['color_schema']['increasing']
                )
            ),
            decreasing=dict(
                line=dict(
                    color=plotting_cfg['color_schema']['decreasing']
                )
            )
        )
        ohlc.update(**ohlc_kwargs)
        fig.add_trace(ohlc, **ohlc_add_trace_kwargs)

        if show_volume:
            marker_colors = np.empty(self.volume.shape, dtype=object)
            marker_colors[(self.close.values - self.open.values) > 0] = plotting_cfg['color_schema']['increasing']
            marker_colors[(self.close.values - self.open.values) == 0] = plotting_cfg['color_schema']['gray']
            marker_colors[(self.close.values - self.open.values) < 0] = plotting_cfg['color_schema']['decreasing']
            volume_bar = go.Bar(
                x=self.wrapper.index,
                y=self.volume,
                marker=dict(
                    color=marker_colors,
                    line_width=0
                ),
                opacity=0.5,
                name='Volume'
            )
            volume_bar.update(**volume_kwargs)
            fig.add_trace(volume_bar, **volume_add_trace_kwargs)

        return fig

    @property
    def plots_defaults(self) -> tp.Kwargs:
        """Defaults for `OHLCVDFAccessor.plots`.

        Merges `vectorbt.generic.accessors.GenericAccessor.plots_defaults` and
        `ohlcv.plots` from `vectorbt._settings.settings`."""
        from vectorbt._settings import settings
        ohlcv_plots_cfg = settings['ohlcv']['plots']

        return merge_dicts(
            GenericAccessor.plots_defaults.__get__(self),
            ohlcv_plots_cfg
        )

    _subplots: tp.ClassVar[Config] = Config(
        dict(
            plot=dict(
                title='OHLC',
                xaxis_kwargs=dict(
                    showgrid=True,
                    rangeslider_visible=False
                ),
                yaxis_kwargs=dict(
                    showgrid=True
                ),
                check_is_not_grouped=True,
                plot_func='plot',
                show_volume=False,
                tags='ohlcv'
            )
        ),
        copy_kwargs=dict(copy_mode='deep')
    )

    @property
    def subplots(self) -> Config:
        return self._subplots
Beispiel #2
0
    0                  False   False
    1                   True   False
    2                  False    True
    3                   True   False
    4                  False    True
    ```
    """
    pass


setattr(RPROB, '__doc__', _RPROB.__doc__)

rprobex_config = Config(
    dict(class_name='RPROBEX',
         module_name=__name__,
         short_name='rprobex',
         param_names=['prob'],
         exit_only=True,
         iteratively=False))
"""Factory config for `RPROBEX`."""

rprobex_func_config = Config(
    dict(exit_choice_func=rand_by_prob_choice_nb,
         exit_settings=dict(pass_params=['prob'],
                            pass_kwargs=['first', 'temp_idx_arr', 'flex_2d']),
         pass_flex_2d=True,
         param_settings=dict(prob=flex_elem_param_config),
         seed=None))
"""Exit function config for `RPROBEX`."""

RPROBEX = SignalFactory(**rprobex_config).from_choice_func(
Beispiel #3
0
from vectorbt.records.mapped_array import MappedArray
from vectorbt.utils.colors import adjust_lightness
from vectorbt.utils.config import merge_dicts, Config
from vectorbt.utils.decorators import cached_property, cached_method
from vectorbt.utils.figure import make_figure, get_domain
from vectorbt.utils.template import RepEval

__pdoc__ = {}

dd_field_config = Config(dict(dtype=drawdown_dt,
                              settings=dict(
                                  id=dict(title='Drawdown Id'),
                                  peak_idx=dict(title='Peak Timestamp',
                                                mapping='index'),
                                  valley_idx=dict(title='Valley Timestamp',
                                                  mapping='index'),
                                  peak_val=dict(title='Peak Value', ),
                                  valley_val=dict(title='Valley Value', ),
                                  end_val=dict(title='End Value', ),
                                  status=dict(mapping=DrawdownStatus))),
                         readonly=True,
                         as_attrs=False)
"""_"""

__pdoc__['dd_field_config'] = f"""Field config for `Drawdowns`.

```json
{dd_field_config.to_doc()}
```
"""
Beispiel #4
0
def load(fname, names=__all__, **kwargs):
    """Load settings from a file."""
    settings = Config.load(fname, **kwargs)
    for k in names:
        setattr(this_module, k, settings[k])
Beispiel #5
0
```

Changes take effect immediately."""

import json

from vectorbt.utils.config import Config

__pdoc__ = {}

# Color schema
color_schema = Config(blue="#1f77b4",
                      orange="#ff7f0e",
                      green="#2ca02c",
                      red="#dc3912",
                      purple="#9467bd",
                      brown="#8c564b",
                      pink="#e377c2",
                      gray="#7f7f7f",
                      yellow="#bcbd22",
                      cyan="#17becf")
"""_"""

__pdoc__['color_schema'] = f"""Color schema.

```plaintext
{json.dumps(color_schema, indent=2)}
```
"""

# Contrast color schema
contrast_color_schema = Config(blue='#4285F4',
Beispiel #6
0
from vectorbt.records.base import Records
from vectorbt.records.decorators import override_field_config, attach_fields
from vectorbt.records.mapped_array import MappedArray
from vectorbt.utils.colors import adjust_lightness
from vectorbt.utils.config import merge_dicts, Config
from vectorbt.utils.decorators import cached_property, cached_method
from vectorbt.utils.figure import make_figure, get_domain

__pdoc__ = {}

ranges_field_config = Config(
    dict(
        dtype=range_dt,
        settings=dict(
            id=dict(title='Range Id'),
            idx=dict(name='end_idx'  # remap field of Records
                     ),
            start_idx=dict(title='Start Timestamp', mapping='index'),
            end_idx=dict(title='End Timestamp', mapping='index'),
            status=dict(title='Status', mapping=RangeStatus))),
    readonly=True,
    as_attrs=False)
"""_"""

__pdoc__['ranges_field_config'] = f"""Field config for `Ranges`.

```json
{ranges_field_config.to_doc()}
```
"""

ranges_attach_field_config = Config(dict(status=dict(attach_filters=True)),
Beispiel #7
0
    for k in names:
        settings[k] = getattr(this_module, k)
    Config(settings).save(fname, **kwargs)


def load(fname, names=__all__, **kwargs):
    """Load settings from a file."""
    settings = Config.load(fname, **kwargs)
    for k in names:
        setattr(this_module, k, settings[k])


__pdoc__ = {}

# Color schema
color_schema = Config(increasing="#1b9e76", decreasing="#d95f02")
"""_"""

__pdoc__['color_schema'] = f"""Color schema.

```plaintext
{json.dumps(color_schema, indent=2, default=str)}
```
"""

# Contrast color schema
contrast_color_schema = Config(blue="#4285F4",
                               orange="#FFAA00",
                               green="#37B13F",
                               red="#EA4335",
                               gray="#E2E2E2")
Beispiel #8
0
def load(fname: str, names: tp.Iterable[str] = __all__, **kwargs) -> None:
    """Load settings from a file."""
    settings = Config.load(fname, **kwargs)
    for k in names:
        setattr(this_module, k, settings[k])
Beispiel #9
0
logs_field_config = Config(dict(
    dtype=log_dt,
    settings=dict(id=dict(title='Log Id'),
                  group=dict(title='Group'),
                  cash=dict(title='Cash'),
                  position=dict(title='Position'),
                  debt=dict(title='Debt'),
                  free_cash=dict(title='Free Cash'),
                  val_price=dict(title='Val Price'),
                  value=dict(title='Value'),
                  req_size=dict(title='Request Size'),
                  req_price=dict(title='Request Price'),
                  req_size_type=dict(title='Request Size Type',
                                     mapping=SizeType),
                  req_direction=dict(title='Request Direction',
                                     mapping=Direction),
                  req_fees=dict(title='Request Fees'),
                  req_fixed_fees=dict(title='Request Fixed Fees'),
                  req_slippage=dict(title='Request Slippage'),
                  req_min_size=dict(title='Request Min Size'),
                  req_max_size=dict(title='Request Max Size'),
                  req_size_granularity=dict(title='Request Size Granularity'),
                  req_reject_prob=dict(title='Request Rejection Prob'),
                  req_lock_cash=dict(title='Request Lock Cash'),
                  req_allow_partial=dict(title='Request Allow Partial'),
                  req_raise_reject=dict(title='Request Raise Rejection'),
                  req_log=dict(title='Request Log'),
                  new_cash=dict(title='New Cash'),
                  new_position=dict(title='New Position'),
                  new_debt=dict(title='New Debt'),
                  new_free_cash=dict(title='New Free Cash'),
                  new_val_price=dict(title='New Val Price'),
                  new_value=dict(title='New Value'),
                  res_size=dict(title='Result Size'),
                  res_price=dict(title='Result Price'),
                  res_fees=dict(title='Result Fees'),
                  res_side=dict(title='Result Side', mapping=OrderSide),
                  res_status=dict(title='Result Status', mapping=OrderStatus),
                  res_status_info=dict(title='Result Status Info',
                                       mapping=OrderStatusInfo),
                  order_id=dict(title='Order Id'))),
                           readonly=True,
                           as_attrs=False)
Beispiel #10
0
    for k in names:
        settings[k] = getattr(this_module, k)
    Config(settings).save(fname, **kwargs)


def load(fname: str, names: tp.Iterable[str] = __all__, **kwargs) -> None:
    """Load settings from a file."""
    settings = Config.load(fname, **kwargs)
    for k in names:
        setattr(this_module, k, settings[k])


__pdoc__ = {}

# Color schema
color_schema = Config(increasing="#1b9e76", decreasing="#d95f02")
"""_"""

__pdoc__['color_schema'] = f"""Color schema.

```plaintext
{json.dumps(color_schema, indent=2, default=str)}
```
"""

# Contrast color schema
contrast_color_schema = Config(blue="#4285F4",
                               orange="#FFAA00",
                               green="#37B13F",
                               red="#EA4335",
                               gray="#E2E2E2")
Beispiel #11
0
def save(fname: str, names: tp.Iterable[str] = __all__, **kwargs) -> None:
    """Save settings to a file."""
    settings = dict()
    for k in names:
        settings[k] = getattr(this_module, k)
    Config(settings).save(fname, **kwargs)
Beispiel #12
0
import numpy as np
import plotly.graph_objects as go

from vectorbt.utils.config import Config
from vectorbt.utils.docs import fix_class_for_docs
from vectorbt.utils.widgets import CustomFigureWidget
from vectorbt.signals.enums import StopType
from vectorbt.signals.factory import SignalFactory
from vectorbt.signals.nb import (rand_enex_apply_nb, rand_by_prob_choice_nb,
                                 stop_choice_nb, adv_stop_choice_nb)

flex_elem_param_config = Config(
    array_like=
    True,  # passing a NumPy array means passing one value, for multiple use list
    bc_to_input=True,  # broadcast to input
    broadcast_kwargs=dict(
        keep_raw=True  # keep original shape for flexible indexing to save memory
    ))
"""Config for flexible element-wise parameters."""

flex_col_param_config = Config(
    array_like=True,
    bc_to_input=1,  # broadcast to axis 1 (columns)
    broadcast_kwargs=dict(keep_raw=True))
"""Config for flexible column-wise parameters."""

# ############# Random signals ############# #

RAND = SignalFactory(
    class_name='RAND',
Beispiel #13
0
Changes take effect immediately."""

import numpy as np
import json

from vectorbt.utils.config import Config

__pdoc__ = {}

# Color schema
color_schema = Config(
    blue="#1f77b4",
    orange="#ff7f0e",
    green="#2ca02c",
    red="#dc3912",
    purple="#9467bd",
    brown="#8c564b",
    pink="#e377c2",
    gray="#7f7f7f",
    yellow="#bcbd22",
    cyan="#17becf"
)
"""_"""

__pdoc__['color_schema'] = f"""Color schema.

```plaintext
{json.dumps(color_schema, indent=2)}
```
"""

# Contrast color schema
Beispiel #14
0
class Records(Wrapping, StatsBuilderMixin, PlotsBuilderMixin, RecordsWithFields, metaclass=MetaRecords):
    """Wraps the actual records array (such as trades) and exposes methods for mapping
    it to some array of values (such as PnL of each trade).

    Args:
        wrapper (ArrayWrapper): Array wrapper.

            See `vectorbt.base.array_wrapper.ArrayWrapper`.
        records_arr (array_like): A structured NumPy array of records.

            Must have the fields `id` (record index) and `col` (column index).
        col_mapper (ColumnMapper): Column mapper if already known.

            !!! note
                It depends on `records_arr`, so make sure to invalidate `col_mapper` upon creating
                a `Records` instance with a modified `records_arr`.

                `Records.replace` does it automatically.
        **kwargs: Custom keyword arguments passed to the config.

            Useful if any subclass wants to extend the config.
    """

    _field_config: tp.ClassVar[Config] = Config(
        dict(
            dtype=None,
            settings=dict(
                id=dict(
                    name='id',
                    title='Id'
                ),
                col=dict(
                    name='col',
                    title='Column',
                    mapping='columns'
                ),
                idx=dict(
                    name='idx',
                    title='Timestamp',
                    mapping='index'
                )
            )
        ),
        readonly=True,
        as_attrs=False
    )

    @property
    def field_config(self) -> Config:
        """Field config of `${cls_name}`.

        ```json
        ${field_config}
        ```
        """
        return self._field_config

    def __init__(self,
                 wrapper: ArrayWrapper,
                 records_arr: tp.RecordArray,
                 col_mapper: tp.Optional[ColumnMapper] = None,
                 **kwargs) -> None:
        Wrapping.__init__(
            self,
            wrapper,
            records_arr=records_arr,
            col_mapper=col_mapper,
            **kwargs
        )
        StatsBuilderMixin.__init__(self)

        # Check fields
        records_arr = np.asarray(records_arr)
        checks.assert_not_none(records_arr.dtype.fields)
        field_names = {
            dct.get('name', field_name)
            for field_name, dct in self.field_config.get('settings', {}).items()
        }
        dtype = self.field_config.get('dtype', None)
        if dtype is not None:
            for field in dtype.names:
                if field not in records_arr.dtype.names:
                    if field not in field_names:
                        raise TypeError(f"Field '{field}' from {dtype} cannot be found in records or config")

        self._records_arr = records_arr
        if col_mapper is None:
            col_mapper = ColumnMapper(wrapper, self.col_arr)
        self._col_mapper = col_mapper

    def replace(self: RecordsT, **kwargs) -> RecordsT:
        """See `vectorbt.utils.config.Configured.replace`.

        Also, makes sure that `Records.col_mapper` is not passed to the new instance."""
        if self.config.get('col_mapper', None) is not None:
            if 'wrapper' in kwargs:
                if self.wrapper is not kwargs.get('wrapper'):
                    kwargs['col_mapper'] = None
            if 'records_arr' in kwargs:
                if self.records_arr is not kwargs.get('records_arr'):
                    kwargs['col_mapper'] = None
        return Configured.replace(self, **kwargs)

    def get_by_col_idxs(self, col_idxs: tp.Array1d) -> tp.RecordArray:
        """Get records corresponding to column indices.

        Returns new records array."""
        if self.col_mapper.is_sorted():
            new_records_arr = nb.record_col_range_select_nb(
                self.values, self.col_mapper.col_range, to_1d_array(col_idxs))  # faster
        else:
            new_records_arr = nb.record_col_map_select_nb(
                self.values, self.col_mapper.col_map, to_1d_array(col_idxs))
        return new_records_arr

    def indexing_func_meta(self, pd_indexing_func: tp.PandasIndexingFunc, **kwargs) -> IndexingMetaT:
        """Perform indexing on `Records` and return metadata."""
        new_wrapper, _, group_idxs, col_idxs = \
            self.wrapper.indexing_func_meta(pd_indexing_func, column_only_select=True, **kwargs)
        new_records_arr = self.get_by_col_idxs(col_idxs)
        return new_wrapper, new_records_arr, group_idxs, col_idxs

    def indexing_func(self: RecordsT, pd_indexing_func: tp.PandasIndexingFunc, **kwargs) -> RecordsT:
        """Perform indexing on `Records`."""
        new_wrapper, new_records_arr, _, _ = self.indexing_func_meta(pd_indexing_func, **kwargs)
        return self.replace(
            wrapper=new_wrapper,
            records_arr=new_records_arr
        )

    @property
    def records_arr(self) -> tp.RecordArray:
        """Records array."""
        return self._records_arr

    @property
    def values(self) -> tp.RecordArray:
        """Records array."""
        return self.records_arr

    def __len__(self) -> int:
        return len(self.values)

    @property
    def records(self) -> tp.Frame:
        """Records."""
        return pd.DataFrame.from_records(self.values)

    @property
    def recarray(self) -> tp.RecArray:
        return self.values.view(np.recarray)

    @property
    def col_mapper(self) -> ColumnMapper:
        """Column mapper.

        See `vectorbt.records.col_mapper.ColumnMapper`."""
        return self._col_mapper

    @property
    def records_readable(self) -> tp.Frame:
        """Records in readable format."""
        df = self.records.copy()
        field_settings = self.field_config.get('settings', {})
        for col_name in df.columns:
            if col_name in field_settings:
                dct = field_settings[col_name]
                if dct.get('ignore', False):
                    df = df.drop(columns=col_name)
                    continue
                field_name = dct.get('name', col_name)
                if 'title' in dct:
                    title = dct['title']
                    new_columns = dict()
                    new_columns[field_name] = title
                    df.rename(columns=new_columns, inplace=True)
                else:
                    title = field_name
                if 'mapping' in dct:
                    if isinstance(dct['mapping'], str) and dct['mapping'] == 'index':
                        df[title] = self.get_map_field_to_index(col_name)
                    else:
                        df[title] = self.get_apply_mapping_arr(col_name)
        return df

    def get_field_setting(self, field: str, setting: str, default: tp.Any = None) -> tp.Any:
        """Resolve any setting of the field. Uses `Records.field_config`."""
        return self.field_config.get('settings', {}).get(field, {}).get(setting, default)

    def get_field_name(self, field: str) -> str:
        """Resolve the name of the field. Uses `Records.field_config`.."""
        return self.get_field_setting(field, 'name', field)

    def get_field_title(self, field: str) -> str:
        """Resolve the title of the field. Uses `Records.field_config`."""
        return self.get_field_setting(field, 'title', field)

    def get_field_mapping(self, field: str) -> tp.Optional[tp.MappingLike]:
        """Resolve the mapping of the field. Uses `Records.field_config`."""
        return self.get_field_setting(field, 'mapping', None)

    def get_field_arr(self, field: str) -> tp.Array1d:
        """Resolve the array of the field. Uses `Records.field_config`."""
        return self.values[self.get_field_name(field)]

    def get_map_field(self, field: str, **kwargs) -> MappedArray:
        """Resolve the mapped array of the field. Uses `Records.field_config`."""
        return self.map_field(self.get_field_name(field), mapping=self.get_field_mapping(field), **kwargs)

    def get_apply_mapping_arr(self, field: str, **kwargs) -> tp.Array1d:
        """Resolve the mapped array on the field, with mapping applied. Uses `Records.field_config`."""
        return self.get_map_field(field, **kwargs).apply_mapping().values

    def get_map_field_to_index(self, field: str, **kwargs) -> tp.Index:
        """Resolve the mapped array on the field, with index applied. Uses `Records.field_config`."""
        return self.get_map_field(field, **kwargs).to_index()

    @property
    def id_arr(self) -> tp.Array1d:
        """Get id array."""
        return self.values[self.get_field_name('id')]

    @property
    def col_arr(self) -> tp.Array1d:
        """Get column array."""
        return self.values[self.get_field_name('col')]

    @property
    def idx_arr(self) -> tp.Optional[tp.Array1d]:
        """Get index array."""
        idx_field_name = self.get_field_name('idx')
        if idx_field_name is None:
            return None
        return self.values[idx_field_name]

    @cached_method
    def is_sorted(self, incl_id: bool = False) -> bool:
        """Check whether records are sorted."""
        if incl_id:
            return nb.is_col_idx_sorted_nb(self.col_arr, self.id_arr)
        return nb.is_col_sorted_nb(self.col_arr)

    def sort(self: RecordsT, incl_id: bool = False, group_by: tp.GroupByLike = None, **kwargs) -> RecordsT:
        """Sort records by columns (primary) and ids (secondary, optional).

        !!! note
            Sorting is expensive. A better approach is to append records already in the correct order."""
        if self.is_sorted(incl_id=incl_id):
            return self.replace(**kwargs).regroup(group_by)
        if incl_id:
            ind = np.lexsort((self.id_arr, self.col_arr))  # expensive!
        else:
            ind = np.argsort(self.col_arr)
        return self.replace(records_arr=self.values[ind], **kwargs).regroup(group_by)

    def apply_mask(self: RecordsT, mask: tp.Array1d, group_by: tp.GroupByLike = None, **kwargs) -> RecordsT:
        """Return a new class instance, filtered by mask."""
        mask_indices = np.flatnonzero(mask)
        return self.replace(
            records_arr=np.take(self.values, mask_indices),
            **kwargs
        ).regroup(group_by)

    def map_array(self,
                  a: tp.ArrayLike,
                  idx_arr: tp.Optional[tp.ArrayLike] = None,
                  mapping: tp.Optional[tp.MappingLike] = None,
                  group_by: tp.GroupByLike = None,
                  **kwargs) -> MappedArray:
        """Convert array to mapped array.

         The length of the array should match that of the records."""
        if not isinstance(a, np.ndarray):
            a = np.asarray(a)
        checks.assert_shape_equal(a, self.values)
        if idx_arr is None:
            idx_arr = self.idx_arr
        return MappedArray(
            self.wrapper,
            a,
            self.col_arr,
            id_arr=self.id_arr,
            idx_arr=idx_arr,
            mapping=mapping,
            col_mapper=self.col_mapper,
            **kwargs
        ).regroup(group_by)

    def map_field(self, field: str, **kwargs) -> MappedArray:
        """Convert field to mapped array.

        `**kwargs` are passed to `Records.map_array`."""
        mapped_arr = self.values[field]
        return self.map_array(mapped_arr, **kwargs)

    def map(self,
            map_func_nb: tp.RecordMapFunc, *args,
            dtype: tp.Optional[tp.DTypeLike] = None,
            **kwargs) -> MappedArray:
        """Map each record to a scalar value. Returns mapped array.

        See `vectorbt.records.nb.map_records_nb`.

        `**kwargs` are passed to `Records.map_array`."""
        checks.assert_numba_func(map_func_nb)
        mapped_arr = nb.map_records_nb(self.values, map_func_nb, *args)
        mapped_arr = np.asarray(mapped_arr, dtype=dtype)
        return self.map_array(mapped_arr, **kwargs)

    def apply(self,
              apply_func_nb: tp.RecordApplyFunc, *args,
              group_by: tp.GroupByLike = None,
              apply_per_group: bool = False,
              dtype: tp.Optional[tp.DTypeLike] = None,
              **kwargs) -> MappedArray:
        """Apply function on records per column/group. Returns mapped array.

        Applies per group if `apply_per_group` is True.

        See `vectorbt.records.nb.apply_on_records_nb`.

        `**kwargs` are passed to `Records.map_array`."""
        checks.assert_numba_func(apply_func_nb)
        if apply_per_group:
            col_map = self.col_mapper.get_col_map(group_by=group_by)
        else:
            col_map = self.col_mapper.get_col_map(group_by=False)
        mapped_arr = nb.apply_on_records_nb(self.values, col_map, apply_func_nb, *args)
        mapped_arr = np.asarray(mapped_arr, dtype=dtype)
        return self.map_array(mapped_arr, group_by=group_by, **kwargs)

    @cached_method
    def count(self, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None) -> tp.MaybeSeries:
        """Return count by column."""
        wrap_kwargs = merge_dicts(dict(name_or_index='count'), wrap_kwargs)
        return self.wrapper.wrap_reduced(
            self.col_mapper.get_col_map(group_by=group_by)[1],
            group_by=group_by, **wrap_kwargs)

    # ############# Stats ############# #

    @property
    def stats_defaults(self) -> tp.Kwargs:
        """Defaults for `Records.stats`.

        Merges `vectorbt.generic.stats_builder.StatsBuilderMixin.stats_defaults` and
        `records.stats` from `vectorbt._settings.settings`."""
        from vectorbt._settings import settings
        records_stats_cfg = settings['records']['stats']

        return merge_dicts(
            StatsBuilderMixin.stats_defaults.__get__(self),
            records_stats_cfg
        )

    _metrics: tp.ClassVar[Config] = Config(
        dict(
            start=dict(
                title='Start',
                calc_func=lambda self: self.wrapper.index[0],
                agg_func=None,
                tags='wrapper'
            ),
            end=dict(
                title='End',
                calc_func=lambda self: self.wrapper.index[-1],
                agg_func=None,
                tags='wrapper'
            ),
            period=dict(
                title='Period',
                calc_func=lambda self: len(self.wrapper.index),
                apply_to_timedelta=True,
                agg_func=None,
                tags='wrapper'
            ),
            count=dict(
                title='Count',
                calc_func='count',
                tags='records'
            )
        ),
        copy_kwargs=dict(copy_mode='deep')
    )

    @property
    def metrics(self) -> Config:
        return self._metrics

    # ############# Plotting ############# #

    @property
    def plots_defaults(self) -> tp.Kwargs:
        """Defaults for `Records.plots`.

        Merges `vectorbt.generic.plots_builder.PlotsBuilderMixin.plots_defaults` and
        `records.plots` from `vectorbt._settings.settings`."""
        from vectorbt._settings import settings
        records_plots_cfg = settings['records']['plots']

        return merge_dicts(
            PlotsBuilderMixin.plots_defaults.__get__(self),
            records_plots_cfg
        )

    @property
    def subplots(self) -> Config:
        return self._subplots

    # ############# Docs ############# #

    @classmethod
    def build_field_config_doc(cls, source_cls: tp.Optional[type] = None) -> str:
        """Build field config documentation."""
        if source_cls is None:
            source_cls = Records
        return string.Template(
            inspect.cleandoc(get_dict_attr(source_cls, 'field_config').__doc__)
        ).substitute(
            {'field_config': cls.field_config.to_doc(), 'cls_name': cls.__name__}
        )

    @classmethod
    def override_field_config_doc(cls, __pdoc__: dict, source_cls: tp.Optional[type] = None) -> None:
        """Call this method on each subclass that overrides `field_config`."""
        __pdoc__[cls.__name__ + '.field_config'] = cls.build_field_config_doc(source_cls=source_cls)
Beispiel #15
0
class Data(Wrapping, StatsBuilderMixin, PlotsBuilderMixin, metaclass=MetaData):
    """Class that downloads, updates, and manages data coming from a data source."""

    def __init__(self,
                 wrapper: ArrayWrapper,
                 data: tp.Data,
                 tz_localize: tp.Optional[tp.TimezoneLike],
                 tz_convert: tp.Optional[tp.TimezoneLike],
                 missing_index: str,
                 missing_columns: str,
                 download_kwargs: dict,
                 **kwargs) -> None:
        Wrapping.__init__(
            self,
            wrapper,
            data=data,
            tz_localize=tz_localize,
            tz_convert=tz_convert,
            missing_index=missing_index,
            missing_columns=missing_columns,
            download_kwargs=download_kwargs,
            **kwargs
        )
        StatsBuilderMixin.__init__(self)
        PlotsBuilderMixin.__init__(self)

        checks.assert_instance_of(data, dict)
        for k, v in data.items():
            checks.assert_meta_equal(v, data[list(data.keys())[0]])
        self._data = data
        self._tz_localize = tz_localize
        self._tz_convert = tz_convert
        self._missing_index = missing_index
        self._missing_columns = missing_columns
        self._download_kwargs = download_kwargs

    def indexing_func(self: DataT, pd_indexing_func: tp.PandasIndexingFunc, **kwargs) -> DataT:
        """Perform indexing on `Data`."""
        new_wrapper = pd_indexing_func(self.wrapper)
        new_data = {k: pd_indexing_func(v) for k, v in self.data.items()}
        return self.replace(
            wrapper=new_wrapper,
            data=new_data
        )

    @property
    def data(self) -> tp.Data:
        """Data dictionary keyed by symbol."""
        return self._data

    @property
    def symbols(self) -> tp.List[tp.Label]:
        """List of symbols."""
        return list(self.data.keys())

    @property
    def tz_localize(self) -> tp.Optional[tp.TimezoneLike]:
        """`tz_localize` initially passed to `Data.download_symbol`."""
        return self._tz_localize

    @property
    def tz_convert(self) -> tp.Optional[tp.TimezoneLike]:
        """`tz_convert` initially passed to `Data.download_symbol`."""
        return self._tz_convert

    @property
    def missing_index(self) -> str:
        """`missing_index` initially passed to `Data.download_symbol`."""
        return self._missing_index

    @property
    def missing_columns(self) -> str:
        """`missing_columns` initially passed to `Data.download_symbol`."""
        return self._missing_columns

    @property
    def download_kwargs(self) -> dict:
        """Keyword arguments initially passed to `Data.download_symbol`."""
        return self._download_kwargs

    @classmethod
    def align_index(cls, data: tp.Data, missing: str = 'nan') -> tp.Data:
        """Align data to have the same index.

        The argument `missing` accepts the following values:

        * 'nan': set missing data points to NaN
        * 'drop': remove missing data points
        * 'raise': raise an error"""
        if len(data) == 1:
            return data

        index = None
        for k, v in data.items():
            if index is None:
                index = v.index
            else:
                if len(index.intersection(v.index)) != len(index.union(v.index)):
                    if missing == 'nan':
                        warnings.warn("Symbols have mismatching index. "
                                      "Setting missing data points to NaN.", stacklevel=2)
                        index = index.union(v.index)
                    elif missing == 'drop':
                        warnings.warn("Symbols have mismatching index. "
                                      "Dropping missing data points.", stacklevel=2)
                        index = index.intersection(v.index)
                    elif missing == 'raise':
                        raise ValueError("Symbols have mismatching index")
                    else:
                        raise ValueError(f"missing='{missing}' is not recognized")

        # reindex
        new_data = {k: v.reindex(index=index) for k, v in data.items()}
        return new_data

    @classmethod
    def align_columns(cls, data: tp.Data, missing: str = 'raise') -> tp.Data:
        """Align data to have the same columns.

        See `Data.align_index` for `missing`."""
        if len(data) == 1:
            return data

        columns = None
        multiple_columns = False
        name_is_none = False
        for k, v in data.items():
            if isinstance(v, pd.Series):
                if v.name is None:
                    name_is_none = True
                v = v.to_frame()
            else:
                multiple_columns = True
            if columns is None:
                columns = v.columns
            else:
                if len(columns.intersection(v.columns)) != len(columns.union(v.columns)):
                    if missing == 'nan':
                        warnings.warn("Symbols have mismatching columns. "
                                      "Setting missing data points to NaN.", stacklevel=2)
                        columns = columns.union(v.columns)
                    elif missing == 'drop':
                        warnings.warn("Symbols have mismatching columns. "
                                      "Dropping missing data points.", stacklevel=2)
                        columns = columns.intersection(v.columns)
                    elif missing == 'raise':
                        raise ValueError("Symbols have mismatching columns")
                    else:
                        raise ValueError(f"missing='{missing}' is not recognized")

        # reindex
        new_data = {}
        for k, v in data.items():
            if isinstance(v, pd.Series):
                v = v.to_frame(name=v.name)
            v = v.reindex(columns=columns)
            if not multiple_columns:
                v = v[columns[0]]
                if name_is_none:
                    v = v.rename(None)
            new_data[k] = v
        return new_data

    @classmethod
    def select_symbol_kwargs(cls, symbol: tp.Label, kwargs: dict) -> dict:
        """Select keyword arguments belonging to `symbol`."""
        _kwargs = dict()
        for k, v in kwargs.items():
            if isinstance(v, symbol_dict):
                if symbol in v:
                    _kwargs[k] = v[symbol]
            else:
                _kwargs[k] = v
        return _kwargs

    @classmethod
    def from_data(cls: tp.Type[DataT],
                  data: tp.Data,
                  tz_localize: tp.Optional[tp.TimezoneLike] = None,
                  tz_convert: tp.Optional[tp.TimezoneLike] = None,
                  missing_index: tp.Optional[str] = None,
                  missing_columns: tp.Optional[str] = None,
                  wrapper_kwargs: tp.KwargsLike = None,
                  **kwargs) -> DataT:
        """Create a new `Data` instance from (aligned) data.

        Args:
            data (dict): Dictionary of array-like objects keyed by symbol.
            tz_localize (timezone_like): If the index is tz-naive, convert to a timezone.

                See `vectorbt.utils.datetime_.to_timezone`.
            tz_convert (timezone_like): Convert the index from one timezone to another.

                See `vectorbt.utils.datetime_.to_timezone`.
            missing_index (str): See `Data.align_index`.
            missing_columns (str): See `Data.align_columns`.
            wrapper_kwargs (dict): Keyword arguments passed to `vectorbt.base.array_wrapper.ArrayWrapper`.
            **kwargs: Keyword arguments passed to the `__init__` method.

        For defaults, see `data` in `vectorbt._settings.settings`."""
        from vectorbt._settings import settings
        data_cfg = settings['data']

        # Get global defaults
        if tz_localize is None:
            tz_localize = data_cfg['tz_localize']
        if tz_convert is None:
            tz_convert = data_cfg['tz_convert']
        if missing_index is None:
            missing_index = data_cfg['missing_index']
        if missing_columns is None:
            missing_columns = data_cfg['missing_columns']
        if wrapper_kwargs is None:
            wrapper_kwargs = {}

        data = data.copy()
        for k, v in data.items():
            # Convert array to pandas
            if not isinstance(v, (pd.Series, pd.DataFrame)):
                v = np.asarray(v)
                if v.ndim == 1:
                    v = pd.Series(v)
                else:
                    v = pd.DataFrame(v)

            # Perform operations with datetime-like index
            if isinstance(v.index, pd.DatetimeIndex):
                if tz_localize is not None:
                    if not is_tz_aware(v.index):
                        v = v.tz_localize(to_timezone(tz_localize))
                if tz_convert is not None:
                    v = v.tz_convert(to_timezone(tz_convert))
                v.index.freq = v.index.inferred_freq
            data[k] = v

        # Align index and columns
        data = cls.align_index(data, missing=missing_index)
        data = cls.align_columns(data, missing=missing_columns)

        # Create new instance
        symbols = list(data.keys())
        wrapper = ArrayWrapper.from_obj(data[symbols[0]], **wrapper_kwargs)
        return cls(
            wrapper,
            data,
            tz_localize=tz_localize,
            tz_convert=tz_convert,
            missing_index=missing_index,
            missing_columns=missing_columns,
            **kwargs
        )

    @classmethod
    def download_symbol(cls, symbol: tp.Label, **kwargs) -> tp.SeriesFrame:
        """Abstract method to download a symbol."""
        raise NotImplementedError

    @classmethod
    def download(cls: tp.Type[DataT],
                 symbols: tp.Union[tp.Label, tp.Labels],
                 tz_localize: tp.Optional[tp.TimezoneLike] = None,
                 tz_convert: tp.Optional[tp.TimezoneLike] = None,
                 missing_index: tp.Optional[str] = None,
                 missing_columns: tp.Optional[str] = None,
                 wrapper_kwargs: tp.KwargsLike = None,
                 **kwargs) -> DataT:
        """Download data using `Data.download_symbol`.

        Args:
            symbols (hashable or sequence of hashable): One or multiple symbols.

                !!! note
                    Tuple is considered as a single symbol (since hashable).
            tz_localize (any): See `Data.from_data`.
            tz_convert (any): See `Data.from_data`.
            missing_index (str): See `Data.from_data`.
            missing_columns (str): See `Data.from_data`.
            wrapper_kwargs (dict): See `Data.from_data`.
            **kwargs: Passed to `Data.download_symbol`.

                If two symbols require different keyword arguments, pass `symbol_dict` for each argument.
        """
        if checks.is_hashable(symbols):
            symbols = [symbols]
        elif not checks.is_sequence(symbols):
            raise TypeError("Symbols must be either hashable or sequence of hashable")

        data = dict()
        for s in symbols:
            # Select keyword arguments for this symbol
            _kwargs = cls.select_symbol_kwargs(s, kwargs)

            # Download data for this symbol
            data[s] = cls.download_symbol(s, **_kwargs)

        # Create new instance from data
        return cls.from_data(
            data,
            tz_localize=tz_localize,
            tz_convert=tz_convert,
            missing_index=missing_index,
            missing_columns=missing_columns,
            wrapper_kwargs=wrapper_kwargs,
            download_kwargs=kwargs
        )

    def update_symbol(self, symbol: tp.Label, **kwargs) -> tp.SeriesFrame:
        """Abstract method to update a symbol."""
        raise NotImplementedError

    def update(self: DataT, **kwargs) -> DataT:
        """Update the data using `Data.update_symbol`.

        Args:
            **kwargs: Passed to `Data.update_symbol`.

                If two symbols require different keyword arguments, pass `symbol_dict` for each argument.

        !!! note
            Returns a new `Data` instance."""
        new_data = dict()
        for k, v in self.data.items():
            # Select keyword arguments for this symbol
            _kwargs = self.select_symbol_kwargs(k, kwargs)

            # Download new data for this symbol
            new_obj = self.update_symbol(k, **_kwargs)

            # Convert array to pandas
            if not isinstance(new_obj, (pd.Series, pd.DataFrame)):
                new_obj = np.asarray(new_obj)
                index = pd.RangeIndex(
                    start=v.index[-1],
                    stop=v.index[-1] + new_obj.shape[0],
                    step=1
                )
                if new_obj.ndim == 1:
                    new_obj = pd.Series(new_obj, index=index)
                else:
                    new_obj = pd.DataFrame(new_obj, index=index)

            # Perform operations with datetime-like index
            if isinstance(new_obj.index, pd.DatetimeIndex):
                if self.tz_localize is not None:
                    if not is_tz_aware(new_obj.index):
                        new_obj = new_obj.tz_localize(to_timezone(self.tz_localize))
                if self.tz_convert is not None:
                    new_obj = new_obj.tz_convert(to_timezone(self.tz_convert))

            new_data[k] = new_obj

        # Align index and columns
        new_data = self.align_index(new_data, missing=self.missing_index)
        new_data = self.align_columns(new_data, missing=self.missing_columns)

        # Concatenate old and new data
        for k, v in new_data.items():
            if isinstance(self.data[k], pd.Series):
                if isinstance(v, pd.DataFrame):
                    v = v[self.data[k].name]
            else:
                v = v[self.data[k].columns]
            v = pd.concat((self.data[k], v), axis=0)
            v = v[~v.index.duplicated(keep='last')]
            if isinstance(v.index, pd.DatetimeIndex):
                v.index.freq = v.index.inferred_freq
            new_data[k] = v

        # Create new instance
        new_index = new_data[self.symbols[0]].index
        return self.replace(
            wrapper=self.wrapper.replace(index=new_index),
            data=new_data
        )

    @cached_method
    def concat(self, level_name: str = 'symbol') -> tp.Data:
        """Return a dict of Series/DataFrames with symbols as columns, keyed by column name."""
        first_data = self.data[self.symbols[0]]
        index = first_data.index
        if isinstance(first_data, pd.Series):
            columns = pd.Index([first_data.name])
        else:
            columns = first_data.columns
        if len(self.symbols) > 1:
            new_data = {c: pd.DataFrame(
                index=index,
                columns=pd.Index(self.symbols, name=level_name)
            ) for c in columns}
        else:
            new_data = {c: pd.Series(
                index=index,
                name=self.symbols[0]
            ) for c in columns}
        for c in columns:
            for s in self.symbols:
                if isinstance(self.data[s], pd.Series):
                    col_data = self.data[s]
                else:
                    col_data = self.data[s][c]
                if len(self.symbols) > 1:
                    new_data[c].loc[:, s] = col_data
                else:
                    new_data[c].loc[:] = col_data

        return new_data

    def get(self, column: tp.Optional[tp.Label] = None, **kwargs) -> tp.MaybeTuple[tp.SeriesFrame]:
        """Get column data.

        If one symbol, returns data for that symbol.
        If multiple symbols, performs concatenation first and returns a DataFrame if one column
        and a tuple of DataFrames if a list of columns passed."""
        if len(self.symbols) == 1:
            if column is None:
                return self.data[self.symbols[0]]
            return self.data[self.symbols[0]][column]

        concat_data = self.concat(**kwargs)
        if len(concat_data) == 1:
            return tuple(concat_data.values())[0]
        if column is not None:
            if isinstance(column, list):
                return tuple([concat_data[c] for c in column])
            return concat_data[column]
        return tuple(concat_data.values())

    # ############# Stats ############# #

    @property
    def stats_defaults(self) -> tp.Kwargs:
        """Defaults for `Data.stats`.

        Merges `vectorbt.generic.stats_builder.StatsBuilderMixin.stats_defaults` and
        `data.stats` from `vectorbt._settings.settings`."""
        from vectorbt._settings import settings
        data_stats_cfg = settings['data']['stats']

        return merge_dicts(
            StatsBuilderMixin.stats_defaults.__get__(self),
            data_stats_cfg
        )

    _metrics: tp.ClassVar[Config] = Config(
        dict(
            start=dict(
                title='Start',
                calc_func=lambda self: self.wrapper.index[0],
                agg_func=None,
                tags='wrapper'
            ),
            end=dict(
                title='End',
                calc_func=lambda self: self.wrapper.index[-1],
                agg_func=None,
                tags='wrapper'
            ),
            period=dict(
                title='Period',
                calc_func=lambda self: len(self.wrapper.index),
                apply_to_timedelta=True,
                agg_func=None,
                tags='wrapper'
            ),
            total_symbols=dict(
                title='Total Symbols',
                calc_func=lambda self: len(self.symbols),
                agg_func=None,
                tags='data'
            ),
            null_counts=dict(
                title='Null Counts',
                calc_func=lambda self, group_by:
                {
                    k: v.isnull().vbt(wrapper=self.wrapper).sum(group_by=group_by)
                    for k, v in self.data.items()
                },
                tags='data'
            )
        ),
        copy_kwargs=dict(copy_mode='deep')
    )

    @property
    def metrics(self) -> Config:
        return self._metrics

    # ############# Plotting ############# #

    def plot(self,
             column: tp.Optional[tp.Label] = None,
             base: tp.Optional[float] = None,
             **kwargs) -> tp.Union[tp.BaseFigure, plotting.Scatter]:  # pragma: no cover
        """Plot orders.

        Args:
            column (str): Name of the column to plot.
            base (float): Rebase all series of a column to a given intial base.

                !!! note
                    The column should contain prices.
            kwargs (dict): Keyword arguments passed to `vectorbt.generic.accessors.GenericAccessor.plot`.

        ## Example

        ```python-repl
        >>> import vectorbt as vbt

        >>> start = '2021-01-01 UTC'  # crypto is in UTC
        >>> end = '2021-06-01 UTC'
        >>> data = vbt.YFData.download(['BTC-USD', 'ETH-USD', 'ADA-USD'], start=start, end=end)

        >>> data.plot(column='Close', base=1)
        ```

        ![](/docs/img/data_plot.svg)"""
        self_col = self.select_one(column=column, group_by=False)
        data = self_col.get()
        if base is not None:
            data = data.vbt.rebase(base)
        return data.vbt.plot(**kwargs)

    @property
    def plots_defaults(self) -> tp.Kwargs:
        """Defaults for `Data.plots`.

        Merges `vectorbt.generic.plots_builder.PlotsBuilderMixin.plots_defaults` and
        `data.plots` from `vectorbt._settings.settings`."""
        from vectorbt._settings import settings
        data_plots_cfg = settings['data']['plots']

        return merge_dicts(
            PlotsBuilderMixin.plots_defaults.__get__(self),
            data_plots_cfg
        )

    _subplots: tp.ClassVar[Config] = Config(
        dict(
            plot=dict(
                check_is_not_grouped=True,
                plot_func='plot',
                pass_add_trace_kwargs=True,
                tags='data'
            )
        ),
        copy_kwargs=dict(copy_mode='deep')
    )

    @property
    def subplots(self) -> Config:
        return self._subplots
Beispiel #16
0
class Logs(Records):
    """Extends `Records` for working with log records."""
    @property
    def field_config(self) -> Config:
        return self._field_config

    # ############# Stats ############# #

    @property
    def stats_defaults(self) -> tp.Kwargs:
        """Defaults for `Logs.stats`.

        Merges `vectorbt.records.base.Records.stats_defaults` and
        `logs.stats` from `vectorbt._settings.settings`."""
        from vectorbt._settings import settings
        logs_stats_cfg = settings['logs']['stats']

        return merge_dicts(Records.stats_defaults.__get__(self),
                           logs_stats_cfg)

    _metrics: tp.ClassVar[Config] = Config(dict(
        start=dict(title='Start',
                   calc_func=lambda self: self.wrapper.index[0],
                   agg_func=None,
                   tags='wrapper'),
        end=dict(title='End',
                 calc_func=lambda self: self.wrapper.index[-1],
                 agg_func=None,
                 tags='wrapper'),
        period=dict(title='Period',
                    calc_func=lambda self: len(self.wrapper.index),
                    apply_to_timedelta=True,
                    agg_func=None,
                    tags='wrapper'),
        total_records=dict(title='Total Records',
                           calc_func='count',
                           tags='records'),
        res_status_counts=dict(title='Status Counts',
                               calc_func='res_status.value_counts',
                               incl_all_keys=True,
                               post_calc_func=lambda self, out, settings:
                               to_dict(out, orient='index_series'),
                               tags=['logs', 'res_status', 'value_counts']),
        res_status_info_counts=dict(
            title='Status Info Counts',
            calc_func='res_status_info.value_counts',
            post_calc_func=lambda self, out, settings: to_dict(
                out, orient='index_series'),
            tags=['logs', 'res_status_info', 'value_counts'])),
                                           copy_kwargs=dict(copy_mode='deep'))

    @property
    def metrics(self) -> Config:
        return self._metrics

    # ############# Plotting ############# #

    @property
    def plots_defaults(self) -> tp.Kwargs:
        """Defaults for `Logs.plots`.

        Merges `vectorbt.records.base.Records.plots_defaults` and
        `logs.plots` from `vectorbt._settings.settings`."""
        from vectorbt._settings import settings
        logs_plots_cfg = settings['logs']['plots']

        return merge_dicts(Records.plots_defaults.__get__(self),
                           logs_plots_cfg)

    @property
    def subplots(self) -> Config:
        return self._subplots
Beispiel #17
0
        self.register_template(theme)
        self['plotting']['color_schema'].update(self['plotting']['themes'][theme]['color_schema'])
        self['plotting']['layout']['template'] = 'vbt_' + theme

    def reset_theme(self) -> None:
        """Reset to default theme."""
        self.set_theme('light')


settings = SettingsConfig(
    dict(
        numba=dict(
            check_func_type=True,
            check_func_suffix=False
        ),
        config=Config(),  # flex
        configured=dict(
            config=Config(  # flex
                dict(
                    readonly=True
                )
            ),
        ),
        caching=dict(
            enabled=True,
            whitelist=[
                CacheCondition(base_cls=ArrayWrapper),
                CacheCondition(base_cls=ColumnGrouper),
                CacheCondition(base_cls=ColumnMapper)
            ],
            blacklist=[]
Beispiel #18
0
class MappedArray(Wrapping,
                  StatsBuilderMixin,
                  PlotsBuilderMixin,
                  metaclass=MetaMappedArray):
    """Exposes methods for reducing, converting, and plotting arrays mapped by
    `vectorbt.records.base.Records` class.

    Args:
        wrapper (ArrayWrapper): Array wrapper.

            See `vectorbt.base.array_wrapper.ArrayWrapper`.
        mapped_arr (array_like): A one-dimensional array of mapped record values.
        col_arr (array_like): A one-dimensional column array.

            Must be of the same size as `mapped_arr`.
        id_arr (array_like): A one-dimensional id array. Defaults to simple range.

            Must be of the same size as `mapped_arr`.
        idx_arr (array_like): A one-dimensional index array. Optional.

            Must be of the same size as `mapped_arr`.
        mapping (namedtuple, dict or callable): Mapping.
        col_mapper (ColumnMapper): Column mapper if already known.

            !!! note
                It depends upon `wrapper` and `col_arr`, so make sure to invalidate `col_mapper` upon creating
                a `MappedArray` instance with a modified `wrapper` or `col_arr.

                `MappedArray.replace` does it automatically.
        **kwargs: Custom keyword arguments passed to the config.

            Useful if any subclass wants to extend the config.
    """
    def __init__(self,
                 wrapper: ArrayWrapper,
                 mapped_arr: tp.ArrayLike,
                 col_arr: tp.ArrayLike,
                 id_arr: tp.Optional[tp.ArrayLike] = None,
                 idx_arr: tp.Optional[tp.ArrayLike] = None,
                 mapping: tp.Optional[tp.MappingLike] = None,
                 col_mapper: tp.Optional[ColumnMapper] = None,
                 **kwargs) -> None:
        Wrapping.__init__(self,
                          wrapper,
                          mapped_arr=mapped_arr,
                          col_arr=col_arr,
                          id_arr=id_arr,
                          idx_arr=idx_arr,
                          mapping=mapping,
                          col_mapper=col_mapper,
                          **kwargs)
        StatsBuilderMixin.__init__(self)

        mapped_arr = np.asarray(mapped_arr)
        col_arr = np.asarray(col_arr)
        checks.assert_shape_equal(mapped_arr, col_arr, axis=0)
        if id_arr is None:
            id_arr = np.arange(len(mapped_arr))
        else:
            id_arr = np.asarray(id_arr)
        if idx_arr is not None:
            idx_arr = np.asarray(idx_arr)
            checks.assert_shape_equal(mapped_arr, idx_arr, axis=0)
        if mapping is not None:
            if isinstance(mapping, str):
                if mapping.lower() == 'index':
                    mapping = self.wrapper.index
                elif mapping.lower() == 'columns':
                    mapping = self.wrapper.columns
            mapping = to_mapping(mapping)

        self._mapped_arr = mapped_arr
        self._id_arr = id_arr
        self._col_arr = col_arr
        self._idx_arr = idx_arr
        self._mapping = mapping
        if col_mapper is None:
            col_mapper = ColumnMapper(wrapper, col_arr)
        self._col_mapper = col_mapper

    def replace(self: MappedArrayT, **kwargs) -> MappedArrayT:
        """See `vectorbt.utils.config.Configured.replace`.

        Also, makes sure that `MappedArray.col_mapper` is not passed to the new instance."""
        if self.config.get('col_mapper', None) is not None:
            if 'wrapper' in kwargs:
                if self.wrapper is not kwargs.get('wrapper'):
                    kwargs['col_mapper'] = None
            if 'col_arr' in kwargs:
                if self.col_arr is not kwargs.get('col_arr'):
                    kwargs['col_mapper'] = None
        return Configured.replace(self, **kwargs)

    def indexing_func_meta(self, pd_indexing_func: tp.PandasIndexingFunc,
                           **kwargs) -> IndexingMetaT:
        """Perform indexing on `MappedArray` and return metadata."""
        new_wrapper, _, group_idxs, col_idxs = \
            self.wrapper.indexing_func_meta(pd_indexing_func, column_only_select=True, **kwargs)
        new_indices, new_col_arr = self.col_mapper._col_idxs_meta(col_idxs)
        new_mapped_arr = self.values[new_indices]
        new_id_arr = self.id_arr[new_indices]
        if self.idx_arr is not None:
            new_idx_arr = self.idx_arr[new_indices]
        else:
            new_idx_arr = None
        return new_wrapper, new_mapped_arr, new_col_arr, new_id_arr, new_idx_arr, group_idxs, col_idxs

    def indexing_func(self: MappedArrayT,
                      pd_indexing_func: tp.PandasIndexingFunc,
                      **kwargs) -> MappedArrayT:
        """Perform indexing on `MappedArray`."""
        new_wrapper, new_mapped_arr, new_col_arr, new_id_arr, new_idx_arr, _, _ = \
            self.indexing_func_meta(pd_indexing_func, **kwargs)
        return self.replace(wrapper=new_wrapper,
                            mapped_arr=new_mapped_arr,
                            col_arr=new_col_arr,
                            id_arr=new_id_arr,
                            idx_arr=new_idx_arr)

    @property
    def mapped_arr(self) -> tp.Array1d:
        """Mapped array."""
        return self._mapped_arr

    @property
    def values(self) -> tp.Array1d:
        """Mapped array."""
        return self.mapped_arr

    def __len__(self) -> int:
        return len(self.values)

    @property
    def col_arr(self) -> tp.Array1d:
        """Column array."""
        return self._col_arr

    @property
    def col_mapper(self) -> ColumnMapper:
        """Column mapper.

        See `vectorbt.records.col_mapper.ColumnMapper`."""
        return self._col_mapper

    @property
    def id_arr(self) -> tp.Array1d:
        """Id array."""
        return self._id_arr

    @property
    def idx_arr(self) -> tp.Optional[tp.Array1d]:
        """Index array."""
        return self._idx_arr

    @property
    def mapping(self) -> tp.Optional[tp.Mapping]:
        """Mapping."""
        return self._mapping

    @cached_method
    def is_sorted(self, incl_id: bool = False) -> bool:
        """Check whether mapped array is sorted."""
        if incl_id:
            return nb.is_col_idx_sorted_nb(self.col_arr, self.id_arr)
        return nb.is_col_sorted_nb(self.col_arr)

    def sort(self: MappedArrayT,
             incl_id: bool = False,
             idx_arr: tp.Optional[tp.Array1d] = None,
             group_by: tp.GroupByLike = None,
             **kwargs) -> MappedArrayT:
        """Sort mapped array by column array (primary) and id array (secondary, optional).

        `**kwargs` are passed to `MappedArray.replace`."""
        if idx_arr is None:
            idx_arr = self.idx_arr
        if self.is_sorted(incl_id=incl_id):
            return self.replace(idx_arr=idx_arr, **kwargs).regroup(group_by)
        if incl_id:
            ind = np.lexsort((self.id_arr, self.col_arr))  # expensive!
        else:
            ind = np.argsort(self.col_arr)
        return self.replace(
            mapped_arr=self.values[ind],
            col_arr=self.col_arr[ind],
            id_arr=self.id_arr[ind],
            idx_arr=idx_arr[ind] if idx_arr is not None else None,
            **kwargs).regroup(group_by)

    def apply_mask(self: MappedArrayT,
                   mask: tp.Array1d,
                   idx_arr: tp.Optional[tp.Array1d] = None,
                   group_by: tp.GroupByLike = None,
                   **kwargs) -> MappedArrayT:
        """Return a new class instance, filtered by mask.

        `**kwargs` are passed to `MappedArray.replace`."""
        if idx_arr is None:
            idx_arr = self.idx_arr
        mask_indices = np.flatnonzero(mask)
        return self.replace(mapped_arr=np.take(self.values, mask_indices),
                            col_arr=np.take(self.col_arr, mask_indices),
                            id_arr=np.take(self.id_arr, mask_indices),
                            idx_arr=np.take(idx_arr, mask_indices)
                            if idx_arr is not None else None,
                            **kwargs).regroup(group_by)

    def map_to_mask(self,
                    inout_map_func_nb: tp.MaskInOutMapFunc,
                    *args,
                    group_by: tp.GroupByLike = None) -> tp.Array1d:
        """Map mapped array to a mask.

        See `vectorbt.records.nb.mapped_to_mask_nb`."""
        col_map = self.col_mapper.get_col_map(group_by=group_by)
        return nb.mapped_to_mask_nb(self.values, col_map, inout_map_func_nb,
                                    *args)

    @cached_method
    def top_n_mask(self, n: int, **kwargs) -> tp.Array1d:
        """Return mask of top N elements in each column/group."""
        return self.map_to_mask(nb.top_n_inout_map_nb, n, **kwargs)

    @cached_method
    def bottom_n_mask(self, n: int, **kwargs) -> tp.Array1d:
        """Return mask of bottom N elements in each column/group."""
        return self.map_to_mask(nb.bottom_n_inout_map_nb, n, **kwargs)

    @cached_method
    def top_n(self: MappedArrayT, n: int, **kwargs) -> MappedArrayT:
        """Filter top N elements from each column/group."""
        return self.apply_mask(self.top_n_mask(n), **kwargs)

    @cached_method
    def bottom_n(self: MappedArrayT, n: int, **kwargs) -> MappedArrayT:
        """Filter bottom N elements from each column/group."""
        return self.apply_mask(self.bottom_n_mask(n), **kwargs)

    @cached_method
    def is_expandable(self,
                      idx_arr: tp.Optional[tp.Array1d] = None,
                      group_by: tp.GroupByLike = None) -> bool:
        """See `vectorbt.records.nb.is_mapped_expandable_nb`."""
        if idx_arr is None:
            if self.idx_arr is None:
                raise ValueError("Must pass idx_arr")
            idx_arr = self.idx_arr
        col_arr = self.col_mapper.get_col_arr(group_by=group_by)
        target_shape = self.wrapper.get_shape_2d(group_by=group_by)
        return nb.is_mapped_expandable_nb(col_arr, idx_arr, target_shape)

    def to_pd(self,
              idx_arr: tp.Optional[tp.Array1d] = None,
              ignore_index: bool = False,
              fill_value: float = np.nan,
              group_by: tp.GroupByLike = None,
              wrap_kwargs: tp.KwargsLike = None) -> tp.SeriesFrame:
        """Expand mapped array to a Series/DataFrame.

        If `ignore_index`, will ignore the index and stack data points on top of each other in every column/group
        (see `vectorbt.records.nb.stack_expand_mapped_nb`). Otherwise, see `vectorbt.records.nb.expand_mapped_nb`.

        !!! note
            Will raise an error if there are multiple values pointing to the same position.
            Set `ignore_index` to True in this case.

        !!! warning
            Mapped arrays represent information in the most memory-friendly format.
            Mapping back to pandas may occupy lots of memory if records are sparse."""
        if ignore_index:
            if self.wrapper.ndim == 1:
                return self.wrapper.wrap(self.values,
                                         index=np.arange(len(self.values)),
                                         group_by=group_by,
                                         **merge_dicts({}, wrap_kwargs))
            col_map = self.col_mapper.get_col_map(group_by=group_by)
            out = nb.stack_expand_mapped_nb(self.values, col_map, fill_value)
            return self.wrapper.wrap(out,
                                     index=np.arange(out.shape[0]),
                                     group_by=group_by,
                                     **merge_dicts({}, wrap_kwargs))
        if idx_arr is None:
            if self.idx_arr is None:
                raise ValueError("Must pass idx_arr")
            idx_arr = self.idx_arr
        if not self.is_expandable(idx_arr=idx_arr, group_by=group_by):
            raise ValueError(
                "Multiple values are pointing to the same position. Use ignore_index."
            )
        col_arr = self.col_mapper.get_col_arr(group_by=group_by)
        target_shape = self.wrapper.get_shape_2d(group_by=group_by)
        out = nb.expand_mapped_nb(self.values, col_arr, idx_arr, target_shape,
                                  fill_value)
        return self.wrapper.wrap(out,
                                 group_by=group_by,
                                 **merge_dicts({}, wrap_kwargs))

    def apply(self: MappedArrayT,
              apply_func_nb: tp.MappedApplyFunc,
              *args,
              group_by: tp.GroupByLike = None,
              apply_per_group: bool = False,
              dtype: tp.Optional[tp.DTypeLike] = None,
              **kwargs) -> MappedArrayT:
        """Apply function on mapped array per column/group. Returns mapped array.

        Applies per group if `apply_per_group` is True.

        See `vectorbt.records.nb.apply_on_mapped_nb`.

        `**kwargs` are passed to `MappedArray.replace`."""
        checks.assert_numba_func(apply_func_nb)
        if apply_per_group:
            col_map = self.col_mapper.get_col_map(group_by=group_by)
        else:
            col_map = self.col_mapper.get_col_map(group_by=False)
        mapped_arr = nb.apply_on_mapped_nb(self.values, col_map, apply_func_nb,
                                           *args)
        mapped_arr = np.asarray(mapped_arr, dtype=dtype)
        return self.replace(mapped_arr=mapped_arr, **kwargs).regroup(group_by)

    def reduce(self,
               reduce_func_nb: tp.ReduceFunc,
               *args,
               idx_arr: tp.Optional[tp.Array1d] = None,
               returns_array: bool = False,
               returns_idx: bool = False,
               to_index: bool = True,
               fill_value: tp.Scalar = np.nan,
               group_by: tp.GroupByLike = None,
               wrap_kwargs: tp.KwargsLike = None) -> tp.MaybeSeriesFrame:
        """Reduce mapped array by column/group.

        If `returns_array` is False and `returns_idx` is False, see `vectorbt.records.nb.reduce_mapped_nb`.
        If `returns_array` is False and `returns_idx` is True, see `vectorbt.records.nb.reduce_mapped_to_idx_nb`.
        If `returns_array` is True and `returns_idx` is False, see `vectorbt.records.nb.reduce_mapped_to_array_nb`.
        If `returns_array` is True and `returns_idx` is True, see `vectorbt.records.nb.reduce_mapped_to_idx_array_nb`.

        If `returns_idx` is True, must pass `idx_arr`. Set `to_index` to False to return raw positions instead
        of labels. Use `fill_value` to set the default value. Set `group_by` to False to disable grouping.
        """
        # Perform checks
        checks.assert_numba_func(reduce_func_nb)
        if idx_arr is None:
            if self.idx_arr is None:
                if returns_idx:
                    raise ValueError("Must pass idx_arr")
            idx_arr = self.idx_arr

        # Perform main computation
        col_map = self.col_mapper.get_col_map(group_by=group_by)
        if not returns_array:
            if not returns_idx:
                out = nb.reduce_mapped_nb(self.values, col_map, fill_value,
                                          reduce_func_nb, *args)
            else:
                out = nb.reduce_mapped_to_idx_nb(self.values, col_map, idx_arr,
                                                 fill_value, reduce_func_nb,
                                                 *args)
        else:
            if not returns_idx:
                out = nb.reduce_mapped_to_array_nb(self.values, col_map,
                                                   fill_value, reduce_func_nb,
                                                   *args)
            else:
                out = nb.reduce_mapped_to_idx_array_nb(self.values, col_map,
                                                       idx_arr, fill_value,
                                                       reduce_func_nb, *args)

        # Perform post-processing
        wrap_kwargs = merge_dicts(
            dict(name_or_index='reduce' if not returns_array else None,
                 to_index=returns_idx and to_index,
                 fillna=-1 if returns_idx else None,
                 dtype=np.int_ if returns_idx else None), wrap_kwargs)
        return self.wrapper.wrap_reduced(out, group_by=group_by, **wrap_kwargs)

    @cached_method
    def nth(self,
            n: int,
            group_by: tp.GroupByLike = None,
            wrap_kwargs: tp.KwargsLike = None,
            **kwargs) -> tp.MaybeSeries:
        """Return n-th element of each column/group."""
        wrap_kwargs = merge_dicts(dict(name_or_index='nth'), wrap_kwargs)
        return self.reduce(generic_nb.nth_reduce_nb,
                           n,
                           returns_array=False,
                           returns_idx=False,
                           group_by=group_by,
                           wrap_kwargs=wrap_kwargs,
                           **kwargs)

    @cached_method
    def nth_index(self,
                  n: int,
                  group_by: tp.GroupByLike = None,
                  wrap_kwargs: tp.KwargsLike = None,
                  **kwargs) -> tp.MaybeSeries:
        """Return index of n-th element of each column/group."""
        wrap_kwargs = merge_dicts(dict(name_or_index='nth_index'), wrap_kwargs)
        return self.reduce(generic_nb.nth_index_reduce_nb,
                           n,
                           returns_array=False,
                           returns_idx=True,
                           group_by=group_by,
                           wrap_kwargs=wrap_kwargs,
                           **kwargs)

    @cached_method
    def min(self,
            group_by: tp.GroupByLike = None,
            wrap_kwargs: tp.KwargsLike = None,
            **kwargs) -> tp.MaybeSeries:
        """Return min by column/group."""
        wrap_kwargs = merge_dicts(dict(name_or_index='min'), wrap_kwargs)
        return self.reduce(generic_nb.min_reduce_nb,
                           returns_array=False,
                           returns_idx=False,
                           group_by=group_by,
                           wrap_kwargs=wrap_kwargs,
                           **kwargs)

    @cached_method
    def max(self,
            group_by: tp.GroupByLike = None,
            wrap_kwargs: tp.KwargsLike = None,
            **kwargs) -> tp.MaybeSeries:
        """Return max by column/group."""
        wrap_kwargs = merge_dicts(dict(name_or_index='max'), wrap_kwargs)
        return self.reduce(generic_nb.max_reduce_nb,
                           returns_array=False,
                           returns_idx=False,
                           group_by=group_by,
                           wrap_kwargs=wrap_kwargs,
                           **kwargs)

    @cached_method
    def mean(self,
             group_by: tp.GroupByLike = None,
             wrap_kwargs: tp.KwargsLike = None,
             **kwargs) -> tp.MaybeSeries:
        """Return mean by column/group."""
        wrap_kwargs = merge_dicts(dict(name_or_index='mean'), wrap_kwargs)
        return self.reduce(generic_nb.mean_reduce_nb,
                           returns_array=False,
                           returns_idx=False,
                           group_by=group_by,
                           wrap_kwargs=wrap_kwargs,
                           **kwargs)

    @cached_method
    def median(self,
               group_by: tp.GroupByLike = None,
               wrap_kwargs: tp.KwargsLike = None,
               **kwargs) -> tp.MaybeSeries:
        """Return median by column/group."""
        wrap_kwargs = merge_dicts(dict(name_or_index='median'), wrap_kwargs)
        return self.reduce(generic_nb.median_reduce_nb,
                           returns_array=False,
                           returns_idx=False,
                           group_by=group_by,
                           wrap_kwargs=wrap_kwargs,
                           **kwargs)

    @cached_method
    def std(self,
            ddof: int = 1,
            group_by: tp.GroupByLike = None,
            wrap_kwargs: tp.KwargsLike = None,
            **kwargs) -> tp.MaybeSeries:
        """Return std by column/group."""
        wrap_kwargs = merge_dicts(dict(name_or_index='std'), wrap_kwargs)
        return self.reduce(generic_nb.std_reduce_nb,
                           ddof,
                           returns_array=False,
                           returns_idx=False,
                           group_by=group_by,
                           wrap_kwargs=wrap_kwargs,
                           **kwargs)

    @cached_method
    def sum(self,
            fill_value: tp.Scalar = 0.,
            group_by: tp.GroupByLike = None,
            wrap_kwargs: tp.KwargsLike = None,
            **kwargs) -> tp.MaybeSeries:
        """Return sum by column/group."""
        wrap_kwargs = merge_dicts(dict(name_or_index='sum'), wrap_kwargs)
        return self.reduce(generic_nb.sum_reduce_nb,
                           fill_value=fill_value,
                           returns_array=False,
                           returns_idx=False,
                           group_by=group_by,
                           wrap_kwargs=wrap_kwargs,
                           **kwargs)

    @cached_method
    def idxmin(self,
               group_by: tp.GroupByLike = None,
               wrap_kwargs: tp.KwargsLike = None,
               **kwargs) -> tp.MaybeSeries:
        """Return index of min by column/group."""
        wrap_kwargs = merge_dicts(dict(name_or_index='idxmin'), wrap_kwargs)
        return self.reduce(generic_nb.argmin_reduce_nb,
                           returns_array=False,
                           returns_idx=True,
                           group_by=group_by,
                           wrap_kwargs=wrap_kwargs,
                           **kwargs)

    @cached_method
    def idxmax(self,
               group_by: tp.GroupByLike = None,
               wrap_kwargs: tp.KwargsLike = None,
               **kwargs) -> tp.MaybeSeries:
        """Return index of max by column/group."""
        wrap_kwargs = merge_dicts(dict(name_or_index='idxmax'), wrap_kwargs)
        return self.reduce(generic_nb.argmax_reduce_nb,
                           returns_array=False,
                           returns_idx=True,
                           group_by=group_by,
                           wrap_kwargs=wrap_kwargs,
                           **kwargs)

    @cached_method
    def describe(self,
                 percentiles: tp.Optional[tp.ArrayLike] = None,
                 ddof: int = 1,
                 group_by: tp.GroupByLike = None,
                 wrap_kwargs: tp.KwargsLike = None,
                 **kwargs) -> tp.SeriesFrame:
        """Return statistics by column/group."""
        if percentiles is not None:
            percentiles = to_1d_array(percentiles)
        else:
            percentiles = np.array([0.25, 0.5, 0.75])
        percentiles = percentiles.tolist()
        if 0.5 not in percentiles:
            percentiles.append(0.5)
        percentiles = np.unique(percentiles)
        perc_formatted = pd.io.formats.format.format_percentiles(percentiles)
        index = pd.Index(
            ['count', 'mean', 'std', 'min', *perc_formatted, 'max'])
        wrap_kwargs = merge_dicts(dict(name_or_index=index), wrap_kwargs)
        out = self.reduce(generic_nb.describe_reduce_nb,
                          percentiles,
                          ddof,
                          returns_array=True,
                          returns_idx=False,
                          group_by=group_by,
                          wrap_kwargs=wrap_kwargs,
                          **kwargs)
        if isinstance(out, pd.DataFrame):
            out.loc['count'].fillna(0., inplace=True)
        else:
            if np.isnan(out.loc['count']):
                out.loc['count'] = 0.
        return out

    @cached_method
    def count(self,
              group_by: tp.GroupByLike = None,
              wrap_kwargs: tp.KwargsLike = None) -> tp.MaybeSeries:
        """Return number of values by column/group."""
        wrap_kwargs = merge_dicts(dict(name_or_index='count'), wrap_kwargs)
        return self.wrapper.wrap_reduced(
            self.col_mapper.get_col_map(group_by=group_by)[1],
            group_by=group_by,
            **wrap_kwargs)

    @cached_method
    def value_counts(self,
                     normalize: bool = False,
                     sort_uniques: bool = True,
                     sort: bool = False,
                     ascending: bool = False,
                     dropna: bool = False,
                     group_by: tp.GroupByLike = None,
                     mapping: tp.Optional[tp.MappingLike] = None,
                     incl_all_keys: bool = False,
                     wrap_kwargs: tp.KwargsLike = None,
                     **kwargs) -> tp.SeriesFrame:
        """See `vectorbt.generic.accessors.GenericAccessor.value_counts`.

        !!! note
            Does not take into account missing values."""
        if mapping is None:
            mapping = self.mapping
        if isinstance(mapping, str):
            if mapping.lower() == 'index':
                mapping = self.wrapper.index
            elif mapping.lower() == 'columns':
                mapping = self.wrapper.columns
            mapping = to_mapping(mapping)
        mapped_codes, mapped_uniques = pd.factorize(self.values,
                                                    sort=False,
                                                    na_sentinel=None)
        col_map = self.col_mapper.get_col_map(group_by=group_by)
        value_counts = nb.mapped_value_counts_nb(mapped_codes,
                                                 len(mapped_uniques), col_map)
        if incl_all_keys and mapping is not None:
            missing_keys = []
            for x in mapping:
                if pd.isnull(x) and pd.isnull(mapped_uniques).any():
                    continue
                if x not in mapped_uniques:
                    missing_keys.append(x)
            value_counts = np.vstack(
                (value_counts,
                 np.full((len(missing_keys), value_counts.shape[1]), 0)))
            mapped_uniques = np.concatenate(
                (mapped_uniques, np.array(missing_keys)))
        nan_mask = np.isnan(mapped_uniques)
        if dropna:
            value_counts = value_counts[~nan_mask]
            mapped_uniques = mapped_uniques[~nan_mask]
        if sort_uniques:
            new_indices = mapped_uniques.argsort()
            value_counts = value_counts[new_indices]
            mapped_uniques = mapped_uniques[new_indices]
        value_counts_sum = value_counts.sum(axis=1)
        if normalize:
            value_counts = value_counts / value_counts_sum.sum()
        if sort:
            if ascending:
                new_indices = value_counts_sum.argsort()
            else:
                new_indices = (-value_counts_sum).argsort()
            value_counts = value_counts[new_indices]
            mapped_uniques = mapped_uniques[new_indices]
        value_counts_pd = self.wrapper.wrap(value_counts,
                                            index=mapped_uniques,
                                            group_by=group_by,
                                            **merge_dicts({}, wrap_kwargs))
        if mapping is not None:
            value_counts_pd.index = apply_mapping(value_counts_pd.index,
                                                  mapping, **kwargs)
        return value_counts_pd

    @cached_method
    def apply_mapping(self: MappedArrayT,
                      mapping: tp.Optional[tp.MappingLike] = None,
                      **kwargs) -> MappedArrayT:
        """Apply mapping on each element."""
        if mapping is None:
            mapping = self.mapping
        if isinstance(mapping, str):
            if mapping.lower() == 'index':
                mapping = self.wrapper.index
            elif mapping.lower() == 'columns':
                mapping = self.wrapper.columns
            mapping = to_mapping(mapping)
        return self.replace(mapped_arr=apply_mapping(self.values, mapping),
                            **kwargs)

    def to_index(self):
        """Convert to index."""
        return self.wrapper.index[self.values]

    # ############# Stats ############# #

    @property
    def stats_defaults(self) -> tp.Kwargs:
        """Defaults for `MappedArray.stats`.

        Merges `vectorbt.generic.stats_builder.StatsBuilderMixin.stats_defaults` and
        `mapped_array.stats` from `vectorbt._settings.settings`."""
        from vectorbt._settings import settings
        mapped_array_stats_cfg = settings['mapped_array']['stats']

        return merge_dicts(StatsBuilderMixin.stats_defaults.__get__(self),
                           mapped_array_stats_cfg)

    _metrics: tp.ClassVar[Config] = Config(dict(
        start=dict(title='Start',
                   calc_func=lambda self: self.wrapper.index[0],
                   agg_func=None,
                   tags='wrapper'),
        end=dict(title='End',
                 calc_func=lambda self: self.wrapper.index[-1],
                 agg_func=None,
                 tags='wrapper'),
        period=dict(title='Period',
                    calc_func=lambda self: len(self.wrapper.index),
                    apply_to_timedelta=True,
                    agg_func=None,
                    tags='wrapper'),
        count=dict(title='Count', calc_func='count', tags='mapped_array'),
        mean=dict(title='Mean',
                  calc_func='mean',
                  inv_check_has_mapping=True,
                  tags=['mapped_array', 'describe']),
        std=dict(title='Std',
                 calc_func='std',
                 inv_check_has_mapping=True,
                 tags=['mapped_array', 'describe']),
        min=dict(title='Min',
                 calc_func='min',
                 inv_check_has_mapping=True,
                 tags=['mapped_array', 'describe']),
        median=dict(title='Median',
                    calc_func='median',
                    inv_check_has_mapping=True,
                    tags=['mapped_array', 'describe']),
        max=dict(title='Max',
                 calc_func='max',
                 inv_check_has_mapping=True,
                 tags=['mapped_array', 'describe']),
        idx_min=dict(title='Min Index',
                     calc_func='idxmin',
                     inv_check_has_mapping=True,
                     agg_func=None,
                     tags=['mapped_array', 'index']),
        idx_max=dict(title='Max Index',
                     calc_func='idxmax',
                     inv_check_has_mapping=True,
                     agg_func=None,
                     tags=['mapped_array', 'index']),
        value_counts=dict(title='Value Counts',
                          calc_func=lambda value_counts: to_dict(
                              value_counts, orient='index_series'),
                          resolve_value_counts=True,
                          check_has_mapping=True,
                          tags=['mapped_array', 'value_counts'])),
                                           copy_kwargs=dict(copy_mode='deep'))

    @property
    def metrics(self) -> Config:
        return self._metrics

    # ############# Plotting ############# #

    def histplot(self,
                 group_by: tp.GroupByLike = None,
                 **kwargs) -> tp.BaseFigure:  # pragma: no cover
        """Plot histogram by column/group."""
        return self.to_pd(group_by=group_by,
                          ignore_index=True).vbt.histplot(**kwargs)

    def boxplot(self,
                group_by: tp.GroupByLike = None,
                **kwargs) -> tp.BaseFigure:  # pragma: no cover
        """Plot box plot by column/group."""
        return self.to_pd(group_by=group_by,
                          ignore_index=True).vbt.boxplot(**kwargs)

    @property
    def plots_defaults(self) -> tp.Kwargs:
        """Defaults for `MappedArray.plots`.

        Merges `vectorbt.generic.plots_builder.PlotsBuilderMixin.plots_defaults` and
        `mapped_array.plots` from `vectorbt._settings.settings`."""
        from vectorbt._settings import settings
        mapped_array_plots_cfg = settings['mapped_array']['plots']

        return merge_dicts(PlotsBuilderMixin.plots_defaults.__get__(self),
                           mapped_array_plots_cfg)

    _subplots: tp.ClassVar[Config] = Config(
        dict(to_pd_plot=dict(check_is_not_grouped=True,
                             plot_func='to_pd.vbt.plot',
                             pass_trace_names=False,
                             tags='mapped_array')),
        copy_kwargs=dict(copy_mode='deep'))

    @property
    def subplots(self) -> Config:
        return self._subplots
Beispiel #19
0
class Ranges(Records):
    """Extends `Records` for working with range records.

    Requires `records_arr` to have all fields defined in `vectorbt.generic.enums.range_dt`."""
    @property
    def field_config(self) -> Config:
        return self._field_config

    def __init__(self,
                 wrapper: ArrayWrapper,
                 records_arr: tp.RecordArray,
                 ts: tp.Optional[tp.ArrayLike] = None,
                 **kwargs) -> None:
        Records.__init__(self, wrapper, records_arr, ts=ts, **kwargs)
        self._ts = ts

    def indexing_func(self: RangesT, pd_indexing_func: tp.PandasIndexingFunc,
                      **kwargs) -> RangesT:
        """Perform indexing on `Ranges`."""
        new_wrapper, new_records_arr, _, col_idxs = \
            Records.indexing_func_meta(self, pd_indexing_func, **kwargs)
        if self.ts is not None:
            new_ts = new_wrapper.wrap(self.ts.values[:, col_idxs],
                                      group_by=False)
        else:
            new_ts = None
        return self.replace(wrapper=new_wrapper,
                            records_arr=new_records_arr,
                            ts=new_ts)

    @classmethod
    def from_ts(cls: tp.Type[RangesT],
                ts: tp.ArrayLike,
                gap_value: tp.Optional[tp.Scalar] = None,
                attach_ts: bool = True,
                wrapper_kwargs: tp.KwargsLike = None,
                **kwargs) -> RangesT:
        """Build `Ranges` from time series `ts`.

        Searches for sequences of

        * True values in boolean data (False acts as a gap),
        * positive values in integer data (-1 acts as a gap), and
        * non-NaN values in any other data (NaN acts as a gap).

        `**kwargs` will be passed to `Ranges.__init__`."""
        if wrapper_kwargs is None:
            wrapper_kwargs = {}

        ts_pd = to_pd_array(ts)
        ts_arr = to_2d_array(ts_pd)
        if gap_value is None:
            if np.issubdtype(ts_arr.dtype, np.bool_):
                gap_value = False
            elif np.issubdtype(ts_arr.dtype, np.integer):
                gap_value = -1
            else:
                gap_value = np.nan
        records_arr = nb.find_ranges_nb(ts_arr, gap_value)
        wrapper = ArrayWrapper.from_obj(ts_pd, **wrapper_kwargs)
        return cls(wrapper,
                   records_arr,
                   ts=ts_pd if attach_ts else None,
                   **kwargs)

    @property
    def ts(self) -> tp.Optional[tp.SeriesFrame]:
        """Original time series that records are built from (optional)."""
        return self._ts

    def to_mask(self,
                group_by: tp.GroupByLike = None,
                wrap_kwargs: tp.KwargsLike = None) -> tp.SeriesFrame:
        """Convert ranges to a mask.

        See `vectorbt.generic.nb.ranges_to_mask_nb`."""
        col_map = self.col_mapper.get_col_map(group_by=group_by)
        mask = nb.ranges_to_mask_nb(self.get_field_arr('start_idx'),
                                    self.get_field_arr('end_idx'),
                                    self.get_field_arr('status'), col_map,
                                    len(self.wrapper.index))
        return self.wrapper.wrap(mask,
                                 group_by=group_by,
                                 **merge_dicts({}, wrap_kwargs))

    @cached_property
    def duration(self) -> MappedArray:
        """Duration of each range (in raw format)."""
        duration = nb.range_duration_nb(self.get_field_arr('start_idx'),
                                        self.get_field_arr('end_idx'),
                                        self.get_field_arr('status'))
        return self.map_array(duration)

    @cached_method
    def avg_duration(self,
                     group_by: tp.GroupByLike = None,
                     wrap_kwargs: tp.KwargsLike = None,
                     **kwargs) -> tp.MaybeSeries:
        """Average range duration (as timedelta)."""
        wrap_kwargs = merge_dicts(
            dict(to_timedelta=True, name_or_index='avg_duration'), wrap_kwargs)
        return self.duration.mean(group_by=group_by,
                                  wrap_kwargs=wrap_kwargs,
                                  **kwargs)

    @cached_method
    def max_duration(self,
                     group_by: tp.GroupByLike = None,
                     wrap_kwargs: tp.KwargsLike = None,
                     **kwargs) -> tp.MaybeSeries:
        """Maximum range duration (as timedelta)."""
        wrap_kwargs = merge_dicts(
            dict(to_timedelta=True, name_or_index='max_duration'), wrap_kwargs)
        return self.duration.max(group_by=group_by,
                                 wrap_kwargs=wrap_kwargs,
                                 **kwargs)

    @cached_method
    def coverage(self,
                 overlapping: bool = False,
                 normalize: bool = True,
                 group_by: tp.GroupByLike = None,
                 wrap_kwargs: tp.KwargsLike = None) -> tp.MaybeSeries:
        """Coverage, that is, the number of steps that are covered by all ranges.

        See `vectorbt.generic.nb.range_coverage_nb`."""
        col_map = self.col_mapper.get_col_map(group_by=group_by)
        index_lens = self.wrapper.grouper.get_group_lens(
            group_by=group_by) * self.wrapper.shape[0]
        coverage = nb.range_coverage_nb(self.get_field_arr('start_idx'),
                                        self.get_field_arr('end_idx'),
                                        self.get_field_arr('status'),
                                        col_map,
                                        index_lens,
                                        overlapping=overlapping,
                                        normalize=normalize)
        wrap_kwargs = merge_dicts(dict(name_or_index='coverage'), wrap_kwargs)
        return self.wrapper.wrap_reduced(coverage,
                                         group_by=group_by,
                                         **wrap_kwargs)

    # ############# Stats ############# #

    @property
    def stats_defaults(self) -> tp.Kwargs:
        """Defaults for `Ranges.stats`.

        Merges `vectorbt.records.base.Records.stats_defaults` and
        `ranges.stats` from `vectorbt._settings.settings`."""
        from vectorbt._settings import settings
        ranges_stats_cfg = settings['ranges']['stats']

        return merge_dicts(Records.stats_defaults.__get__(self),
                           ranges_stats_cfg)

    _metrics: tp.ClassVar[Config] = Config(dict(
        start=dict(title='Start',
                   calc_func=lambda self: self.wrapper.index[0],
                   agg_func=None,
                   tags='wrapper'),
        end=dict(title='End',
                 calc_func=lambda self: self.wrapper.index[-1],
                 agg_func=None,
                 tags='wrapper'),
        period=dict(title='Period',
                    calc_func=lambda self: len(self.wrapper.index),
                    apply_to_timedelta=True,
                    agg_func=None,
                    tags='wrapper'),
        coverage=dict(title='Coverage',
                      calc_func='coverage',
                      overlapping=False,
                      normalize=False,
                      apply_to_timedelta=True,
                      tags=['ranges', 'coverage']),
        overlap_coverage=dict(title='Overlap Coverage',
                              calc_func='coverage',
                              overlapping=True,
                              normalize=False,
                              apply_to_timedelta=True,
                              tags=['ranges', 'coverage']),
        total_records=dict(title='Total Records',
                           calc_func='count',
                           tags='records'),
        duration=dict(title='Duration',
                      calc_func='duration.describe',
                      post_calc_func=lambda self, out, settings: {
                          'Min': out.loc['min'],
                          'Median': out.loc['50%'],
                          'Max': out.loc['max'],
                          'Mean': out.loc['mean'],
                          'Std': out.loc['std']
                      },
                      apply_to_timedelta=True,
                      tags=['ranges', 'duration']),
    ),
                                           copy_kwargs=dict(copy_mode='deep'))

    @property
    def metrics(self) -> Config:
        return self._metrics

    # ############# Plotting ############# #

    def plot(self,
             column: tp.Optional[tp.Label] = None,
             top_n: int = 5,
             plot_zones: bool = True,
             ts_trace_kwargs: tp.KwargsLike = None,
             start_trace_kwargs: tp.KwargsLike = None,
             end_trace_kwargs: tp.KwargsLike = None,
             open_shape_kwargs: tp.KwargsLike = None,
             closed_shape_kwargs: tp.KwargsLike = None,
             add_trace_kwargs: tp.KwargsLike = None,
             xref: str = 'x',
             yref: str = 'y',
             fig: tp.Optional[tp.BaseFigure] = None,
             **layout_kwargs) -> tp.BaseFigure:  # pragma: no cover
        """Plot ranges.

        Args:
            column (str): Name of the column to plot.
            top_n (int): Filter top N range records by maximum duration.
            plot_zones (bool): Whether to plot zones.
            ts_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `Ranges.ts`.
            start_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for start values.
            end_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for end values.
            open_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for open zones.
            closed_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for closed zones.
            add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
            xref (str): X coordinate axis.
            yref (str): Y coordinate axis.
            fig (Figure or FigureWidget): Figure to add traces to.
            **layout_kwargs: Keyword arguments for layout.

        ## Example

        ```python-repl
        >>> import vectorbt as vbt
        >>> from datetime import datetime, timedelta
        >>> import pandas as pd

        >>> price = pd.Series([1, 2, 1, 2, 3, 2, 1, 2], name='Price')
        >>> price.index = [datetime(2020, 1, 1) + timedelta(days=i) for i in range(len(price))]
        >>> vbt.Ranges.from_ts(price >= 2, wrapper_kwargs=dict(freq='1 day')).plot()
        ```

        ![](/docs/img/ranges_plot.svg)
        """
        from vectorbt._settings import settings
        plotting_cfg = settings['plotting']

        self_col = self.select_one(column=column, group_by=False)
        if top_n is not None:
            self_col = self_col.apply_mask(self_col.duration.top_n_mask(top_n))

        if ts_trace_kwargs is None:
            ts_trace_kwargs = {}
        ts_trace_kwargs = merge_dicts(
            dict(line=dict(color=plotting_cfg['color_schema']['blue'])),
            ts_trace_kwargs)
        if start_trace_kwargs is None:
            start_trace_kwargs = {}
        if end_trace_kwargs is None:
            end_trace_kwargs = {}
        if open_shape_kwargs is None:
            open_shape_kwargs = {}
        if closed_shape_kwargs is None:
            closed_shape_kwargs = {}
        if add_trace_kwargs is None:
            add_trace_kwargs = {}

        if fig is None:
            fig = make_figure()
        fig.update_layout(**layout_kwargs)
        y_domain = get_domain(yref, fig)

        if self_col.ts is not None:
            fig = self_col.ts.vbt.plot(trace_kwargs=ts_trace_kwargs,
                                       add_trace_kwargs=add_trace_kwargs,
                                       fig=fig)

        if self_col.count() > 0:
            # Extract information
            id_ = self_col.get_field_arr('id')
            id_title = self_col.get_field_title('id')

            start_idx = self_col.get_map_field_to_index('start_idx')
            start_idx_title = self_col.get_field_title('start_idx')
            if self_col.ts is not None:
                start_val = self_col.ts.loc[start_idx]
            else:
                start_val = np.full(len(start_idx), 0)

            end_idx = self_col.get_map_field_to_index('end_idx')
            end_idx_title = self_col.get_field_title('end_idx')
            if self_col.ts is not None:
                end_val = self_col.ts.loc[end_idx]
            else:
                end_val = np.full(len(end_idx), 0)

            duration = np.vectorize(str)(self_col.wrapper.to_timedelta(
                self_col.duration.values, to_pd=True, silence_warnings=True))

            status = self_col.get_field_arr('status')

            # Plot start markers
            start_customdata = id_[:, None]
            start_scatter = go.Scatter(
                x=start_idx,
                y=start_val,
                mode='markers',
                marker=dict(
                    symbol='diamond',
                    color=plotting_cfg['contrast_color_schema']['blue'],
                    size=7,
                    line=dict(
                        width=1,
                        color=adjust_lightness(
                            plotting_cfg['contrast_color_schema']['blue']))),
                name='Start',
                customdata=start_customdata,
                hovertemplate=f"{id_title}: %{{customdata[0]}}"
                f"<br>{start_idx_title}: %{{x}}")
            start_scatter.update(**start_trace_kwargs)
            fig.add_trace(start_scatter, **add_trace_kwargs)

            closed_mask = status == RangeStatus.Closed
            if closed_mask.any():
                # Plot end markers
                closed_end_customdata = np.stack(
                    (id_[closed_mask], duration[closed_mask]), axis=1)
                closed_end_scatter = go.Scatter(
                    x=end_idx[closed_mask],
                    y=end_val[closed_mask],
                    mode='markers',
                    marker=dict(
                        symbol='diamond',
                        color=plotting_cfg['contrast_color_schema']['green'],
                        size=7,
                        line=dict(width=1,
                                  color=adjust_lightness(
                                      plotting_cfg['contrast_color_schema']
                                      ['green']))),
                    name='Closed',
                    customdata=closed_end_customdata,
                    hovertemplate=f"{id_title}: %{{customdata[0]}}"
                    f"<br>{end_idx_title}: %{{x}}"
                    f"<br>Duration: %{{customdata[1]}}")
                closed_end_scatter.update(**end_trace_kwargs)
                fig.add_trace(closed_end_scatter, **add_trace_kwargs)

                if plot_zones:
                    # Plot closed range zones
                    for i in range(len(id_[closed_mask])):
                        fig.add_shape(**merge_dicts(
                            dict(
                                type="rect",
                                xref=xref,
                                yref="paper",
                                x0=start_idx[closed_mask][i],
                                y0=y_domain[0],
                                x1=end_idx[closed_mask][i],
                                y1=y_domain[1],
                                fillcolor='teal',
                                opacity=0.2,
                                layer="below",
                                line_width=0,
                            ), closed_shape_kwargs))

            open_mask = status == RangeStatus.Open
            if open_mask.any():
                # Plot end markers
                open_end_customdata = np.stack(
                    (id_[open_mask], duration[open_mask]), axis=1)
                open_end_scatter = go.Scatter(
                    x=end_idx[open_mask],
                    y=end_val[open_mask],
                    mode='markers',
                    marker=dict(
                        symbol='diamond',
                        color=plotting_cfg['contrast_color_schema']['orange'],
                        size=7,
                        line=dict(width=1,
                                  color=adjust_lightness(
                                      plotting_cfg['contrast_color_schema']
                                      ['orange']))),
                    name='Open',
                    customdata=open_end_customdata,
                    hovertemplate=f"{id_title}: %{{customdata[0]}}"
                    f"<br>{end_idx_title}: %{{x}}"
                    f"<br>Duration: %{{customdata[1]}}")
                open_end_scatter.update(**end_trace_kwargs)
                fig.add_trace(open_end_scatter, **add_trace_kwargs)

                if plot_zones:
                    # Plot open range zones
                    for i in range(len(id_[open_mask])):
                        fig.add_shape(**merge_dicts(
                            dict(
                                type="rect",
                                xref=xref,
                                yref="paper",
                                x0=start_idx[open_mask][i],
                                y0=y_domain[0],
                                x1=end_idx[open_mask][i],
                                y1=y_domain[1],
                                fillcolor='orange',
                                opacity=0.2,
                                layer="below",
                                line_width=0,
                            ), open_shape_kwargs))

        return fig

    @property
    def plots_defaults(self) -> tp.Kwargs:
        """Defaults for `Ranges.plots`.

        Merges `vectorbt.records.base.Records.plots_defaults` and
        `ranges.plots` from `vectorbt._settings.settings`."""
        from vectorbt._settings import settings
        ranges_plots_cfg = settings['ranges']['plots']

        return merge_dicts(Records.plots_defaults.__get__(self),
                           ranges_plots_cfg)

    _subplots: tp.ClassVar[Config] = Config(
        dict(plot=dict(title="Ranges",
                       check_is_not_grouped=True,
                       plot_func='plot',
                       tags='ranges')),
        copy_kwargs=dict(copy_mode='deep'))

    @property
    def subplots(self) -> Config:
        return self._subplots
Beispiel #20
0
__pdoc__ = {}

orders_field_config = Config(
    dict(
        dtype=order_dt,
        settings=dict(
            id=dict(
                title='Order Id'
            ),
            size=dict(
                title='Size'
            ),
            price=dict(
                title='Price'
            ),
            fees=dict(
                title='Fees'
            ),
            side=dict(
                title='Side',
                mapping=OrderSide
            )
        )
    ),
    readonly=True,
    as_attrs=False
)
"""_"""

__pdoc__['orders_field_config'] = f"""Field config for `Orders`.
Beispiel #21
0
def save(fname, names=__all__, **kwargs):
    """Save settings to a file."""
    settings = dict()
    for k in names:
        settings[k] = getattr(this_module, k)
    Config(settings).save(fname, **kwargs)
Beispiel #22
0
class Orders(Records):
    """Extends `Records` for working with order records."""

    @property
    def field_config(self) -> Config:
        return self._field_config

    def __init__(self,
                 wrapper: ArrayWrapper,
                 records_arr: tp.RecordArray,
                 close: tp.Optional[tp.ArrayLike] = None,
                 **kwargs) -> None:
        Records.__init__(
            self,
            wrapper,
            records_arr,
            close=close,
            **kwargs
        )
        self._close = close

    def indexing_func(self: OrdersT, pd_indexing_func: tp.PandasIndexingFunc, **kwargs) -> OrdersT:
        """Perform indexing on `Orders`."""
        new_wrapper, new_records_arr, group_idxs, col_idxs = \
            Records.indexing_func_meta(self, pd_indexing_func, **kwargs)
        if self.close is not None:
            new_close = new_wrapper.wrap(to_2d_array(self.close)[:, col_idxs], group_by=False)
        else:
            new_close = None
        return self.replace(
            wrapper=new_wrapper,
            records_arr=new_records_arr,
            close=new_close
        )

    @property
    def close(self) -> tp.Optional[tp.SeriesFrame]:
        """Reference price such as close (optional)."""
        return self._close

    # ############# Stats ############# #

    @property
    def stats_defaults(self) -> tp.Kwargs:
        """Defaults for `Orders.stats`.

        Merges `vectorbt.records.base.Records.stats_defaults` and
        `orders.stats` from `vectorbt._settings.settings`."""
        from vectorbt._settings import settings
        orders_stats_cfg = settings['orders']['stats']

        return merge_dicts(
            Records.stats_defaults.__get__(self),
            orders_stats_cfg
        )

    _metrics: tp.ClassVar[Config] = Config(
        dict(
            start=dict(
                title='Start',
                calc_func=lambda self: self.wrapper.index[0],
                agg_func=None,
                tags='wrapper'
            ),
            end=dict(
                title='End',
                calc_func=lambda self: self.wrapper.index[-1],
                agg_func=None,
                tags='wrapper'
            ),
            period=dict(
                title='Period',
                calc_func=lambda self: len(self.wrapper.index),
                apply_to_timedelta=True,
                agg_func=None,
                tags='wrapper'
            ),
            total_records=dict(
                title='Total Records',
                calc_func='count',
                tags='records'
            ),
            total_buy_orders=dict(
                title='Total Buy Orders',
                calc_func='buy.count',
                tags=['orders', 'buy']
            ),
            total_sell_orders=dict(
                title='Total Sell Orders',
                calc_func='sell.count',
                tags=['orders', 'sell']
            ),
            min_size=dict(
                title='Min Size',
                calc_func='size.min',
                tags=['orders', 'size']
            ),
            max_size=dict(
                title='Max Size',
                calc_func='size.max',
                tags=['orders', 'size']
            ),
            avg_size=dict(
                title='Avg Size',
                calc_func='size.mean',
                tags=['orders', 'size']
            ),
            avg_buy_size=dict(
                title='Avg Buy Size',
                calc_func='buy.size.mean',
                tags=['orders', 'buy', 'size']
            ),
            avg_sell_size=dict(
                title='Avg Sell Size',
                calc_func='sell.size.mean',
                tags=['orders', 'sell', 'size']
            ),
            avg_buy_price=dict(
                title='Avg Buy Price',
                calc_func='buy.price.mean',
                tags=['orders', 'buy', 'price']
            ),
            avg_sell_price=dict(
                title='Avg Sell Price',
                calc_func='sell.price.mean',
                tags=['orders', 'sell', 'price']
            ),
            total_fees=dict(
                title='Total Fees',
                calc_func='fees.sum',
                tags=['orders', 'fees']
            ),
            min_fees=dict(
                title='Min Fees',
                calc_func='fees.min',
                tags=['orders', 'fees']
            ),
            max_fees=dict(
                title='Max Fees',
                calc_func='fees.max',
                tags=['orders', 'fees']
            ),
            avg_fees=dict(
                title='Avg Fees',
                calc_func='fees.mean',
                tags=['orders', 'fees']
            ),
            avg_buy_fees=dict(
                title='Avg Buy Fees',
                calc_func='buy.fees.mean',
                tags=['orders', 'buy', 'fees']
            ),
            avg_sell_fees=dict(
                title='Avg Sell Fees',
                calc_func='sell.fees.mean',
                tags=['orders', 'sell', 'fees']
            ),
        ),
        copy_kwargs=dict(copy_mode='deep')
    )

    @property
    def metrics(self) -> Config:
        return self._metrics

    # ############# Plotting ############# #

    def plot(self,
             column: tp.Optional[tp.Label] = None,
             close_trace_kwargs: tp.KwargsLike = None,
             buy_trace_kwargs: tp.KwargsLike = None,
             sell_trace_kwargs: tp.KwargsLike = None,
             add_trace_kwargs: tp.KwargsLike = None,
             fig: tp.Optional[tp.BaseFigure] = None,
             **layout_kwargs) -> tp.BaseFigure:  # pragma: no cover
        """Plot orders.

        Args:
            column (str): Name of the column to plot.
            close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `Orders.close`.
            buy_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Buy" markers.
            sell_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Sell" markers.
            add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
            fig (Figure or FigureWidget): Figure to add traces to.
            **layout_kwargs: Keyword arguments for layout.

        ## Example

        ```python-repl
        >>> import pandas as pd
        >>> from datetime import datetime, timedelta
        >>> import vectorbt as vbt

        >>> price = pd.Series([1., 2., 3., 2., 1.], name='Price')
        >>> price.index = [datetime(2020, 1, 1) + timedelta(days=i) for i in range(len(price))]
        >>> size = pd.Series([1., 1., 1., 1., -1.])
        >>> orders = vbt.Portfolio.from_orders(price, size).orders

        >>> orders.plot()
        ```

        ![](/docs/img/orders_plot.svg)"""
        from vectorbt._settings import settings
        plotting_cfg = settings['plotting']

        self_col = self.select_one(column=column, group_by=False)

        if close_trace_kwargs is None:
            close_trace_kwargs = {}
        close_trace_kwargs = merge_dicts(dict(
            line=dict(
                color=plotting_cfg['color_schema']['blue']
            ),
            name='Close'
        ), close_trace_kwargs)
        if buy_trace_kwargs is None:
            buy_trace_kwargs = {}
        if sell_trace_kwargs is None:
            sell_trace_kwargs = {}
        if add_trace_kwargs is None:
            add_trace_kwargs = {}

        if fig is None:
            fig = make_figure()
        fig.update_layout(**layout_kwargs)

        # Plot price
        if self_col.close is not None:
            fig = self_col.close.vbt.plot(trace_kwargs=close_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig)

        if self_col.count() > 0:
            # Extract information
            id_ = self_col.get_field_arr('id')
            id_title = self_col.get_field_title('id')

            idx = self_col.get_map_field_to_index('idx')
            idx_title = self_col.get_field_title('idx')

            size = self_col.get_field_arr('size')
            size_title = self_col.get_field_title('size')

            fees = self_col.get_field_arr('fees')
            fees_title = self_col.get_field_title('fees')

            price = self_col.get_field_arr('price')
            price_title = self_col.get_field_title('price')

            side = self_col.get_field_arr('side')

            buy_mask = side == OrderSide.Buy
            if buy_mask.any():
                # Plot buy markers
                buy_customdata = np.stack((
                    id_[buy_mask],
                    size[buy_mask],
                    fees[buy_mask]
                ), axis=1)
                buy_scatter = go.Scatter(
                    x=idx[buy_mask],
                    y=price[buy_mask],
                    mode='markers',
                    marker=dict(
                        symbol='triangle-up',
                        color=plotting_cfg['contrast_color_schema']['green'],
                        size=8,
                        line=dict(
                            width=1,
                            color=adjust_lightness(plotting_cfg['contrast_color_schema']['green'])
                        )
                    ),
                    name='Buy',
                    customdata=buy_customdata,
                    hovertemplate=f"{id_title}: %{{customdata[0]}}"
                                  f"<br>{idx_title}: %{{x}}"
                                  f"<br>{price_title}: %{{y}}"
                                  f"<br>{size_title}: %{{customdata[1]:.6f}}"
                                  f"<br>{fees_title}: %{{customdata[2]:.6f}}"
                )
                buy_scatter.update(**buy_trace_kwargs)
                fig.add_trace(buy_scatter, **add_trace_kwargs)

            sell_mask = side == OrderSide.Sell
            if sell_mask.any():
                # Plot sell markers
                sell_customdata = np.stack((
                    id_[sell_mask],
                    size[sell_mask],
                    fees[sell_mask]
                ), axis=1)
                sell_scatter = go.Scatter(
                    x=idx[sell_mask],
                    y=price[sell_mask],
                    mode='markers',
                    marker=dict(
                        symbol='triangle-down',
                        color=plotting_cfg['contrast_color_schema']['red'],
                        size=8,
                        line=dict(
                            width=1,
                            color=adjust_lightness(plotting_cfg['contrast_color_schema']['red'])
                        )
                    ),
                    name='Sell',
                    customdata=sell_customdata,
                    hovertemplate=f"{id_title}: %{{customdata[0]}}"
                                  f"<br>{idx_title}: %{{x}}"
                                  f"<br>{price_title}: %{{y}}"
                                  f"<br>{size_title}: %{{customdata[1]:.6f}}"
                                  f"<br>{fees_title}: %{{customdata[2]:.6f}}"
                )
                sell_scatter.update(**sell_trace_kwargs)
                fig.add_trace(sell_scatter, **add_trace_kwargs)

        return fig

    @property
    def plots_defaults(self) -> tp.Kwargs:
        """Defaults for `Orders.plots`.

        Merges `vectorbt.records.base.Records.plots_defaults` and
        `orders.plots` from `vectorbt._settings.settings`."""
        from vectorbt._settings import settings
        orders_plots_cfg = settings['orders']['plots']

        return merge_dicts(
            Records.plots_defaults.__get__(self),
            orders_plots_cfg
        )

    _subplots: tp.ClassVar[Config] = Config(
        dict(
            plot=dict(
                title="Orders",
                yaxis_kwargs=dict(title="Price"),
                check_is_not_grouped=True,
                plot_func='plot',
                tags='orders'
            )
        ),
        copy_kwargs=dict(copy_mode='deep')
    )

    @property
    def subplots(self) -> Config:
        return self._subplots
Beispiel #23
0
"""Common configurations for indicators."""

from vectorbt.utils.config import Config

flex_elem_param_config = Config(
    dict(
        is_array_like=
        True,  # passing a NumPy array means passing one value, for multiple use list
        bc_to_input=True,  # broadcast to input
        broadcast_kwargs=dict(
            keep_raw=
            True  # keep original shape for flexible indexing to save memory
        )))
"""Config for flexible element-wise parameters."""

flex_col_param_config = Config(
    dict(
        is_array_like=True,
        bc_to_input=1,  # broadcast to axis 1 (columns)
        per_column=True,  # display one parameter per column
        broadcast_kwargs=dict(keep_raw=True)))
"""Config for flexible column-wise parameters."""
Beispiel #24
0
"""Global defaults."""

from vectorbt.utils.config import Config

# Layout
layout = Config(frozen=False,
                autosize=False,
                width=700,
                height=300,
                margin=dict(b=30, t=30),
                hovermode='closest',
                colorway=[
                    "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd",
                    "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf"
                ])
"""Default Plotly layout."""

# Portfolio
portfolio = Config(init_capital=1.,
                   fees=0.,
                   slippage=0.,
                   year_freq='1Y',
                   risk_free=0.,
                   required_return=0.,
                   cutoff=0.05)
"""Default portfolio parameters."""

# Broadcasting
broadcasting = Config(index_from='strict',
                      columns_from='stack',
                      ignore_single=True,
Beispiel #25
0
        self.register_template(theme)
        self['plotting']['color_schema'].update(self['plotting']['themes'][theme]['color_schema'])
        self['plotting']['layout']['template'] = 'vbt_' + theme

    def reset_theme(self) -> None:
        """Reset to default theme."""
        self.set_theme('light')


settings = SettingsConfig(
    dict(
        config=Config(  # flex
            dict(
                configured=Config(  # flex
                    dict(
                        readonly=True
                    )
                )
            ),
        ),
        caching=dict(
            enabled=True,
            whitelist=[
                CacheCondition(base_cls=ArrayWrapper),
                CacheCondition(base_cls=ColumnGrouper),
                CacheCondition(base_cls=ColumnMapper)
            ],
            blacklist=[]
        ),
        broadcasting=dict(
            align_index=False,
Beispiel #26
0
        0    False
        1    False
        2     True
        3     True
        4     True
        Name: array_0, dtype: bool
        ```
    """
    pass


setattr(RPROB, '__doc__', _RPROB.__doc__)

rprobx_config = Config(
    dict(class_name='RPROBX',
         module_name=__name__,
         short_name='rprobx',
         mode='exits',
         param_names=['prob']))
"""Factory config for `RPROBX`."""

rprobx_func_config = Config(
    dict(exit_choice_func=rand_by_prob_choice_nb,
         exit_settings=dict(
             pass_params=['prob'],
             pass_kwargs=['pick_first', 'temp_idx_arr', 'flex_2d']),
         pass_flex_2d=True,
         param_settings=dict(prob=flex_elem_param_config),
         seed=None))
"""Exit function config for `RPROBX`."""

RPROBX = SignalFactory(**rprobx_config).from_choice_func(**rprobx_func_config)
Beispiel #27
0
binary_magic_config = Config(
    {
        '__eq__': dict(func=np.equal),
        '__ne__': dict(func=np.not_equal),
        '__lt__': dict(func=np.less),
        '__gt__': dict(func=np.greater),
        '__le__': dict(func=np.less_equal),
        '__ge__': dict(func=np.greater_equal),
        # arithmetic ops
        '__add__': dict(func=np.add),
        '__sub__': dict(func=np.subtract),
        '__mul__': dict(func=np.multiply),
        '__pow__': dict(func=np.power),
        '__mod__': dict(func=np.mod),
        '__floordiv__': dict(func=np.floor_divide),
        '__truediv__': dict(func=np.true_divide),
        '__radd__': dict(func=lambda x, y: np.add(y, x)),
        '__rsub__': dict(func=lambda x, y: np.subtract(y, x)),
        '__rmul__': dict(func=lambda x, y: np.multiply(y, x)),
        '__rpow__': dict(func=lambda x, y: np.power(y, x)),
        '__rmod__': dict(func=lambda x, y: np.mod(y, x)),
        '__rfloordiv__': dict(func=lambda x, y: np.floor_divide(y, x)),
        '__rtruediv__': dict(func=lambda x, y: np.true_divide(y, x)),
        # mask ops
        '__and__': dict(func=np.bitwise_and),
        '__or__': dict(func=np.bitwise_or),
        '__xor__': dict(func=np.bitwise_xor),
        '__rand__': dict(func=lambda x, y: np.bitwise_and(y, x)),
        '__ror__': dict(func=lambda x, y: np.bitwise_or(y, x)),
        '__rxor__': dict(func=lambda x, y: np.bitwise_xor(y, x))
    },
    readonly=True,
    as_attrs=False)
Beispiel #28
0
class PlotsBuilderMixin(metaclass=MetaPlotsBuilderMixin):
    """Mixin that implements `PlotsBuilderMixin.plots`.

    Required to be a subclass of `vectorbt.base.array_wrapper.Wrapping`."""
    def __init__(self):
        checks.assert_instance_of(self, Wrapping)

        # Copy writeable attrs
        self._subplots = self.__class__._subplots.copy()

    @property
    def writeable_attrs(self) -> tp.Set[str]:
        """Set of writeable attributes that will be saved/copied along with the config."""
        return {'_subplots'}

    @property
    def plots_defaults(self) -> tp.Kwargs:
        """Defaults for `PlotsBuilderMixin.plots`.

        See `plots_builder` in `vectorbt._settings.settings`."""
        from vectorbt._settings import settings
        plots_builder_cfg = settings['plots_builder']

        return merge_dicts(plots_builder_cfg,
                           dict(settings=dict(freq=self.wrapper.freq)))

    _subplots: tp.ClassVar[Config] = Config(dict(),
                                            copy_kwargs=dict(copy_mode='deep'))

    @property
    def subplots(self) -> Config:
        """Subplots supported by `${cls_name}`.

        ```json
        ${subplots}
        ```

        Returns `${cls_name}._subplots`, which gets (deep) copied upon creation of each instance.
        Thus, changing this config won't affect the class.

        To change subplots, you can either change the config in-place, override this property,
        or overwrite the instance variable `${cls_name}._subplots`."""
        return self._subplots

    def plots(self,
              subplots: tp.Optional[tp.MaybeIterable[tp.Union[str, tp.Tuple[
                  str, tp.Kwargs]]]] = None,
              tags: tp.Optional[tp.MaybeIterable[str]] = None,
              column: tp.Optional[tp.Label] = None,
              group_by: tp.GroupByLike = None,
              silence_warnings: tp.Optional[bool] = None,
              template_mapping: tp.Optional[tp.Mapping] = None,
              settings: tp.KwargsLike = None,
              filters: tp.KwargsLike = None,
              subplot_settings: tp.KwargsLike = None,
              show_titles: bool = None,
              hide_id_labels: bool = None,
              group_id_labels: bool = None,
              make_subplots_kwargs: tp.KwargsLike = None,
              **layout_kwargs) -> tp.Optional[tp.BaseFigure]:
        """Plot various parts of this object.

        Args:
            subplots (str, tuple, iterable, or dict): Subplots to plot.

                Each element can be either:

                * a subplot name (see keys in `PlotsBuilderMixin.subplots`)
                * a tuple of a subplot name and a settings dict as in `PlotsBuilderMixin.subplots`.

                The settings dict can contain the following keys:

                * `title`: Title of the subplot. Defaults to the name.
                * `plot_func` (required): Plotting function for custom subplots.
                    Should write the supplied figure `fig` in-place and can return anything (it won't be used).
                * `xaxis_kwargs`: Layout keyword arguments for the x-axis. Defaults to `dict(title='Index')`.
                * `yaxis_kwargs`: Layout keyword arguments for the y-axis. Defaults to empty dict.
                * `tags`, `check_{filter}`, `inv_check_{filter}`, `resolve_plot_func`, `pass_{arg}`,
                    `resolve_path_{arg}`, `resolve_{arg}` and `template_mapping`:
                    The same as in `vectorbt.generic.stats_builder.StatsBuilderMixin` for `calc_func`.
                * Any other keyword argument that overrides the settings or is passed directly to `plot_func`.

                If `resolve_plot_func` is True, the plotting function may "request" any of the
                following arguments by accepting them or if `pass_{arg}` was found in the settings dict:

                * Each of `vectorbt.utils.attr_.AttrResolver.self_aliases`: original object
                    (ungrouped, with no column selected)
                * `group_by`: won't be passed if it was used in resolving the first attribute of `plot_func`
                    specified as a path, use `pass_group_by=True` to pass anyway
                * `column`
                * `subplot_name`
                * `trace_names`: list with the subplot name, can't be used in templates
                * `add_trace_kwargs`: dict with subplot row and column index
                * `xref`
                * `yref`
                * `xaxis`
                * `yaxis`
                * `x_domain`
                * `y_domain`
                * `fig`
                * `silence_warnings`
                * Any argument from `settings`
                * Any attribute of this object if it meant to be resolved
                    (see `vectorbt.utils.attr_.AttrResolver.resolve_attr`)

                !!! note
                    Layout-related resolution arguments such as `add_trace_kwargs` are unavailable
                    before filtering and thus cannot be used in any templates but can still be overridden.

                Pass `subplots='all'` to plot all supported subplots.
            tags (str or iterable): See `tags` in `vectorbt.generic.stats_builder.StatsBuilderMixin`.
            column (str): See `column` in `vectorbt.generic.stats_builder.StatsBuilderMixin`.
            group_by (any): See `group_by` in `vectorbt.generic.stats_builder.StatsBuilderMixin`.
            silence_warnings (bool): See `silence_warnings` in `vectorbt.generic.stats_builder.StatsBuilderMixin`.
            template_mapping (mapping): See `template_mapping` in `vectorbt.generic.stats_builder.StatsBuilderMixin`.

                Applied on `settings`, `make_subplots_kwargs`, and `layout_kwargs`, and then on each subplot settings.
            filters (dict): See `filters` in `vectorbt.generic.stats_builder.StatsBuilderMixin`.
            settings (dict): See `settings` in `vectorbt.generic.stats_builder.StatsBuilderMixin`.
            subplot_settings (dict): See `metric_settings` in `vectorbt.generic.stats_builder.StatsBuilderMixin`.
            show_titles (bool): Whether to show the title of each subplot.
            hide_id_labels (bool): Whether to hide identical legend labels.

                Two labels are identical if their name, marker style and line style match.
            group_id_labels (bool): Whether to group identical legend labels.
            make_subplots_kwargs (dict): Keyword arguments passed to `plotly.subplots.make_subplots`.
            **layout_kwargs: Keyword arguments used to update the layout of the figure.

        !!! note
            `PlotsBuilderMixin` and `vectorbt.generic.stats_builder.StatsBuilderMixin` are very similar.
            Some artifacts follow the same concept, just named differently:

            * `plots_defaults` vs `stats_defaults`
            * `subplots` vs `metrics`
            * `subplot_settings` vs `metric_settings`

            See further notes under `vectorbt.generic.stats_builder.StatsBuilderMixin`.

        Usage:
            See `vectorbt.portfolio.base` for examples.
        """
        from vectorbt._settings import settings as _settings
        plotting_cfg = _settings['plotting']

        # Resolve defaults
        if silence_warnings is None:
            silence_warnings = self.plots_defaults['silence_warnings']
        if show_titles is None:
            show_titles = self.plots_defaults['show_titles']
        if hide_id_labels is None:
            hide_id_labels = self.plots_defaults['hide_id_labels']
        if group_id_labels is None:
            group_id_labels = self.plots_defaults['group_id_labels']
        template_mapping = merge_dicts(self.plots_defaults['template_mapping'],
                                       template_mapping)
        filters = merge_dicts(self.plots_defaults['filters'], filters)
        settings = merge_dicts(self.plots_defaults['settings'], settings)
        subplot_settings = merge_dicts(self.plots_defaults['subplot_settings'],
                                       subplot_settings)
        make_subplots_kwargs = merge_dicts(
            self.plots_defaults['make_subplots_kwargs'], make_subplots_kwargs)
        layout_kwargs = merge_dicts(self.plots_defaults['layout_kwargs'],
                                    layout_kwargs)

        # Replace templates globally (not used at subplot level)
        if len(template_mapping) > 0:
            sub_settings = deep_substitute(settings, mapping=template_mapping)
            sub_make_subplots_kwargs = deep_substitute(
                make_subplots_kwargs, mapping=template_mapping)
            sub_layout_kwargs = deep_substitute(layout_kwargs,
                                                mapping=template_mapping)
        else:
            sub_settings = settings
            sub_make_subplots_kwargs = make_subplots_kwargs
            sub_layout_kwargs = layout_kwargs

        # Resolve self
        reself = self.resolve_self(cond_kwargs=sub_settings,
                                   impacts_caching=False,
                                   silence_warnings=silence_warnings)

        # Prepare subplots
        if subplots is None:
            subplots = reself.plots_defaults['subplots']
        if subplots == 'all':
            subplots = reself.subplots
        if isinstance(subplots, dict):
            subplots = list(subplots.items())
        if isinstance(subplots, (str, tuple)):
            subplots = [subplots]

        # Prepare tags
        if tags is None:
            tags = reself.plots_defaults['tags']
        if isinstance(tags, str) and tags == 'all':
            tags = None
        if isinstance(tags, (str, tuple)):
            tags = [tags]

        # Bring to the same shape
        new_subplots = []
        for i, subplot in enumerate(subplots):
            if isinstance(subplot, str):
                subplot = (subplot, reself.subplots[subplot])
            if not isinstance(subplot, tuple):
                raise TypeError(
                    f"Subplot at index {i} must be either a string or a tuple")
            new_subplots.append(subplot)
        subplots = new_subplots

        # Handle duplicate names
        subplot_counts = Counter(list(map(lambda x: x[0], subplots)))
        subplot_i = {k: -1 for k in subplot_counts.keys()}
        subplots_dct = {}
        for i, (subplot_name, _subplot_settings) in enumerate(subplots):
            if subplot_counts[subplot_name] > 1:
                subplot_i[subplot_name] += 1
                subplot_name = subplot_name + '_' + str(
                    subplot_i[subplot_name])
            subplots_dct[subplot_name] = _subplot_settings

        # Check subplot_settings
        missed_keys = set(subplot_settings.keys()).difference(
            set(subplots_dct.keys()))
        if len(missed_keys) > 0:
            raise ValueError(
                f"Keys {missed_keys} in subplot_settings could not be matched with any subplot"
            )

        # Merge settings
        opt_arg_names_dct = {}
        custom_arg_names_dct = {}
        resolved_self_dct = {}
        mapping_dct = {}
        for subplot_name, _subplot_settings in list(subplots_dct.items()):
            opt_settings = merge_dicts(
                {name: reself
                 for name in reself.self_aliases},
                dict(column=column,
                     group_by=group_by,
                     subplot_name=subplot_name,
                     trace_names=[subplot_name],
                     silence_warnings=silence_warnings), settings)
            _subplot_settings = _subplot_settings.copy()
            passed_subplot_settings = subplot_settings.get(subplot_name, {})
            merged_settings = merge_dicts(opt_settings, _subplot_settings,
                                          passed_subplot_settings)
            subplot_template_mapping = merged_settings.pop(
                'template_mapping', {})
            template_mapping_merged = merge_dicts(template_mapping,
                                                  subplot_template_mapping)
            template_mapping_merged = deep_substitute(template_mapping_merged,
                                                      mapping=merged_settings)
            mapping = merge_dicts(template_mapping_merged, merged_settings)
            # safe because we will use deep_substitute again once layout params are known
            merged_settings = deep_substitute(merged_settings,
                                              mapping=mapping,
                                              safe=True)

            # Filter by tag
            if tags is not None:
                in_tags = merged_settings.get('tags', None)
                if in_tags is None or not match_tags(tags, in_tags):
                    subplots_dct.pop(subplot_name, None)
                    continue

            custom_arg_names = set(_subplot_settings.keys()).union(
                set(passed_subplot_settings.keys()))
            opt_arg_names = set(opt_settings.keys())
            custom_reself = reself.resolve_self(
                cond_kwargs=merged_settings,
                custom_arg_names=custom_arg_names,
                impacts_caching=True,
                silence_warnings=merged_settings['silence_warnings'])

            subplots_dct[subplot_name] = merged_settings
            custom_arg_names_dct[subplot_name] = custom_arg_names
            opt_arg_names_dct[subplot_name] = opt_arg_names
            resolved_self_dct[subplot_name] = custom_reself
            mapping_dct[subplot_name] = mapping

        # Filter subplots
        for subplot_name, _subplot_settings in list(subplots_dct.items()):
            custom_reself = resolved_self_dct[subplot_name]
            mapping = mapping_dct[subplot_name]
            _silence_warnings = _subplot_settings.get('silence_warnings')

            subplot_filters = set()
            for k in _subplot_settings.keys():
                filter_name = None
                if k.startswith('check_'):
                    filter_name = k[len('check_'):]
                elif k.startswith('inv_check_'):
                    filter_name = k[len('inv_check_'):]
                if filter_name is not None:
                    if filter_name not in filters:
                        raise ValueError(
                            f"Metric '{subplot_name}' requires filter '{filter_name}'"
                        )
                    subplot_filters.add(filter_name)

            for filter_name in subplot_filters:
                filter_settings = filters[filter_name]
                _filter_settings = deep_substitute(filter_settings,
                                                   mapping=mapping)
                filter_func = _filter_settings['filter_func']
                warning_message = _filter_settings.get('warning_message', None)
                inv_warning_message = _filter_settings.get(
                    'inv_warning_message', None)
                to_check = _subplot_settings.get('check_' + filter_name, False)
                inv_to_check = _subplot_settings.get(
                    'inv_check_' + filter_name, False)

                if to_check or inv_to_check:
                    whether_true = filter_func(custom_reself,
                                               _subplot_settings)
                    to_remove = (to_check
                                 and not whether_true) or (inv_to_check
                                                           and whether_true)
                    if to_remove:
                        if to_check and warning_message is not None and not _silence_warnings:
                            warnings.warn(warning_message)
                        if inv_to_check and inv_warning_message is not None and not _silence_warnings:
                            warnings.warn(inv_warning_message)

                        subplots_dct.pop(subplot_name, None)
                        custom_arg_names_dct.pop(subplot_name, None)
                        opt_arg_names_dct.pop(subplot_name, None)
                        resolved_self_dct.pop(subplot_name, None)
                        mapping_dct.pop(subplot_name, None)
                        break

        # Any subplots left?
        if len(subplots_dct) == 0:
            if not silence_warnings:
                warnings.warn("No subplots to plot", stacklevel=2)
            return None

        # Set up figure
        rows = sub_make_subplots_kwargs.pop('rows', len(subplots_dct))
        cols = sub_make_subplots_kwargs.pop('cols', 1)
        specs = sub_make_subplots_kwargs.pop('specs', [[{}
                                                        for _ in range(cols)]
                                                       for _ in range(rows)])
        row_col_tuples = []
        for row, row_spec in enumerate(specs):
            for col, col_spec in enumerate(row_spec):
                if col_spec is not None:
                    row_col_tuples.append((row + 1, col + 1))
        shared_xaxes = sub_make_subplots_kwargs.pop('shared_xaxes', True)
        shared_yaxes = sub_make_subplots_kwargs.pop('shared_yaxes', False)
        default_height = plotting_cfg['layout']['height']
        default_width = plotting_cfg['layout']['width'] + 50
        min_space = 10  # space between subplots with no axis sharing
        max_title_spacing = 30
        max_xaxis_spacing = 50
        max_yaxis_spacing = 100
        legend_height = 50
        if show_titles:
            title_spacing = max_title_spacing
        else:
            title_spacing = 0
        if not shared_xaxes and rows > 1:
            xaxis_spacing = max_xaxis_spacing
        else:
            xaxis_spacing = 0
        if not shared_yaxes and cols > 1:
            yaxis_spacing = max_yaxis_spacing
        else:
            yaxis_spacing = 0
        if 'height' in sub_layout_kwargs:
            height = sub_layout_kwargs.pop('height')
        else:
            height = default_height + title_spacing
            if rows > 1:
                height *= rows
                height += min_space * rows - min_space
                height += legend_height - legend_height * rows
                if shared_xaxes:
                    height += max_xaxis_spacing - max_xaxis_spacing * rows
        if 'width' in sub_layout_kwargs:
            width = sub_layout_kwargs.pop('width')
        else:
            width = default_width
            if cols > 1:
                width *= cols
                width += min_space * cols - min_space
                if shared_yaxes:
                    width += max_yaxis_spacing - max_yaxis_spacing * cols
        if height is not None:
            if 'vertical_spacing' in sub_make_subplots_kwargs:
                vertical_spacing = sub_make_subplots_kwargs.pop(
                    'vertical_spacing')
            else:
                vertical_spacing = min_space + title_spacing + xaxis_spacing
            if vertical_spacing is not None and vertical_spacing > 1:
                vertical_spacing /= height
            legend_y = 1 + (min_space + title_spacing) / height
        else:
            vertical_spacing = sub_make_subplots_kwargs.pop(
                'vertical_spacing', None)
            legend_y = 1.02
        if width is not None:
            if 'horizontal_spacing' in sub_make_subplots_kwargs:
                horizontal_spacing = sub_make_subplots_kwargs.pop(
                    'horizontal_spacing')
            else:
                horizontal_spacing = min_space + yaxis_spacing
            if horizontal_spacing is not None and horizontal_spacing > 1:
                horizontal_spacing /= width
        else:
            horizontal_spacing = sub_make_subplots_kwargs.pop(
                'horizontal_spacing', None)
        if show_titles:
            _subplot_titles = []
            for i in range(len(subplots_dct)):
                _subplot_titles.append('$title_' + str(i))
        else:
            _subplot_titles = None
        fig = make_subplots(rows=rows,
                            cols=cols,
                            specs=specs,
                            shared_xaxes=shared_xaxes,
                            shared_yaxes=shared_yaxes,
                            subplot_titles=_subplot_titles,
                            vertical_spacing=vertical_spacing,
                            horizontal_spacing=horizontal_spacing,
                            **sub_make_subplots_kwargs)
        sub_layout_kwargs = merge_dicts(
            dict(showlegend=True,
                 width=width,
                 height=height,
                 legend=dict(orientation="h",
                             yanchor="bottom",
                             y=legend_y,
                             xanchor="right",
                             x=1,
                             traceorder='normal')), sub_layout_kwargs)
        fig.update_layout(
            **sub_layout_kwargs)  # final destination for sub_layout_kwargs

        # Plot subplots
        arg_cache_dct = {}
        for i, (subplot_name,
                _subplot_settings) in enumerate(subplots_dct.items()):
            try:
                final_kwargs = _subplot_settings.copy()
                opt_arg_names = opt_arg_names_dct[subplot_name]
                custom_arg_names = custom_arg_names_dct[subplot_name]
                custom_reself = resolved_self_dct[subplot_name]
                mapping = mapping_dct[subplot_name]

                # Compute figure artifacts
                row, col = row_col_tuples[i]
                xref = 'x' if i == 0 else 'x' + str(i + 1)
                yref = 'y' if i == 0 else 'y' + str(i + 1)
                xaxis = 'xaxis' + xref[1:]
                yaxis = 'yaxis' + yref[1:]
                x_domain = get_domain(xref, fig)
                y_domain = get_domain(yref, fig)
                subplot_layout_kwargs = dict(
                    add_trace_kwargs=dict(row=row, col=col),
                    xref=xref,
                    yref=yref,
                    xaxis=xaxis,
                    yaxis=yaxis,
                    x_domain=x_domain,
                    y_domain=y_domain,
                    fig=fig,
                    pass_fig=True  # force passing fig
                )
                for k in subplot_layout_kwargs:
                    opt_arg_names.add(k)
                    if k in final_kwargs:
                        custom_arg_names.add(k)
                final_kwargs = merge_dicts(subplot_layout_kwargs, final_kwargs)
                mapping = merge_dicts(subplot_layout_kwargs, mapping)
                final_kwargs = deep_substitute(final_kwargs, mapping=mapping)

                # Clean up keys
                for k, v in list(final_kwargs.items()):
                    if k.startswith('check_') or k.startswith(
                            'inv_check_') or k in ('tags', ):
                        final_kwargs.pop(k, None)

                # Get subplot-specific values
                _column = final_kwargs.get('column')
                _group_by = final_kwargs.get('group_by')
                _silence_warnings = final_kwargs.get('silence_warnings')
                title = final_kwargs.pop('title', subplot_name)
                plot_func = final_kwargs.pop('plot_func', None)
                xaxis_kwargs = final_kwargs.pop('xaxis_kwargs', None)
                yaxis_kwargs = final_kwargs.pop('yaxis_kwargs', None)
                resolve_plot_func = final_kwargs.pop('resolve_plot_func', True)
                use_caching = final_kwargs.pop('use_caching', True)

                if plot_func is not None:
                    # Resolve plot_func
                    if resolve_plot_func:
                        if not callable(plot_func):
                            passed_kwargs_out = {}

                            def _getattr_func(
                                obj: tp.Any,
                                attr: str,
                                args: tp.ArgsLike = None,
                                kwargs: tp.KwargsLike = None,
                                call_attr: bool = True,
                                _final_kwargs: tp.Kwargs = final_kwargs,
                                _opt_arg_names: tp.Set[str] = opt_arg_names,
                                _custom_arg_names: tp.
                                Set[str] = custom_arg_names,
                                _arg_cache_dct: tp.Kwargs = arg_cache_dct
                            ) -> tp.Any:
                                if attr in final_kwargs:
                                    return final_kwargs[attr]
                                if args is None:
                                    args = ()
                                if kwargs is None:
                                    kwargs = {}
                                if obj is custom_reself and _final_kwargs.pop(
                                        'resolve_path_' + attr, True):
                                    if call_attr:
                                        return custom_reself.resolve_attr(
                                            attr,
                                            args=args,
                                            cond_kwargs={
                                                k: v
                                                for k, v in
                                                _final_kwargs.items()
                                                if k in _opt_arg_names
                                            },
                                            kwargs=kwargs,
                                            custom_arg_names=_custom_arg_names,
                                            cache_dct=_arg_cache_dct,
                                            use_caching=use_caching,
                                            passed_kwargs_out=passed_kwargs_out
                                        )
                                    return getattr(obj, attr)
                                out = getattr(obj, attr)
                                if callable(out) and call_attr:
                                    return out(*args, **kwargs)
                                return out

                            plot_func = custom_reself.deep_getattr(
                                plot_func,
                                getattr_func=_getattr_func,
                                call_last_attr=False)

                            if 'group_by' in passed_kwargs_out:
                                if 'pass_group_by' not in final_kwargs:
                                    final_kwargs.pop('group_by', None)
                        if not callable(plot_func):
                            raise TypeError("plot_func must be callable")

                        # Resolve arguments
                        func_arg_names = get_func_arg_names(plot_func)
                        for k in func_arg_names:
                            if k not in final_kwargs:
                                if final_kwargs.pop('resolve_' + k, False):
                                    try:
                                        arg_out = custom_reself.resolve_attr(
                                            k,
                                            cond_kwargs=final_kwargs,
                                            custom_arg_names=custom_arg_names,
                                            cache_dct=arg_cache_dct,
                                            use_caching=use_caching)
                                    except AttributeError:
                                        continue
                                    final_kwargs[k] = arg_out
                        for k in list(final_kwargs.keys()):
                            if k in opt_arg_names:
                                if 'pass_' + k in final_kwargs:
                                    if not final_kwargs.get(
                                            'pass_' + k):  # first priority
                                        final_kwargs.pop(k, None)
                                elif k not in func_arg_names:  # second priority
                                    final_kwargs.pop(k, None)
                        for k in list(final_kwargs.keys()):
                            if k.startswith('pass_') or k.startswith(
                                    'resolve_'):
                                final_kwargs.pop(k, None)  # cleanup

                        # Call plot_func
                        plot_func(**final_kwargs)
                    else:
                        # Do not resolve plot_func
                        plot_func(custom_reself, _subplot_settings)

                # Update global layout
                for annotation in fig.layout.annotations:
                    if 'text' in annotation and annotation[
                            'text'] == '$title_' + str(i):
                        annotation['text'] = title
                subplot_layout = dict()
                subplot_layout[xaxis] = merge_dicts(dict(title='Index'),
                                                    xaxis_kwargs)
                subplot_layout[yaxis] = merge_dicts(dict(), yaxis_kwargs)
                fig.update_layout(**subplot_layout)
            except Exception as e:
                warnings.warn(f"Subplot '{subplot_name}' raised an exception",
                              stacklevel=2)
                raise e

        # Remove duplicate legend labels
        found_ids = dict()
        unique_idx = 0
        for trace in fig.data:
            if 'name' in trace:
                name = trace['name']
            else:
                name = None
            if 'marker' in trace:
                marker = trace['marker']
            else:
                marker = {}
            if 'symbol' in marker:
                marker_symbol = marker['symbol']
            else:
                marker_symbol = None
            if 'color' in marker:
                marker_color = marker['color']
            else:
                marker_color = None
            if 'line' in trace:
                line = trace['line']
            else:
                line = {}
            if 'dash' in line:
                line_dash = line['dash']
            else:
                line_dash = None
            if 'color' in line:
                line_color = line['color']
            else:
                line_color = None

            id = (name, marker_symbol, marker_color, line_dash, line_color)
            if id in found_ids:
                if hide_id_labels:
                    trace['showlegend'] = False
                if group_id_labels:
                    trace['legendgroup'] = found_ids[id]
            else:
                if group_id_labels:
                    trace['legendgroup'] = unique_idx
                found_ids[id] = unique_idx
                unique_idx += 1

        # Remove all except the last title if sharing the same axis
        if shared_xaxes:
            i = 0
            for row in range(rows):
                for col in range(cols):
                    if specs[row][col] is not None:
                        xaxis = 'xaxis' if i == 0 else 'xaxis' + str(i + 1)
                        if row < rows - 1:
                            fig.layout[xaxis]['title'] = None
                        i += 1
        if shared_yaxes:
            i = 0
            for row in range(rows):
                for col in range(cols):
                    if specs[row][col] is not None:
                        yaxis = 'yaxis' if i == 0 else 'yaxis' + str(i + 1)
                        if col > 0:
                            fig.layout[yaxis]['title'] = None
                        i += 1

        # Return the figure
        return fig

    # ############# Docs ############# #

    @classmethod
    def build_subplots_doc(cls, source_cls: tp.Optional[type] = None) -> str:
        """Build subplots documentation."""
        if source_cls is None:
            source_cls = PlotsBuilderMixin
        return string.Template(
            inspect.cleandoc(get_dict_attr(source_cls,
                                           'subplots').__doc__)).substitute({
                                               'subplots':
                                               cls.subplots.to_doc(),
                                               'cls_name':
                                               cls.__name__
                                           })

    @classmethod
    def override_subplots_doc(cls,
                              __pdoc__: dict,
                              source_cls: tp.Optional[type] = None) -> None:
        """Call this method on each subclass that overrides `subplots`."""
        __pdoc__[cls.__name__ +
                 '.subplots'] = cls.build_subplots_doc(source_cls=source_cls)
Beispiel #29
0
class Drawdowns(Ranges):
    """Extends `vectorbt.generic.ranges.Ranges` for working with drawdown records.

    Requires `records_arr` to have all fields defined in `vectorbt.generic.enums.drawdown_dt`."""
    @property
    def field_config(self) -> Config:
        return self._field_config

    def __init__(self,
                 wrapper: ArrayWrapper,
                 records_arr: tp.RecordArray,
                 ts: tp.Optional[tp.ArrayLike] = None,
                 **kwargs) -> None:
        Ranges.__init__(self, wrapper, records_arr, ts=ts, **kwargs)
        self._ts = ts

    def indexing_func(self: DrawdownsT,
                      pd_indexing_func: tp.PandasIndexingFunc,
                      **kwargs) -> DrawdownsT:
        """Perform indexing on `Drawdowns`."""
        new_wrapper, new_records_arr, _, col_idxs = \
            Ranges.indexing_func_meta(self, pd_indexing_func, **kwargs)
        if self.ts is not None:
            new_ts = new_wrapper.wrap(self.ts.values[:, col_idxs],
                                      group_by=False)
        else:
            new_ts = None
        return self.replace(wrapper=new_wrapper,
                            records_arr=new_records_arr,
                            ts=new_ts)

    @classmethod
    def from_ts(cls: tp.Type[DrawdownsT],
                ts: tp.ArrayLike,
                attach_ts: bool = True,
                wrapper_kwargs: tp.KwargsLike = None,
                **kwargs) -> DrawdownsT:
        """Build `Drawdowns` from time series `ts`.

        `**kwargs` will be passed to `Drawdowns.__init__`."""
        ts_pd = to_pd_array(ts)
        records_arr = nb.get_drawdowns_nb(to_2d_array(ts_pd))
        wrapper = ArrayWrapper.from_obj(ts_pd, **merge_dicts({},
                                                             wrapper_kwargs))
        return cls(wrapper,
                   records_arr,
                   ts=ts_pd if attach_ts else None,
                   **kwargs)

    @property
    def ts(self) -> tp.Optional[tp.SeriesFrame]:
        """Original time series that records are built from (optional)."""
        return self._ts

    # ############# Drawdown ############# #

    @cached_property
    def drawdown(self) -> MappedArray:
        """See `vectorbt.generic.nb.dd_drawdown_nb`.

        Takes into account both recovered and active drawdowns."""
        drawdown = nb.dd_drawdown_nb(self.get_field_arr('peak_val'),
                                     self.get_field_arr('valley_val'))
        return self.map_array(drawdown)

    @cached_method
    def avg_drawdown(self,
                     group_by: tp.GroupByLike = None,
                     wrap_kwargs: tp.KwargsLike = None,
                     **kwargs) -> tp.MaybeSeries:
        """Average drawdown (ADD).

        Based on `Drawdowns.drawdown`."""
        wrap_kwargs = merge_dicts(dict(name_or_index='avg_drawdown'),
                                  wrap_kwargs)
        return self.drawdown.mean(group_by=group_by,
                                  wrap_kwargs=wrap_kwargs,
                                  **kwargs)

    @cached_method
    def max_drawdown(self,
                     group_by: tp.GroupByLike = None,
                     wrap_kwargs: tp.KwargsLike = None,
                     **kwargs) -> tp.MaybeSeries:
        """Maximum drawdown (MDD).

        Based on `Drawdowns.drawdown`."""
        wrap_kwargs = merge_dicts(dict(name_or_index='max_drawdown'),
                                  wrap_kwargs)
        return self.drawdown.min(group_by=group_by,
                                 wrap_kwargs=wrap_kwargs,
                                 **kwargs)

    # ############# Recovery ############# #

    @cached_property
    def recovery_return(self) -> MappedArray:
        """See `vectorbt.generic.nb.dd_recovery_return_nb`.

        Takes into account both recovered and active drawdowns."""
        recovery_return = nb.dd_recovery_return_nb(
            self.get_field_arr('valley_val'), self.get_field_arr('end_val'))
        return self.map_array(recovery_return)

    @cached_method
    def avg_recovery_return(self,
                            group_by: tp.GroupByLike = None,
                            wrap_kwargs: tp.KwargsLike = None,
                            **kwargs) -> tp.MaybeSeries:
        """Average recovery return.

        Based on `Drawdowns.recovery_return`."""
        wrap_kwargs = merge_dicts(dict(name_or_index='avg_recovery_return'),
                                  wrap_kwargs)
        return self.recovery_return.mean(group_by=group_by,
                                         wrap_kwargs=wrap_kwargs,
                                         **kwargs)

    @cached_method
    def max_recovery_return(self,
                            group_by: tp.GroupByLike = None,
                            wrap_kwargs: tp.KwargsLike = None,
                            **kwargs) -> tp.MaybeSeries:
        """Maximum recovery return.

        Based on `Drawdowns.recovery_return`."""
        wrap_kwargs = merge_dicts(dict(name_or_index='max_recovery_return'),
                                  wrap_kwargs)
        return self.recovery_return.max(group_by=group_by,
                                        wrap_kwargs=wrap_kwargs,
                                        **kwargs)

    # ############# Duration ############# #

    @cached_property
    def decline_duration(self) -> MappedArray:
        """See `vectorbt.generic.nb.dd_decline_duration_nb`.

        Takes into account both recovered and active drawdowns."""
        decline_duration = nb.dd_decline_duration_nb(
            self.get_field_arr('start_idx'), self.get_field_arr('valley_idx'))
        return self.map_array(decline_duration)

    @cached_property
    def recovery_duration(self) -> MappedArray:
        """See `vectorbt.generic.nb.dd_recovery_duration_nb`.

        A value higher than 1 means the recovery was slower than the decline.

        Takes into account both recovered and active drawdowns."""
        recovery_duration = nb.dd_recovery_duration_nb(
            self.get_field_arr('valley_idx'), self.get_field_arr('end_idx'))
        return self.map_array(recovery_duration)

    @cached_property
    def recovery_duration_ratio(self) -> MappedArray:
        """See `vectorbt.generic.nb.dd_recovery_duration_ratio_nb`.

        Takes into account both recovered and active drawdowns."""
        recovery_duration_ratio = nb.dd_recovery_duration_ratio_nb(
            self.get_field_arr('start_idx'), self.get_field_arr('valley_idx'),
            self.get_field_arr('end_idx'))
        return self.map_array(recovery_duration_ratio)

    # ############# Status: Active ############# #

    @cached_method
    def active_drawdown(self,
                        group_by: tp.GroupByLike = None,
                        wrap_kwargs: tp.KwargsLike = None) -> tp.MaybeSeries:
        """Drawdown of the last active drawdown only.

        Does not support grouping."""
        if self.wrapper.grouper.is_grouped(group_by=group_by):
            raise ValueError("Grouping is not supported by this method")
        wrap_kwargs = merge_dicts(dict(name_or_index='active_drawdown'),
                                  wrap_kwargs)
        active = self.active
        curr_end_val = active.end_val.nth(-1, group_by=group_by)
        curr_peak_val = active.peak_val.nth(-1, group_by=group_by)
        curr_drawdown = (curr_end_val - curr_peak_val) / curr_peak_val
        return self.wrapper.wrap_reduced(curr_drawdown,
                                         group_by=group_by,
                                         **wrap_kwargs)

    @cached_method
    def active_duration(self,
                        group_by: tp.GroupByLike = None,
                        wrap_kwargs: tp.KwargsLike = None,
                        **kwargs) -> tp.MaybeSeries:
        """Duration of the last active drawdown only.

        Does not support grouping."""
        if self.wrapper.grouper.is_grouped(group_by=group_by):
            raise ValueError("Grouping is not supported by this method")
        wrap_kwargs = merge_dicts(
            dict(to_timedelta=True, name_or_index='active_duration'),
            wrap_kwargs)
        return self.active.duration.nth(-1,
                                        group_by=group_by,
                                        wrap_kwargs=wrap_kwargs,
                                        **kwargs)

    @cached_method
    def active_recovery(self,
                        group_by: tp.GroupByLike = None,
                        wrap_kwargs: tp.KwargsLike = None) -> tp.MaybeSeries:
        """Recovery of the last active drawdown only.

        Does not support grouping."""
        if self.wrapper.grouper.is_grouped(group_by=group_by):
            raise ValueError("Grouping is not supported by this method")
        wrap_kwargs = merge_dicts(dict(name_or_index='active_recovery'),
                                  wrap_kwargs)
        active = self.active
        curr_peak_val = active.peak_val.nth(-1, group_by=group_by)
        curr_end_val = active.end_val.nth(-1, group_by=group_by)
        curr_valley_val = active.valley_val.nth(-1, group_by=group_by)
        curr_recovery = (curr_end_val - curr_valley_val) / (curr_peak_val -
                                                            curr_valley_val)
        return self.wrapper.wrap_reduced(curr_recovery,
                                         group_by=group_by,
                                         **wrap_kwargs)

    @cached_method
    def active_recovery_return(self,
                               group_by: tp.GroupByLike = None,
                               wrap_kwargs: tp.KwargsLike = None,
                               **kwargs) -> tp.MaybeSeries:
        """Recovery return of the last active drawdown only.

        Does not support grouping."""
        if self.wrapper.grouper.is_grouped(group_by=group_by):
            raise ValueError("Grouping is not supported by this method")
        wrap_kwargs = merge_dicts(dict(name_or_index='active_recovery_return'),
                                  wrap_kwargs)
        return self.active.recovery_return.nth(-1,
                                               group_by=group_by,
                                               wrap_kwargs=wrap_kwargs,
                                               **kwargs)

    @cached_method
    def active_recovery_duration(self,
                                 group_by: tp.GroupByLike = None,
                                 wrap_kwargs: tp.KwargsLike = None,
                                 **kwargs) -> tp.MaybeSeries:
        """Recovery duration of the last active drawdown only.

        Does not support grouping."""
        if self.wrapper.grouper.is_grouped(group_by=group_by):
            raise ValueError("Grouping is not supported by this method")
        wrap_kwargs = merge_dicts(
            dict(to_timedelta=True, name_or_index='active_recovery_duration'),
            wrap_kwargs)
        return self.active.recovery_duration.nth(-1,
                                                 group_by=group_by,
                                                 wrap_kwargs=wrap_kwargs,
                                                 **kwargs)

    # ############# Stats ############# #

    @property
    def stats_defaults(self) -> tp.Kwargs:
        """Defaults for `Drawdowns.stats`.

        Merges `vectorbt.generic.ranges.Ranges.stats_defaults` and
        `drawdowns.stats` from `vectorbt._settings.settings`."""
        from vectorbt._settings import settings
        drawdowns_stats_cfg = settings['drawdowns']['stats']

        return merge_dicts(Ranges.stats_defaults.__get__(self),
                           drawdowns_stats_cfg)

    _metrics: tp.ClassVar[Config] = Config(dict(
        start=dict(title='Start',
                   calc_func=lambda self: self.wrapper.index[0],
                   agg_func=None,
                   tags='wrapper'),
        end=dict(title='End',
                 calc_func=lambda self: self.wrapper.index[-1],
                 agg_func=None,
                 tags='wrapper'),
        period=dict(title='Period',
                    calc_func=lambda self: len(self.wrapper.index),
                    apply_to_timedelta=True,
                    agg_func=None,
                    tags='wrapper'),
        coverage=dict(title='Coverage [%]',
                      calc_func='coverage',
                      post_calc_func=lambda self, out, settings: out * 100,
                      tags=['ranges', 'duration']),
        total_records=dict(title='Total Records',
                           calc_func='count',
                           tags='records'),
        total_recovered=dict(title='Total Recovered Drawdowns',
                             calc_func='recovered.count',
                             tags='drawdowns'),
        total_active=dict(title='Total Active Drawdowns',
                          calc_func='active.count',
                          tags='drawdowns'),
        active_dd=dict(title='Active Drawdown [%]',
                       calc_func='active_drawdown',
                       post_calc_func=lambda self, out, settings: -out * 100,
                       check_is_not_grouped=True,
                       tags=['drawdowns', 'active']),
        active_duration=dict(title='Active Duration',
                             calc_func='active_duration',
                             fill_wrap_kwargs=True,
                             check_is_not_grouped=True,
                             tags=['drawdowns', 'active', 'duration']),
        active_recovery=dict(
            title='Active Recovery [%]',
            calc_func='active_recovery',
            post_calc_func=lambda self, out, settings: out * 100,
            check_is_not_grouped=True,
            tags=['drawdowns', 'active']),
        active_recovery_return=dict(
            title='Active Recovery Return [%]',
            calc_func='active_recovery_return',
            post_calc_func=lambda self, out, settings: out * 100,
            check_is_not_grouped=True,
            tags=['drawdowns', 'active']),
        active_recovery_duration=dict(title='Active Recovery Duration',
                                      calc_func='active_recovery_duration',
                                      fill_wrap_kwargs=True,
                                      check_is_not_grouped=True,
                                      tags=['drawdowns', 'active',
                                            'duration']),
        max_dd=dict(
            title='Max Drawdown [%]',
            calc_func=RepEval(
                "'max_drawdown' if incl_active else 'recovered.max_drawdown'"),
            post_calc_func=lambda self, out, settings: -out * 100,
            tags=RepEval(
                "['drawdowns'] if incl_active else ['drawdowns', 'recovered']")
        ),
        avg_dd=dict(
            title='Avg Drawdown [%]',
            calc_func=RepEval(
                "'avg_drawdown' if incl_active else 'recovered.avg_drawdown'"),
            post_calc_func=lambda self, out, settings: -out * 100,
            tags=RepEval(
                "['drawdowns'] if incl_active else ['drawdowns', 'recovered']")
        ),
        max_dd_duration=dict(
            title='Max Drawdown Duration',
            calc_func=RepEval(
                "'max_duration' if incl_active else 'recovered.max_duration'"),
            fill_wrap_kwargs=True,
            tags=RepEval(
                "['drawdowns', 'duration'] if incl_active else ['drawdowns', 'recovered', 'duration']"
            )),
        avg_dd_duration=dict(
            title='Avg Drawdown Duration',
            calc_func=RepEval(
                "'avg_duration' if incl_active else 'recovered.avg_duration'"),
            fill_wrap_kwargs=True,
            tags=RepEval(
                "['drawdowns', 'duration'] if incl_active else ['drawdowns', 'recovered', 'duration']"
            )),
        max_return=dict(title='Max Recovery Return [%]',
                        calc_func='recovered.recovery_return.max',
                        post_calc_func=lambda self, out, settings: out * 100,
                        tags=['drawdowns', 'recovered']),
        avg_return=dict(title='Avg Recovery Return [%]',
                        calc_func='recovered.recovery_return.mean',
                        post_calc_func=lambda self, out, settings: out * 100,
                        tags=['drawdowns', 'recovered']),
        max_recovery_duration=dict(title='Max Recovery Duration',
                                   calc_func='recovered.recovery_duration.max',
                                   apply_to_timedelta=True,
                                   tags=['drawdowns', 'recovered',
                                         'duration']),
        avg_recovery_duration=dict(
            title='Avg Recovery Duration',
            calc_func='recovered.recovery_duration.mean',
            apply_to_timedelta=True,
            tags=['drawdowns', 'recovered', 'duration']),
        recovery_duration_ratio=dict(
            title='Avg Recovery Duration Ratio',
            calc_func='recovered.recovery_duration_ratio.mean',
            tags=['drawdowns', 'recovered'])),
                                           copy_kwargs=dict(copy_mode='deep'))

    @property
    def metrics(self) -> Config:
        return self._metrics

    # ############# Plotting ############# #

    def plot(self,
             column: tp.Optional[tp.Label] = None,
             top_n: int = 5,
             plot_zones: bool = True,
             ts_trace_kwargs: tp.KwargsLike = None,
             peak_trace_kwargs: tp.KwargsLike = None,
             valley_trace_kwargs: tp.KwargsLike = None,
             recovery_trace_kwargs: tp.KwargsLike = None,
             active_trace_kwargs: tp.KwargsLike = None,
             decline_shape_kwargs: tp.KwargsLike = None,
             recovery_shape_kwargs: tp.KwargsLike = None,
             active_shape_kwargs: tp.KwargsLike = None,
             add_trace_kwargs: tp.KwargsLike = None,
             xref: str = 'x',
             yref: str = 'y',
             fig: tp.Optional[tp.BaseFigure] = None,
             **layout_kwargs) -> tp.BaseFigure:  # pragma: no cover
        """Plot drawdowns.

        Args:
            column (str): Name of the column to plot.
            top_n (int): Filter top N drawdown records by maximum drawdown.
            plot_zones (bool): Whether to plot zones.
            ts_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `Drawdowns.ts`.
            peak_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for peak values.
            valley_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for valley values.
            recovery_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for recovery values.
            active_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for active recovery values.
            decline_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for decline zones.
            recovery_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for recovery zones.
            active_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for active recovery zones.
            add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
            xref (str): X coordinate axis.
            yref (str): Y coordinate axis.
            fig (Figure or FigureWidget): Figure to add traces to.
            **layout_kwargs: Keyword arguments for layout.

        ## Example

        ```python-repl
        >>> import vectorbt as vbt
        >>> from datetime import datetime, timedelta
        >>> import pandas as pd

        >>> price = pd.Series([1, 2, 1, 2, 3, 2, 1, 2], name='Price')
        >>> price.index = [datetime(2020, 1, 1) + timedelta(days=i) for i in range(len(price))]
        >>> vbt.Drawdowns.from_ts(price, wrapper_kwargs=dict(freq='1 day')).plot()
        ```

        ![](/docs/img/drawdowns_plot.svg)
        """
        from vectorbt._settings import settings
        plotting_cfg = settings['plotting']

        self_col = self.select_one(column=column, group_by=False)
        if top_n is not None:
            # Drawdowns is negative, thus top_n becomes bottom_n
            self_col = self_col.apply_mask(
                self_col.drawdown.bottom_n_mask(top_n))

        if ts_trace_kwargs is None:
            ts_trace_kwargs = {}
        ts_trace_kwargs = merge_dicts(
            dict(line=dict(color=plotting_cfg['color_schema']['blue'])),
            ts_trace_kwargs)
        if peak_trace_kwargs is None:
            peak_trace_kwargs = {}
        if valley_trace_kwargs is None:
            valley_trace_kwargs = {}
        if recovery_trace_kwargs is None:
            recovery_trace_kwargs = {}
        if active_trace_kwargs is None:
            active_trace_kwargs = {}
        if decline_shape_kwargs is None:
            decline_shape_kwargs = {}
        if recovery_shape_kwargs is None:
            recovery_shape_kwargs = {}
        if active_shape_kwargs is None:
            active_shape_kwargs = {}
        if add_trace_kwargs is None:
            add_trace_kwargs = {}

        if fig is None:
            fig = make_figure()
        fig.update_layout(**layout_kwargs)
        y_domain = get_domain(yref, fig)

        if self_col.ts is not None:
            fig = self_col.ts.vbt.plot(trace_kwargs=ts_trace_kwargs,
                                       add_trace_kwargs=add_trace_kwargs,
                                       fig=fig)

        if self_col.count() > 0:
            # Extract information
            id_ = self_col.get_field_arr('id')
            id_title = self_col.get_field_title('id')

            peak_idx = self_col.get_map_field_to_index('peak_idx')
            peak_idx_title = self_col.get_field_title('peak_idx')

            if self_col.ts is not None:
                peak_val = self_col.ts.loc[peak_idx]
            else:
                peak_val = self_col.get_field_arr('peak_val')
            peak_val_title = self_col.get_field_title('peak_val')

            valley_idx = self_col.get_map_field_to_index('valley_idx')
            valley_idx_title = self_col.get_field_title('valley_idx')

            if self_col.ts is not None:
                valley_val = self_col.ts.loc[valley_idx]
            else:
                valley_val = self_col.get_field_arr('valley_val')
            valley_val_title = self_col.get_field_title('valley_val')

            end_idx = self_col.get_map_field_to_index('end_idx')
            end_idx_title = self_col.get_field_title('end_idx')

            if self_col.ts is not None:
                end_val = self_col.ts.loc[end_idx]
            else:
                end_val = self_col.get_field_arr('end_val')
            end_val_title = self_col.get_field_title('end_val')

            drawdown = self_col.drawdown.values
            recovery_return = self_col.recovery_return.values
            decline_duration = np.vectorize(str)(self_col.wrapper.to_timedelta(
                self_col.decline_duration.values,
                to_pd=True,
                silence_warnings=True))
            recovery_duration = np.vectorize(str)(
                self_col.wrapper.to_timedelta(
                    self_col.recovery_duration.values,
                    to_pd=True,
                    silence_warnings=True))
            duration = np.vectorize(str)(self_col.wrapper.to_timedelta(
                self_col.duration.values, to_pd=True, silence_warnings=True))

            status = self_col.get_field_arr('status')

            peak_mask = peak_idx != np.roll(
                end_idx, 1)  # peak and recovery at same time -> recovery wins
            if peak_mask.any():
                # Plot peak markers
                peak_customdata = id_[peak_mask][:, None]
                peak_scatter = go.Scatter(
                    x=peak_idx[peak_mask],
                    y=peak_val[peak_mask],
                    mode='markers',
                    marker=dict(
                        symbol='diamond',
                        color=plotting_cfg['contrast_color_schema']['blue'],
                        size=7,
                        line=dict(width=1,
                                  color=adjust_lightness(
                                      plotting_cfg['contrast_color_schema']
                                      ['blue']))),
                    name='Peak',
                    customdata=peak_customdata,
                    hovertemplate=f"{id_title}: %{{customdata[0]}}"
                    f"<br>{peak_idx_title}: %{{x}}"
                    f"<br>{peak_val_title}: %{{y}}")
                peak_scatter.update(**peak_trace_kwargs)
                fig.add_trace(peak_scatter, **add_trace_kwargs)

            recovered_mask = status == DrawdownStatus.Recovered
            if recovered_mask.any():
                # Plot valley markers
                valley_customdata = np.stack(
                    (id_[recovered_mask], drawdown[recovered_mask],
                     decline_duration[recovered_mask]),
                    axis=1)
                valley_scatter = go.Scatter(
                    x=valley_idx[recovered_mask],
                    y=valley_val[recovered_mask],
                    mode='markers',
                    marker=dict(
                        symbol='diamond',
                        color=plotting_cfg['contrast_color_schema']['red'],
                        size=7,
                        line=dict(width=1,
                                  color=adjust_lightness(
                                      plotting_cfg['contrast_color_schema']
                                      ['red']))),
                    name='Valley',
                    customdata=valley_customdata,
                    hovertemplate=f"{id_title}: %{{customdata[0]}}"
                    f"<br>{valley_idx_title}: %{{x}}"
                    f"<br>{valley_val_title}: %{{y}}"
                    f"<br>Drawdown: %{{customdata[1]:.2%}}"
                    f"<br>Duration: %{{customdata[2]}}")
                valley_scatter.update(**valley_trace_kwargs)
                fig.add_trace(valley_scatter, **add_trace_kwargs)

                if plot_zones:
                    # Plot drawdown zones
                    for i in range(len(id_[recovered_mask])):
                        fig.add_shape(**merge_dicts(
                            dict(
                                type="rect",
                                xref=xref,
                                yref="paper",
                                x0=peak_idx[recovered_mask][i],
                                y0=y_domain[0],
                                x1=valley_idx[recovered_mask][i],
                                y1=y_domain[1],
                                fillcolor='red',
                                opacity=0.2,
                                layer="below",
                                line_width=0,
                            ), decline_shape_kwargs))

                # Plot recovery markers
                recovery_customdata = np.stack(
                    (id_[recovered_mask], recovery_return[recovered_mask],
                     recovery_duration[recovered_mask]),
                    axis=1)
                recovery_scatter = go.Scatter(
                    x=end_idx[recovered_mask],
                    y=end_val[recovered_mask],
                    mode='markers',
                    marker=dict(
                        symbol='diamond',
                        color=plotting_cfg['contrast_color_schema']['green'],
                        size=7,
                        line=dict(width=1,
                                  color=adjust_lightness(
                                      plotting_cfg['contrast_color_schema']
                                      ['green']))),
                    name='Recovery/Peak',
                    customdata=recovery_customdata,
                    hovertemplate=f"{id_title}: %{{customdata[0]}}"
                    f"<br>{end_idx_title}: %{{x}}"
                    f"<br>{end_val_title}: %{{y}}"
                    f"<br>Return: %{{customdata[1]:.2%}}"
                    f"<br>Duration: %{{customdata[2]}}")
                recovery_scatter.update(**recovery_trace_kwargs)
                fig.add_trace(recovery_scatter, **add_trace_kwargs)

                if plot_zones:
                    # Plot recovery zones
                    for i in range(len(id_[recovered_mask])):
                        fig.add_shape(**merge_dicts(
                            dict(
                                type="rect",
                                xref=xref,
                                yref="paper",
                                x0=valley_idx[recovered_mask][i],
                                y0=y_domain[0],
                                x1=end_idx[recovered_mask][i],
                                y1=y_domain[1],
                                fillcolor='green',
                                opacity=0.2,
                                layer="below",
                                line_width=0,
                            ), recovery_shape_kwargs))

            # Plot active markers
            active_mask = status == DrawdownStatus.Active
            if active_mask.any():
                active_customdata = np.stack(
                    (id_[active_mask], drawdown[active_mask],
                     duration[active_mask]),
                    axis=1)
                active_scatter = go.Scatter(
                    x=end_idx[active_mask],
                    y=end_val[active_mask],
                    mode='markers',
                    marker=dict(
                        symbol='diamond',
                        color=plotting_cfg['contrast_color_schema']['orange'],
                        size=7,
                        line=dict(width=1,
                                  color=adjust_lightness(
                                      plotting_cfg['contrast_color_schema']
                                      ['orange']))),
                    name='Active',
                    customdata=active_customdata,
                    hovertemplate=f"{id_title}: %{{customdata[0]}}"
                    f"<br>{end_idx_title}: %{{x}}"
                    f"<br>{end_val_title}: %{{y}}"
                    f"<br>Return: %{{customdata[1]:.2%}}"
                    f"<br>Duration: %{{customdata[2]}}")
                active_scatter.update(**active_trace_kwargs)
                fig.add_trace(active_scatter, **add_trace_kwargs)

                if plot_zones:
                    # Plot active drawdown zones
                    for i in range(len(id_[active_mask])):
                        fig.add_shape(**merge_dicts(
                            dict(
                                type="rect",
                                xref=xref,
                                yref="paper",
                                x0=peak_idx[active_mask][i],
                                y0=y_domain[0],
                                x1=end_idx[active_mask][i],
                                y1=y_domain[1],
                                fillcolor='orange',
                                opacity=0.2,
                                layer="below",
                                line_width=0,
                            ), active_shape_kwargs))

        return fig

    @property
    def plots_defaults(self) -> tp.Kwargs:
        """Defaults for `Drawdowns.plots`.

        Merges `vectorbt.generic.ranges.Ranges.plots_defaults` and
        `drawdowns.plots` from `vectorbt._settings.settings`."""
        from vectorbt._settings import settings
        drawdowns_plots_cfg = settings['drawdowns']['plots']

        return merge_dicts(Ranges.plots_defaults.__get__(self),
                           drawdowns_plots_cfg)

    _subplots: tp.ClassVar[Config] = Config(
        dict(plot=dict(title="Drawdowns",
                       check_is_not_grouped=True,
                       plot_func='plot',
                       tags='drawdowns')),
        copy_kwargs=dict(copy_mode='deep'))

    @property
    def subplots(self) -> Config:
        return self._subplots
Beispiel #30
0
class StatsBuilderMixin(metaclass=MetaStatsBuilderMixin):
    """Mixin that implements `StatsBuilderMixin.stats`.

    Required to be a subclass of `vectorbt.base.array_wrapper.Wrapping`."""
    def __init__(self) -> None:
        checks.assert_instance_of(self, Wrapping)

        # Copy writeable attrs
        self._metrics = self.__class__._metrics.copy()

    @property
    def writeable_attrs(self) -> tp.Set[str]:
        """Set of writeable attributes that will be saved/copied along with the config."""
        return {'_metrics'}

    @property
    def stats_defaults(self) -> tp.Kwargs:
        """Defaults for `StatsBuilderMixin.stats`.

        See `stats_builder` in `vectorbt._settings.settings`."""
        from vectorbt._settings import settings
        stats_builder_cfg = settings['stats_builder']

        return merge_dicts(stats_builder_cfg,
                           dict(settings=dict(freq=self.wrapper.freq)))

    _metrics: tp.ClassVar[Config] = Config(dict(
        start=dict(title='Start',
                   calc_func=lambda self: self.wrapper.index[0],
                   agg_func=None,
                   tags='wrapper'),
        end=dict(title='End',
                 calc_func=lambda self: self.wrapper.index[-1],
                 agg_func=None,
                 tags='wrapper'),
        period=dict(title='Period',
                    calc_func=lambda self: len(self.wrapper.index),
                    apply_to_timedelta=True,
                    agg_func=None,
                    tags='wrapper')),
                                           copy_kwargs=dict(copy_mode='deep'))

    @property
    def metrics(self) -> Config:
        """Metrics supported by `${cls_name}`.

        ```json
        ${metrics}
        ```

        Returns `${cls_name}._metrics`, which gets (deep) copied upon creation of each instance.
        Thus, changing this config won't affect the class.

        To change metrics, you can either change the config in-place, override this property,
        or overwrite the instance variable `${cls_name}._metrics`."""
        return self._metrics

    def stats(
            self,
            metrics: tp.Optional[tp.MaybeIterable[tp.Union[str, tp.Tuple[
                str, tp.Kwargs]]]] = None,
            tags: tp.Optional[tp.MaybeIterable[str]] = None,
            column: tp.Optional[tp.Label] = None,
            group_by: tp.GroupByLike = None,
            agg_func: tp.Optional[tp.Callable] = np.mean,
            silence_warnings: tp.Optional[bool] = None,
            template_mapping: tp.Optional[tp.Mapping] = None,
            settings: tp.KwargsLike = None,
            filters: tp.KwargsLike = None,
            metric_settings: tp.KwargsLike = None
    ) -> tp.Optional[tp.SeriesFrame]:
        """Compute various metrics on this object.

        Args:
            metrics (str, tuple, iterable, or dict): Metrics to calculate.

                Each element can be either:

                * a metric name (see keys in `StatsBuilderMixin.metrics`)
                * a tuple of a metric name and a settings dict as in `StatsBuilderMixin.metrics`.

                The settings dict can contain the following keys:

                * `title`: Title of the metric. Defaults to the name.
                * `tags`: Single or multiple tags to associate this metric with.
                    If any of these tags is in `tags`, keeps this metric.
                * `check_{filter}` and `inv_check_{filter}`: Whether to check this metric against a
                    filter defined in `filters`. True (or False for inverse) means to keep this metric.
                * `calc_func` (required): Calculation function for custom metrics.
                    Should return either a scalar for one column/group, pd.Series for multiple columns/groups,
                    or a dict of such for multiple sub-metrics.
                * `resolve_calc_func`: whether to resolve `calc_func`. If the function can be accessed
                    by traversing attributes of this object, you can specify the path to this function
                    as a string (see `vectorbt.utils.attr_.deep_getattr` for the path format).
                    If `calc_func` is a function, arguments from merged metric settings are matched with
                    arguments in the signature (see below). If `resolve_calc_func` is False, `calc_func`
                    should accept (resolved) self and dictionary of merged metric settings.
                    Defaults to True.
                * `post_calc_func`: Function to post-process the result of `calc_func`.
                    Should accept (resolved) self, output of `calc_func`, and dictionary of merged metric settings,
                    and return whatever is acceptable to be returned by `calc_func`. Defaults to None.
                * `fill_wrap_kwargs`: Whether to fill `wrap_kwargs` with `to_timedelta` and `silence_warnings`.
                    Defaults to False.
                * `apply_to_timedelta`: Whether to apply `vectorbt.base.array_wrapper.ArrayWrapper.to_timedelta`
                    on the result. To disable this globally, pass `to_timedelta=False` in `settings`.
                    Defaults to False.
                * `pass_{arg}`: Whether to pass any argument from the settings (see below). Defaults to True if
                    this argument was found in the function's signature. Set to False to not pass.
                    If argument to be passed was not found, `pass_{arg}` is removed.
                * `resolve_path_{arg}`: Whether to resolve an argument that is meant to be an attribute of
                    this object and is the first part of the path of `calc_func`. Passes only optional arguments.
                    Defaults to True. See `vectorbt.utils.attr_.AttrResolver.resolve_attr`.
                * `resolve_{arg}`: Whether to resolve an argument that is meant to be an attribute of
                    this object and is present in the function's signature. Defaults to False.
                    See `vectorbt.utils.attr_.AttrResolver.resolve_attr`.
                * `template_mapping`: Mapping to replace templates in metric settings. Used across all settings.
                * Any other keyword argument that overrides the settings or is passed directly to `calc_func`.

                If `resolve_calc_func` is True, the calculation function may "request" any of the
                following arguments by accepting them or if `pass_{arg}` was found in the settings dict:

                * Each of `vectorbt.utils.attr_.AttrResolver.self_aliases`: original object
                    (ungrouped, with no column selected)
                * `group_by`: won't be passed if it was used in resolving the first attribute of `calc_func`
                    specified as a path, use `pass_group_by=True` to pass anyway
                * `column`
                * `metric_name`
                * `agg_func`
                * `silence_warnings`
                * `to_timedelta`: replaced by True if None and frequency is set
                * Any argument from `settings`
                * Any attribute of this object if it meant to be resolved
                    (see `vectorbt.utils.attr_.AttrResolver.resolve_attr`)

                Pass `metrics='all'` to calculate all supported metrics.
            tags (str or iterable): Tags to select.

                See `vectorbt.utils.tags.match_tags`.
            column (str): Name of the column/group.

                !!! hint
                    There are two ways to select a column: `obj['a'].stats()` and `obj.stats(column='a')`.
                    They both accomplish the same thing but in different ways: `obj['a'].stats()` computes
                    statistics of the column 'a' only, while `obj.stats(column='a')` computes statistics of
                    all columns first and only then selects the column 'a'. The first method is preferred
                    when you have a lot of data or caching is disabled. The second method is preferred when
                    most attributes have already been cached.
            group_by (any): Group or ungroup columns. See `vectorbt.base.column_grouper.ColumnGrouper`.
            agg_func (callable): Aggregation function to aggregate statistics across all columns.
                Defaults to mean.

                Should take `pd.Series` and return a const.

                Has only effect if `column` was specified or this object contains only one column of data.

                If `agg_func` has been overridden by a metric:

                * it only takes effect if global `agg_func` is not None
                * will raise a warning if it's None but the result of calculation has multiple values
            silence_warnings (bool): Whether to silence all warnings.
            template_mapping (mapping): Global mapping to replace templates.

                Gets merged over `template_mapping` from `StatsBuilderMixin.stats_defaults`.

                Applied on `settings` and then on each metric settings.
            filters (dict): Filters to apply.

                Each item consists of the filter name and settings dict.

                The settings dict can contain the following keys:

                * `filter_func`: Filter function that should accept resolved self and
                    merged settings for a metric, and return either True or False.
                * `warning_message`: Warning message to be shown when skipping a metric.
                    Can be a template that will be substituted using merged metric settings as mapping.
                    Defaults to None.
                * `inv_warning_message`: Same as `warning_message` but for inverse checks.

                Gets merged over `filters` from `StatsBuilderMixin.stats_defaults`.
            settings (dict): Global settings and resolution arguments.

                Extends/overrides `settings` from `StatsBuilderMixin.stats_defaults`.
                Gets extended/overridden by metric settings.
            metric_settings (dict): Keyword arguments for each metric.

                Extends/overrides all global and metric settings.

        For template logic, see `vectorbt.utils.template`.

        For defaults, see `StatsBuilderMixin.stats_defaults`.

        !!! hint
            There are two types of arguments: optional (or resolution) and mandatory arguments.
            Optional arguments are only passed if they are found in the function's signature.
            Mandatory arguments are passed regardless of this. Optional arguments can only be defined
            using `settings` (that is, globally), while mandatory arguments can be defined both using
            default metric settings and `{metric_name}_kwargs`. Overriding optional arguments using default
            metric settings or `{metric_name}_kwargs` won't turn them into mandatory. For this, pass `pass_{arg}=True`.

        !!! hint
            Make sure to resolve and then to re-use as many object attributes as possible to
            utilize built-in caching (even if global caching is disabled).

        ## Example

        See `vectorbt.portfolio.base` for examples.
        """
        # Resolve defaults
        if silence_warnings is None:
            silence_warnings = self.stats_defaults.get('silence_warnings',
                                                       False)
        template_mapping = merge_dicts(
            self.stats_defaults.get('template_mapping', {}), template_mapping)
        filters = merge_dicts(self.stats_defaults.get('filters', {}), filters)
        settings = merge_dicts(self.stats_defaults.get('settings', {}),
                               settings)
        metric_settings = merge_dicts(
            self.stats_defaults.get('metric_settings', {}), metric_settings)

        # Replace templates globally (not used at metric level)
        if len(template_mapping) > 0:
            sub_settings = deep_substitute(settings, mapping=template_mapping)
        else:
            sub_settings = settings

        # Resolve self
        reself = self.resolve_self(cond_kwargs=sub_settings,
                                   impacts_caching=False,
                                   silence_warnings=silence_warnings)

        # Prepare metrics
        if metrics is None:
            metrics = reself.stats_defaults.get('metrics', 'all')
        if metrics == 'all':
            metrics = reself.metrics
        if isinstance(metrics, dict):
            metrics = list(metrics.items())
        if isinstance(metrics, (str, tuple)):
            metrics = [metrics]

        # Prepare tags
        if tags is None:
            tags = reself.stats_defaults.get('tags', 'all')
        if isinstance(tags, str) and tags == 'all':
            tags = None
        if isinstance(tags, (str, tuple)):
            tags = [tags]

        # Bring to the same shape
        new_metrics = []
        for i, metric in enumerate(metrics):
            if isinstance(metric, str):
                metric = (metric, reself.metrics[metric])
            if not isinstance(metric, tuple):
                raise TypeError(
                    f"Metric at index {i} must be either a string or a tuple")
            new_metrics.append(metric)
        metrics = new_metrics

        # Handle duplicate names
        metric_counts = Counter(list(map(lambda x: x[0], metrics)))
        metric_i = {k: -1 for k in metric_counts.keys()}
        metrics_dct = {}
        for i, (metric_name, _metric_settings) in enumerate(metrics):
            if metric_counts[metric_name] > 1:
                metric_i[metric_name] += 1
                metric_name = metric_name + '_' + str(metric_i[metric_name])
            metrics_dct[metric_name] = _metric_settings

        # Check metric_settings
        missed_keys = set(metric_settings.keys()).difference(
            set(metrics_dct.keys()))
        if len(missed_keys) > 0:
            raise ValueError(
                f"Keys {missed_keys} in metric_settings could not be matched with any metric"
            )

        # Merge settings
        opt_arg_names_dct = {}
        custom_arg_names_dct = {}
        resolved_self_dct = {}
        mapping_dct = {}
        for metric_name, _metric_settings in list(metrics_dct.items()):
            opt_settings = merge_dicts(
                {name: reself
                 for name in reself.self_aliases},
                dict(column=column,
                     group_by=group_by,
                     metric_name=metric_name,
                     agg_func=agg_func,
                     silence_warnings=silence_warnings,
                     to_timedelta=None), settings)
            _metric_settings = _metric_settings.copy()
            passed_metric_settings = metric_settings.get(metric_name, {})
            merged_settings = merge_dicts(opt_settings, _metric_settings,
                                          passed_metric_settings)
            metric_template_mapping = merged_settings.pop(
                'template_mapping', {})
            template_mapping_merged = merge_dicts(template_mapping,
                                                  metric_template_mapping)
            template_mapping_merged = deep_substitute(template_mapping_merged,
                                                      mapping=merged_settings)
            mapping = merge_dicts(template_mapping_merged, merged_settings)
            merged_settings = deep_substitute(merged_settings, mapping=mapping)

            # Filter by tag
            if tags is not None:
                in_tags = merged_settings.get('tags', None)
                if in_tags is None or not match_tags(tags, in_tags):
                    metrics_dct.pop(metric_name, None)
                    continue

            custom_arg_names = set(_metric_settings.keys()).union(
                set(passed_metric_settings.keys()))
            opt_arg_names = set(opt_settings.keys())
            custom_reself = reself.resolve_self(
                cond_kwargs=merged_settings,
                custom_arg_names=custom_arg_names,
                impacts_caching=True,
                silence_warnings=merged_settings['silence_warnings'])

            metrics_dct[metric_name] = merged_settings
            custom_arg_names_dct[metric_name] = custom_arg_names
            opt_arg_names_dct[metric_name] = opt_arg_names
            resolved_self_dct[metric_name] = custom_reself
            mapping_dct[metric_name] = mapping

        # Filter metrics
        for metric_name, _metric_settings in list(metrics_dct.items()):
            custom_reself = resolved_self_dct[metric_name]
            mapping = mapping_dct[metric_name]
            _silence_warnings = _metric_settings.get('silence_warnings')

            metric_filters = set()
            for k in _metric_settings.keys():
                filter_name = None
                if k.startswith('check_'):
                    filter_name = k[len('check_'):]
                elif k.startswith('inv_check_'):
                    filter_name = k[len('inv_check_'):]
                if filter_name is not None:
                    if filter_name not in filters:
                        raise ValueError(
                            f"Metric '{metric_name}' requires filter '{filter_name}'"
                        )
                    metric_filters.add(filter_name)

            for filter_name in metric_filters:
                filter_settings = filters[filter_name]
                _filter_settings = deep_substitute(filter_settings,
                                                   mapping=mapping)
                filter_func = _filter_settings['filter_func']
                warning_message = _filter_settings.get('warning_message', None)
                inv_warning_message = _filter_settings.get(
                    'inv_warning_message', None)
                to_check = _metric_settings.get('check_' + filter_name, False)
                inv_to_check = _metric_settings.get('inv_check_' + filter_name,
                                                    False)

                if to_check or inv_to_check:
                    whether_true = filter_func(custom_reself, _metric_settings)
                    to_remove = (to_check
                                 and not whether_true) or (inv_to_check
                                                           and whether_true)
                    if to_remove:
                        if to_check and warning_message is not None and not _silence_warnings:
                            warnings.warn(warning_message)
                        if inv_to_check and inv_warning_message is not None and not _silence_warnings:
                            warnings.warn(inv_warning_message)

                        metrics_dct.pop(metric_name, None)
                        custom_arg_names_dct.pop(metric_name, None)
                        opt_arg_names_dct.pop(metric_name, None)
                        resolved_self_dct.pop(metric_name, None)
                        mapping_dct.pop(metric_name, None)
                        break

        # Any metrics left?
        if len(metrics_dct) == 0:
            if not silence_warnings:
                warnings.warn("No metrics to calculate", stacklevel=2)
            return None

        # Compute stats
        arg_cache_dct = {}
        stats_dct = {}
        used_agg_func = False
        for i, (metric_name,
                _metric_settings) in enumerate(metrics_dct.items()):
            try:
                final_kwargs = _metric_settings.copy()
                opt_arg_names = opt_arg_names_dct[metric_name]
                custom_arg_names = custom_arg_names_dct[metric_name]
                custom_reself = resolved_self_dct[metric_name]

                # Clean up keys
                for k, v in list(final_kwargs.items()):
                    if k.startswith('check_') or k.startswith(
                            'inv_check_') or k in ('tags', ):
                        final_kwargs.pop(k, None)

                # Get metric-specific values
                _column = final_kwargs.get('column')
                _group_by = final_kwargs.get('group_by')
                _agg_func = final_kwargs.get('agg_func')
                _silence_warnings = final_kwargs.get('silence_warnings')
                if final_kwargs['to_timedelta'] is None:
                    final_kwargs[
                        'to_timedelta'] = custom_reself.wrapper.freq is not None
                to_timedelta = final_kwargs.get('to_timedelta')
                title = final_kwargs.pop('title', metric_name)
                calc_func = final_kwargs.pop('calc_func')
                resolve_calc_func = final_kwargs.pop('resolve_calc_func', True)
                post_calc_func = final_kwargs.pop('post_calc_func', None)
                use_caching = final_kwargs.pop('use_caching', True)
                fill_wrap_kwargs = final_kwargs.pop('fill_wrap_kwargs', False)
                if fill_wrap_kwargs:
                    final_kwargs['wrap_kwargs'] = merge_dicts(
                        dict(to_timedelta=to_timedelta,
                             silence_warnings=_silence_warnings),
                        final_kwargs.get('wrap_kwargs', None))
                apply_to_timedelta = final_kwargs.pop('apply_to_timedelta',
                                                      False)

                # Resolve calc_func
                if resolve_calc_func:
                    if not callable(calc_func):
                        passed_kwargs_out = {}

                        def _getattr_func(
                            obj: tp.Any,
                            attr: str,
                            args: tp.ArgsLike = None,
                            kwargs: tp.KwargsLike = None,
                            call_attr: bool = True,
                            _final_kwargs: tp.Kwargs = final_kwargs,
                            _opt_arg_names: tp.Set[str] = opt_arg_names,
                            _custom_arg_names: tp.Set[str] = custom_arg_names,
                            _arg_cache_dct: tp.Kwargs = arg_cache_dct
                        ) -> tp.Any:
                            if attr in final_kwargs:
                                return final_kwargs[attr]
                            if args is None:
                                args = ()
                            if kwargs is None:
                                kwargs = {}
                            if obj is custom_reself and _final_kwargs.pop(
                                    'resolve_path_' + attr, True):
                                if call_attr:
                                    return custom_reself.resolve_attr(
                                        attr,
                                        args=args,
                                        cond_kwargs={
                                            k: v
                                            for k, v in _final_kwargs.items()
                                            if k in _opt_arg_names
                                        },
                                        kwargs=kwargs,
                                        custom_arg_names=_custom_arg_names,
                                        cache_dct=_arg_cache_dct,
                                        use_caching=use_caching,
                                        passed_kwargs_out=passed_kwargs_out)
                                return getattr(obj, attr)
                            out = getattr(obj, attr)
                            if callable(out) and call_attr:
                                return out(*args, **kwargs)
                            return out

                        calc_func = custom_reself.deep_getattr(
                            calc_func,
                            getattr_func=_getattr_func,
                            call_last_attr=False)

                        if 'group_by' in passed_kwargs_out:
                            if 'pass_group_by' not in final_kwargs:
                                final_kwargs.pop('group_by', None)

                    # Resolve arguments
                    if callable(calc_func):
                        func_arg_names = get_func_arg_names(calc_func)
                        for k in func_arg_names:
                            if k not in final_kwargs:
                                if final_kwargs.pop('resolve_' + k, False):
                                    try:
                                        arg_out = custom_reself.resolve_attr(
                                            k,
                                            cond_kwargs=final_kwargs,
                                            custom_arg_names=custom_arg_names,
                                            cache_dct=arg_cache_dct,
                                            use_caching=use_caching)
                                    except AttributeError:
                                        continue
                                    final_kwargs[k] = arg_out
                        for k in list(final_kwargs.keys()):
                            if k in opt_arg_names:
                                if 'pass_' + k in final_kwargs:
                                    if not final_kwargs.get(
                                            'pass_' + k):  # first priority
                                        final_kwargs.pop(k, None)
                                elif k not in func_arg_names:  # second priority
                                    final_kwargs.pop(k, None)
                        for k in list(final_kwargs.keys()):
                            if k.startswith('pass_') or k.startswith(
                                    'resolve_'):
                                final_kwargs.pop(k, None)  # cleanup

                        # Call calc_func
                        out = calc_func(**final_kwargs)
                    else:
                        # calc_func is already a result
                        out = calc_func
                else:
                    # Do not resolve calc_func
                    out = calc_func(custom_reself, _metric_settings)

                # Call post_calc_func
                if post_calc_func is not None:
                    out = post_calc_func(custom_reself, out, _metric_settings)

                # Post-process and store the metric
                multiple = True
                if not isinstance(out, dict):
                    multiple = False
                    out = {None: out}
                for k, v in out.items():
                    # Resolve title
                    if multiple:
                        if title is None:
                            t = str(k)
                        else:
                            t = title + ': ' + str(k)
                    else:
                        t = title

                    # Check result type
                    if checks.is_any_array(v) and not checks.is_series(v):
                        raise TypeError(
                            "calc_func must return either a scalar for one column/group, "
                            "pd.Series for multiple columns/groups, or a dict of such. "
                            f"Not {type(v)}.")

                    # Handle apply_to_timedelta
                    if apply_to_timedelta and to_timedelta:
                        v = custom_reself.wrapper.to_timedelta(
                            v, silence_warnings=_silence_warnings)

                    # Select column or aggregate
                    if checks.is_series(v):
                        if _column is not None:
                            v = custom_reself.select_one_from_obj(
                                v,
                                custom_reself.wrapper.regroup(_group_by),
                                column=_column)
                        elif _agg_func is not None and agg_func is not None:
                            v = _agg_func(v)
                            used_agg_func = True
                        elif _agg_func is None and agg_func is not None:
                            if not _silence_warnings:
                                warnings.warn(
                                    f"Metric '{metric_name}' returned multiple values "
                                    f"despite having no aggregation function",
                                    stacklevel=2)
                            continue

                    # Store metric
                    if t in stats_dct:
                        if not _silence_warnings:
                            warnings.warn(f"Duplicate metric title '{t}'",
                                          stacklevel=2)
                    stats_dct[t] = v
            except Exception as e:
                warnings.warn(f"Metric '{metric_name}' raised an exception",
                              stacklevel=2)
                raise e

        # Return the stats
        if reself.wrapper.get_ndim(group_by=group_by) == 1:
            return pd.Series(stats_dct,
                             name=reself.wrapper.get_name(group_by=group_by))
        if column is not None:
            return pd.Series(stats_dct, name=column)
        if agg_func is not None:
            if used_agg_func and not silence_warnings:
                warnings.warn(
                    f"Object has multiple columns. Aggregating using {agg_func}. "
                    f"Pass column to select a single column/group.",
                    stacklevel=2)
            return pd.Series(stats_dct, name='agg_func_' + agg_func.__name__)
        new_index = reself.wrapper.grouper.get_columns(group_by=group_by)
        stats_df = pd.DataFrame(stats_dct, index=new_index)
        return stats_df

    # ############# Docs ############# #

    @classmethod
    def build_metrics_doc(cls, source_cls: tp.Optional[type] = None) -> str:
        """Build metrics documentation."""
        if source_cls is None:
            source_cls = StatsBuilderMixin
        return string.Template(
            inspect.cleandoc(get_dict_attr(source_cls,
                                           'metrics').__doc__)).substitute({
                                               'metrics':
                                               cls.metrics.to_doc(),
                                               'cls_name':
                                               cls.__name__
                                           })

    @classmethod
    def override_metrics_doc(cls,
                             __pdoc__: dict,
                             source_cls: tp.Optional[type] = None) -> None:
        """Call this method on each subclass that overrides `metrics`."""
        __pdoc__[cls.__name__ +
                 '.metrics'] = cls.build_metrics_doc(source_cls=source_cls)