Exemplo n.º 1
0
    def update(self, data):
        """Update the trace data."""
        data = reshape_fns.to_2d(np.array(data))

        with self.fig.batch_update():
            if self.horizontal:
                self.traces[0].z = data.transpose()
            else:
                self.traces[0].z = data
Exemplo n.º 2
0
 def _indexing_func_meta(self, pd_indexing_func):
     """Perform indexing on `Trades` and also return metadata."""
     new_wrapper, new_records_arr, group_idxs, col_idxs = \
         Records._indexing_func_meta(self, pd_indexing_func)
     new_close = new_wrapper.wrap(to_2d(self.close, raw=True)[:, col_idxs],
                                  group_by=False)
     return self.copy(wrapper=new_wrapper,
                      records_arr=new_records_arr,
                      close=new_close), group_idxs, col_idxs
Exemplo n.º 3
0
def apply_and_concat_one(n, apply_func, *args, **kwargs):
    """For each value `i` from 0 to `n`, apply `apply_func` with arguments `*args` and `**kwargs`, 
    and concat the results along axis 1. 
    
    The result of `apply_func` must be a single 1-dim or 2-dim array.
    
    `apply_func` must accept arguments `i`, `*args` and `**kwargs`."""
    return np.hstack(
        [reshape_fns.to_2d(apply_func(i, *args, **kwargs)) for i in range(n)])
Exemplo n.º 4
0
    def align_to(self,
                 other: tp.SeriesFrame,
                 wrap_kwargs: tp.KwargsLike = None) -> tp.SeriesFrame:
        """Align to `other` on their axes.

        ## Example

        ```python-repl
        >>> import vectorbt as vbt
        >>> import pandas as pd

        >>> df1 = pd.DataFrame([[1, 2], [3, 4]], index=['x', 'y'], columns=['a', 'b'])
        >>> df1
           a  b
        x  1  2
        y  3  4

        >>> df2 = pd.DataFrame([[5, 6, 7, 8], [9, 10, 11, 12]], index=['x', 'y'],
        ...     columns=pd.MultiIndex.from_arrays([[1, 1, 2, 2], ['a', 'b', 'a', 'b']]))
        >>> df2
               1       2
           a   b   a   b
        x  5   6   7   8
        y  9  10  11  12

        >>> df1.vbt.align_to(df2)
              1     2
           a  b  a  b
        x  1  2  1  2
        y  3  4  3  4
        ```
        """
        checks.assert_instance_of(other, (pd.Series, pd.DataFrame))
        obj = reshape_fns.to_2d(self.obj)
        other = reshape_fns.to_2d(other)

        aligned_index = index_fns.align_index_to(obj.index, other.index)
        aligned_columns = index_fns.align_index_to(obj.columns, other.columns)
        obj = obj.iloc[aligned_index, aligned_columns]
        return self.wrapper.wrap(obj.values,
                                 group_by=False,
                                 **merge_dicts(
                                     dict(index=other.index,
                                          columns=other.columns), wrap_kwargs))
Exemplo n.º 5
0
    def apply_and_concat(self,
                         ntimes,
                         *args,
                         apply_func=None,
                         to_2d=False,
                         keys=None,
                         wrap_kwargs=None,
                         **kwargs):
        """Apply `apply_func` `ntimes` times and concatenate the results along columns.
        See `vectorbt.base.combine_fns.apply_and_concat_one`.

        Arguments `*args` and `**kwargs` will be directly passed to `apply_func`.
        If `to_2d` is True, 2-dimensional NumPy arrays will be passed, otherwise as is.
        Use `keys` as the outermost level.

        !!! note
            The resulted arrays to be concatenated must have the same shape as broadcast input arrays.

        ## Example

        ```python-repl
        >>> import vectorbt as vbt
        >>> import pandas as pd

        >>> df = pd.DataFrame([[3, 4], [5, 6]], index=['x', 'y'], columns=['a', 'b'])
        >>> df.vbt.apply_and_concat(3, [1, 2, 3],
        ...     apply_func=lambda i, a, b: a * b[i], keys=['c', 'd', 'e'])
              c       d       e
           a  b   a   b   a   b
        x  3  4   6   8   9  12
        y  5  6  10  12  15  18
        ```
        """
        checks.assert_not_none(apply_func)
        # Optionally cast to 2d array
        if to_2d:
            obj_arr = reshape_fns.to_2d(self._obj, raw=True)
        else:
            obj_arr = np.asarray(self._obj)
        if checks.is_numba_func(apply_func):
            result = combine_fns.apply_and_concat_one_nb(
                ntimes, apply_func, obj_arr, *args, **kwargs)
        else:
            result = combine_fns.apply_and_concat_one(ntimes, apply_func,
                                                      obj_arr, *args, **kwargs)
        # Build column hierarchy
        if keys is not None:
            new_columns = index_fns.combine_indexes(keys, self.wrapper.columns)
        else:
            top_columns = pd.Index(np.arange(ntimes), name='apply_idx')
            new_columns = index_fns.combine_indexes(top_columns,
                                                    self.wrapper.columns)
        return self.wrapper.wrap(result,
                                 group_by=False,
                                 **merge_dicts(dict(columns=new_columns),
                                               wrap_kwargs))
Exemplo n.º 6
0
    def combine_with(self, other, *args, combine_func=None, to_2d=False, broadcast_kwargs={}, **kwargs):
        """Combine both using `combine_func` into a Series/DataFrame of the same shape.

        All arguments will be broadcast using `vectorbt.base.reshape_fns.broadcast`
        with `broadcast_kwargs`.

        Arguments `*args` and `**kwargs` will be directly passed to `combine_func`.
        If `to_2d` is True, 2-dimensional NumPy arrays will be passed, otherwise as is.

        !!! note
            The resulted array must have the same shape as broadcast input arrays.

        ## Example

        ```python-repl
        >>> import vectorbt as vbt
        >>> import pandas as pd

        >>> sr = pd.Series([1, 2], index=['x', 'y'])
        >>> df = pd.DataFrame([[3, 4], [5, 6]], index=['x', 'y'], columns=['a', 'b'])
        >>> sr.vbt.combine_with(df, combine_func=lambda x, y: x + y)
           a  b
        x  4  5
        y  7  8
        ```
        """
        if isinstance(other, Base_Accessor):
            other = other._obj
        checks.assert_not_none(combine_func)
        if checks.is_numba_func(combine_func):
            # Numba requires writable arrays
            broadcast_kwargs = merge_dicts(dict(require_kwargs=dict(requirements='W')), broadcast_kwargs)
        new_obj, new_other = reshape_fns.broadcast(self._obj, other, **broadcast_kwargs)
        # Optionally cast to 2d array
        if to_2d:
            new_obj_arr = reshape_fns.to_2d(new_obj, raw=True)
            new_other_arr = reshape_fns.to_2d(new_other, raw=True)
        else:
            new_obj_arr = np.asarray(new_obj)
            new_other_arr = np.asarray(new_other)
        result = combine_func(new_obj_arr, new_other_arr, *args, **kwargs)
        return new_obj.vbt.wrapper.wrap(result)
Exemplo n.º 7
0
 def indexing_func_meta(
         self: OrdersT, pd_indexing_func: tp.PandasIndexingFunc,
         **kwargs) -> tp.Tuple[OrdersT, tp.MaybeArray, tp.Array1d]:
     """Perform indexing on `Orders` and return metadata."""
     new_wrapper, new_records_arr, group_idxs, col_idxs = \
         Records.indexing_func_meta(self, pd_indexing_func, **kwargs)
     new_close = new_wrapper.wrap(to_2d(self.close, raw=True)[:, col_idxs],
                                  group_by=False)
     return self.copy(wrapper=new_wrapper,
                      records_arr=new_records_arr,
                      close=new_close), group_idxs, col_idxs
Exemplo n.º 8
0
def apply_and_concat_one(n, apply_func, *args, show_progress=False, **kwargs):
    """For each value `i` from 0 to `n`, apply `apply_func` with arguments `*args` and `**kwargs`, 
    and concat the results along axis 1. 
    
    The result of `apply_func` must be a single 1-dim or 2-dim array.
    
    `apply_func` must accept arguments `i`, `*args` and `**kwargs`."""
    outputs = []
    for i in tqdm(range(n), disable=not show_progress):
        outputs.append(reshape_fns.to_2d(apply_func(i, *args, **kwargs)))
    return np.column_stack(outputs)
Exemplo n.º 9
0
def perform_init_checks(ts_list, output_list, param_list, mapper_list, name):
    """Perform checks on objects created by running or slicing an indicator."""
    checks.assert_type(ts_list[0], (pd.Series, pd.DataFrame))
    for ts in ts_list + output_list:
        checks.assert_same_meta(ts_list[0], ts)
    for params in param_list:
        checks.assert_same_shape(param_list[0], params)
    for mapper in mapper_list:
        checks.assert_type(mapper, pd.Series)
        checks.assert_same_index(
            reshape_fns.to_2d(ts_list[0]).iloc[0, :], mapper)
    checks.assert_type(name, str)
Exemplo n.º 10
0
    def apply(self,
              *args,
              apply_func: tp.Optional[tp.Callable] = None,
              keep_pd: bool = False,
              to_2d: bool = False,
              wrap_kwargs: tp.KwargsLike = None,
              **kwargs) -> tp.SeriesFrame:
        """Apply a function `apply_func`.

        Args:
            *args: Variable arguments passed to `apply_func`.
            apply_func (callable): Apply function.

                Can be Numba-compiled.
            keep_pd (bool): Whether to keep inputs as pandas objects, otherwise convert to NumPy arrays.
            to_2d (bool): Whether to reshape inputs to 2-dim arrays, otherwise keep as-is.
            wrap_kwargs (dict): Keyword arguments passed to `vectorbt.base.array_wrapper.ArrayWrapper.wrap`.
            **kwargs: Keyword arguments passed to `combine_func`.

        !!! note
            The resulted array must have the same shape as the original array.

        ## Example

        ```python-repl
        >>> import vectorbt as vbt
        >>> import pandas as pd

        >>> sr = pd.Series([1, 2], index=['x', 'y'])
        >>> sr2.vbt.apply(apply_func=lambda x: x ** 2)
        i2
        x2    1
        y2    4
        z2    9
        Name: a2, dtype: int64
        ```
        """
        checks.assert_not_none(apply_func)
        # Optionally cast to 2d array
        if to_2d:
            obj = reshape_fns.to_2d(self.obj, raw=not keep_pd)
        else:
            if not keep_pd:
                obj = np.asarray(self.obj)
            else:
                obj = self.obj
        result = apply_func(obj, *args, **kwargs)
        return self.wrapper.wrap(result,
                                 group_by=False,
                                 **merge_dicts({}, wrap_kwargs))
Exemplo n.º 11
0
def create_scatter(data=None, trace_names=None, x_labels=None,
                   trace_kwargs={}, fig=None, **layout_kwargs):
    """Create a scatter plot.

    Args:
        data (array_like): Data in any format that can be converted to NumPy.

            Must be of shape (`x_labels`, `trace_names`).
        trace_names (str or list of str): Trace names, corresponding to columns in pandas.
        x_labels (array_like): X-axis labels, corresponding to index in pandas.
        trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter`.
        fig (plotly.graph_objects.Figure): Figure to add traces to.
        **layout_kwargs: Keyword arguments for layout.
    Example:
        ```py
        import vectorbt as vbt

        vbt.plotting.create_scatter(
            data=[[1, 2], [3, 4]],
            trace_names=['a', 'b'],
            x_labels=['x', 'y']
        )
        ```
        ![](/vectorbt/docs/img/create_scatter.png)
        """
    if data is None:
        if trace_names is None:
            raise ValueError("At least trace_names must be passed")
    if trace_names is None:
        data = reshape_fns.to_2d(data)
        trace_names = [None] * data.shape[1]
    if isinstance(trace_names, str):
        trace_names = [trace_names]
    if fig is None:
        fig = CustomFigureWidget()
    fig.update_layout(**layout_kwargs)
    for i, trace_name in enumerate(trace_names):
        scatter = go.Scatter(
            x=x_labels,
            name=trace_name,
            showlegend=trace_name is not None
        )
        scatter.update(**(trace_kwargs[i] if isinstance(trace_kwargs, (list, tuple)) else trace_kwargs))
        fig.add_trace(scatter)

    if data is not None:
        trace_idx = np.arange(len(fig.data) - len(trace_names), len(fig.data))
        update_scatter_data(fig, data, trace_idx=trace_idx)
    return fig
Exemplo n.º 12
0
def create_hist(data=None, trace_names=None, horizontal=False,
                trace_kwargs={}, fig=None, **layout_kwargs):
    """Create a histogram plot.

    Args:
        data (array_like): Data in any format that can be converted to NumPy.

            Must be of shape (any, `trace_names`).
        trace_names (str or list of str): Trace names, corresponding to columns in pandas.
        horizontal (bool): Plot horizontally.
        trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Histogram`.
        fig (plotly.graph_objects.Figure): Figure to add traces to.
        **layout_kwargs: Keyword arguments for layout.
    Example:
        ```py
        import vectorbt as vbt

        vbt.plotting.create_hist(
            data=[[1, 2], [3, 4], [2, 1]],
            trace_names=['a', 'b']
        )
        ```
        ![](/vectorbt/docs/img/create_hist.png)
        """
    if data is None:
        if trace_names is None:
            raise ValueError("At least trace_names must be passed")
    if trace_names is None:
        data = reshape_fns.to_2d(data)
        trace_names = [None] * data.shape[1]
    if isinstance(trace_names, str):
        trace_names = [trace_names]
    if fig is None:
        fig = CustomFigureWidget()
        fig.update_layout(barmode='overlay')
    fig.update_layout(**layout_kwargs)
    for i, trace_name in enumerate(trace_names):
        hist = go.Histogram(
            opacity=0.75 if len(trace_names) > 1 else 1,
            name=trace_name,
            showlegend=trace_name is not None
        )
        hist.update(**(trace_kwargs[i] if isinstance(trace_kwargs, (list, tuple)) else trace_kwargs))
        fig.add_trace(hist)

    if data is not None:
        trace_idx = np.arange(len(fig.data) - len(trace_names), len(fig.data))
        update_hist_data(fig, data, horizontal=horizontal, trace_idx=trace_idx)
    return fig
Exemplo n.º 13
0
    def update(self, data):
        """Update the trace data.

        ## Example

        ```python-repl
        >>> bar.update([[2, 1], [4, 3]])
        ```
        ![](/vectorbt/docs/img/update_bar_data.png)
        """
        data = reshape_fns.to_2d(np.array(data))
        with self.fig.batch_update():
            for i, bar in enumerate(self.traces):
                bar.y = data[:, i]
                if bar.marker.colorscale is not None:
                    bar.marker.color = data[:, i]
Exemplo n.º 14
0
def apply_and_concat_one(n: int,
                         apply_func: tp.Callable,
                         *args,
                         show_progress: bool = False,
                         tqdm_kwargs: tp.KwargsLike = None,
                         **kwargs) -> tp.Array2d:
    """For each value `i` from 0 to `n`, apply `apply_func` with arguments `*args` and `**kwargs`,
    and concat the results along axis 1.

    The result of `apply_func` must be a single 1-dim or 2-dim array.

    `apply_func` must accept arguments `i`, `*args` and `**kwargs`."""
    if tqdm_kwargs is None:
        tqdm_kwargs = {}
    outputs = []
    for i in tqdm(range(n), disable=not show_progress, **tqdm_kwargs):
        outputs.append(reshape_fns.to_2d(apply_func(i, *args, **kwargs)))
    return np.column_stack(outputs)
Exemplo n.º 15
0
def update_heatmap_data(fig, data, horizontal=False, trace_idx=None):
    """Update the heatmap data.

    For keyword arguments, see `create_heatmap`.
    Optionally, specify the index of the trace `trace_idx` to update."""
    data = reshape_fns.to_2d(np.array(data))
    if trace_idx is None:
        if len(fig.data) > 1:
            raise ValueError("Figure contains more traces than data. Must pass trace_idx.")
        trace_idx = 0

    with fig.batch_update():
        heatmap = fig.data[trace_idx]
        if heatmap.type != 'heatmap':
            raise ValueError(f'Trace at index {trace_idx} is not a heatmap')
        if horizontal:
            heatmap.z = data.transpose()
        else:
            heatmap.z = data
Exemplo n.º 16
0
    def update(self, data):
        """Update the trace data."""
        data = reshape_fns.to_2d(np.array(data))

        with self.fig.batch_update():
            for i, trace in enumerate(self.traces):
                d = data[:, i]
                if self.remove_nan:
                    d = d[~np.isnan(d)]
                mask = np.full(d.shape, True)
                if self.from_quantile is not None:
                    mask &= d >= np.quantile(d, self.from_quantile)
                if self.to_quantile is not None:
                    mask &= d <= np.quantile(d, self.to_quantile)
                d = d[mask]
                if self.horizontal:
                    trace.x = d
                    trace.y = None
                else:
                    trace.x = None
                    trace.y = d
Exemplo n.º 17
0
def update_scatter_data(fig, data, trace_idx=None):
    """Update the scatter data.

    For keyword arguments, see `create_scatter`.
    Optionally, specify the index of the trace `trace_idx` to update (can be multiple)."""
    data = reshape_fns.to_2d(np.array(data))
    if trace_idx is None:
        if data.shape[1] < len(fig.data):
            raise ValueError("Figure contains more traces than data. Must pass trace_idx.")
        trace_idx = np.arange(len(fig.data))
    if not isinstance(trace_idx, Iterable):
        trace_idx = [trace_idx]
    if data.shape[1] > len(trace_idx):
        raise ValueError("Data contains more traces than trace_idx")

    with fig.batch_update():
        for i, _trace_idx in enumerate(trace_idx):
            scatter = fig.data[_trace_idx]
            if scatter.type != 'scatter':
                raise ValueError(f'Trace at index {trace_idx} is not a scatter')
            scatter.y = data[:, i]
Exemplo n.º 18
0
    def apply(self,
              *args,
              apply_func=None,
              to_2d=False,
              wrap_kwargs=None,
              **kwargs):
        """Apply a function `apply_func`.

        Arguments `*args` and `**kwargs` will be directly passed to `apply_func`.
        If `to_2d` is True, 2-dimensional NumPy arrays will be passed, otherwise as is.

        !!! note
            The resulted array must have the same shape as the original array.

        ## Example

        ```python-repl
        >>> import vectorbt as vbt
        >>> import pandas as pd

        >>> sr = pd.Series([1, 2], index=['x', 'y'])
        >>> sr2.vbt.apply(apply_func=lambda x: x ** 2)
        i2
        x2    1
        y2    4
        z2    9
        Name: a2, dtype: int64
        ```
        """
        checks.assert_not_none(apply_func)
        # Optionally cast to 2d array
        if to_2d:
            obj = reshape_fns.to_2d(self._obj, raw=True)
        else:
            obj = np.asarray(self._obj)
        result = apply_func(obj, *args, **kwargs)
        return self.wrapper.wrap(result,
                                 group_by=False,
                                 **merge_dicts({}, wrap_kwargs))
Exemplo n.º 19
0
    def __init__(self,
                 data: tp.Optional[tp.ArrayLike] = None,
                 x_labels: tp.Optional[tp.Labels] = None,
                 y_labels: tp.Optional[tp.Labels] = None,
                 is_x_category: bool = False,
                 is_y_category: bool = False,
                 trace_kwargs: tp.KwargsLike = None,
                 add_trace_kwargs: tp.KwargsLike = None,
                 fig: tp.Optional[tp.BaseFigure] = None,
                 **layout_kwargs) -> None:
        """Create a heatmap plot.

        Args:
            data (array_like): Data in any format that can be converted to NumPy.

                Must be of shape (`y_labels`, `x_labels`).
            x_labels (array_like): X-axis labels, corresponding to columns in pandas.
            y_labels (array_like): Y-axis labels, corresponding to index in pandas.
            is_x_category (bool): Whether X-axis is a categorical axis.
            is_y_category (bool): Whether Y-axis is a categorical axis.
            trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Heatmap`.
            add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
            fig (Figure or FigureWidget): Figure to add traces to.
            **layout_kwargs: Keyword arguments for layout.

        ## Example

        ```python-repl
        >>> import vectorbt as vbt

        >>> heatmap = vbt.plotting.Heatmap(
        ...     data=[[1, 2], [3, 4]],
        ...     x_labels=['a', 'b'],
        ...     y_labels=['x', 'y']
        ... )
        >>> heatmap.fig
        ```
        ![](/vectorbt/docs/img/Heatmap.svg)
        """
        Configured.__init__(self,
                            data=data,
                            x_labels=x_labels,
                            y_labels=y_labels,
                            trace_kwargs=trace_kwargs,
                            add_trace_kwargs=add_trace_kwargs,
                            fig=fig,
                            **layout_kwargs)

        from vectorbt._settings import settings
        layout_cfg = settings['plotting']['layout']

        if trace_kwargs is None:
            trace_kwargs = {}
        if add_trace_kwargs is None:
            add_trace_kwargs = {}
        if data is None:
            if x_labels is None or y_labels is None:
                raise ValueError(
                    "At least x_labels and y_labels must be passed")
        else:
            data = reshape_fns.to_2d(np.asarray(data))
        if x_labels is not None:
            x_labels = clean_labels(x_labels)
        if y_labels is not None:
            y_labels = clean_labels(y_labels)

        if fig is None:
            fig = make_figure()
            if 'width' in layout_cfg:
                # Calculate nice width and height
                max_width = layout_cfg['width']
                if data is not None:
                    x_len = data.shape[1]
                    y_len = data.shape[0]
                else:
                    x_len = len(x_labels)
                    y_len = len(y_labels)
                width = math.ceil(
                    renormalize(x_len / (x_len + y_len), (0, 1),
                                (0.3 * max_width, max_width)))
                width = min(width + 150, max_width)  # account for colorbar
                height = math.ceil(
                    renormalize(y_len / (x_len + y_len), (0, 1),
                                (0.3 * max_width, max_width)))
                height = min(height, max_width * 0.7)  # limit height
                fig.update_layout(width=width, height=height)

        heatmap = go.Heatmap(hoverongaps=False,
                             colorscale='Plasma',
                             x=x_labels,
                             y=y_labels)
        heatmap.update(**trace_kwargs)
        fig.add_trace(heatmap, **add_trace_kwargs)

        axis_kwargs = dict()
        if is_x_category:
            if fig.data[-1]['xaxis'] is not None:
                axis_kwargs['xaxis' +
                            fig.data[-1]['xaxis'][1:]] = dict(type='category')
            else:
                axis_kwargs['xaxis'] = dict(type='category')
        if is_y_category:
            if fig.data[-1]['yaxis'] is not None:
                axis_kwargs['yaxis' +
                            fig.data[-1]['yaxis'][1:]] = dict(type='category')
            else:
                axis_kwargs['yaxis'] = dict(type='category')
        fig.update_layout(**axis_kwargs)
        fig.update_layout(**layout_kwargs)

        TraceUpdater.__init__(self, fig, (fig.data[-1], ))

        if data is not None:
            self.update(data)
Exemplo n.º 20
0
    def __init__(self,
                 data=None,
                 trace_names=None,
                 horizontal=False,
                 remove_nan=True,
                 from_quantile=None,
                 to_quantile=None,
                 trace_kwargs=None,
                 add_trace_kwargs=None,
                 fig=None,
                 **layout_kwargs):
        """Create a box plot.

        For keyword arguments, see `Histogram`.

        ## Example

        ```python-repl
        >>> import vectorbt as vbt

        >>> box = vbt.plotting.Box(
        ...     data=[[1, 2], [3, 4], [2, 1]],
        ...     trace_names=['a', 'b']
        ... )
        >>> box.fig
        ```
        ![](/vectorbt/docs/img/Box.png)
        """
        Configured.__init__(self,
                            data=data,
                            trace_names=trace_names,
                            horizontal=horizontal,
                            remove_nan=remove_nan,
                            from_quantile=from_quantile,
                            to_quantile=to_quantile,
                            trace_kwargs=trace_kwargs,
                            add_trace_kwargs=add_trace_kwargs,
                            fig=fig,
                            **layout_kwargs)

        if trace_kwargs is None:
            trace_kwargs = {}
        if add_trace_kwargs is None:
            add_trace_kwargs = {}
        if data is None:
            if trace_names is None:
                raise ValueError("At least trace_names must be passed")
        if trace_names is None:
            data = reshape_fns.to_2d(data)
            trace_names = [None] * data.shape[1]
        if isinstance(trace_names, str):
            trace_names = [trace_names]

        if fig is None:
            fig = FigureWidget()
        fig.update_layout(**layout_kwargs)

        for i, trace_name in enumerate(trace_names):
            _trace_kwargs = trace_kwargs[i] if isinstance(
                trace_kwargs, (list, tuple)) else trace_kwargs
            trace_name = _trace_kwargs.pop('name', trace_name)
            if trace_name is not None:
                trace_name = str(trace_name)
            box = go.Box(name=trace_name, showlegend=trace_name is not None)
            box.update(**_trace_kwargs)
            fig.add_trace(box, **add_trace_kwargs)

        TraceUpdater.__init__(self, fig, fig.data[-len(trace_names):])
        self.horizontal = horizontal
        self.remove_nan = remove_nan
        self.from_quantile = from_quantile
        self.to_quantile = to_quantile

        if data is not None:
            self.update(data)
Exemplo n.º 21
0
    def __init__(self,
                 data=None,
                 trace_names=None,
                 horizontal=False,
                 remove_nan=True,
                 from_quantile=None,
                 to_quantile=None,
                 trace_kwargs=None,
                 add_trace_kwargs=None,
                 fig=None,
                 **layout_kwargs):
        """Create a histogram plot.

        Args:
            data (array_like): Data in any format that can be converted to NumPy.

                Must be of shape (any, `trace_names`).
            trace_names (str or list of str): Trace names, corresponding to columns in pandas.
            horizontal (bool): Plot horizontally.
            remove_nan (bool): Whether to remove NaN values.
            from_quantile (float): Filter out data points before this quantile.

                Should be in range `[0, 1]`.
            to_quantile (float): Filter out data points after this quantile.

                Should be in range `[0, 1]`.
            trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Histogram`.
            add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
            fig (plotly.graph_objects.Figure): Figure to add traces to.
            **layout_kwargs: Keyword arguments for layout.

        ## Example

        ```python-repl
        >>> import vectorbt as vbt

        >>> hist = vbt.plotting.Histogram(
        ...     data=[[1, 2], [3, 4], [2, 1]],
        ...     trace_names=['a', 'b']
        ... )
        >>> hist.fig
        ```
        ![](/vectorbt/docs/img/Histogram.png)
        """
        Configured.__init__(self,
                            data=data,
                            trace_names=trace_names,
                            horizontal=horizontal,
                            remove_nan=remove_nan,
                            from_quantile=from_quantile,
                            to_quantile=to_quantile,
                            trace_kwargs=trace_kwargs,
                            add_trace_kwargs=add_trace_kwargs,
                            fig=fig,
                            **layout_kwargs)

        if trace_kwargs is None:
            trace_kwargs = {}
        if add_trace_kwargs is None:
            add_trace_kwargs = {}
        if data is None:
            if trace_names is None:
                raise ValueError("At least trace_names must be passed")
        if trace_names is None:
            data = reshape_fns.to_2d(data)
            trace_names = [None] * data.shape[1]
        if isinstance(trace_names, str):
            trace_names = [trace_names]

        if fig is None:
            fig = FigureWidget()
            fig.update_layout(barmode='overlay')
        fig.update_layout(**layout_kwargs)

        for i, trace_name in enumerate(trace_names):
            _trace_kwargs = trace_kwargs[i] if isinstance(
                trace_kwargs, (list, tuple)) else trace_kwargs
            trace_name = _trace_kwargs.pop('name', trace_name)
            if trace_name is not None:
                trace_name = str(trace_name)
            hist = go.Histogram(opacity=0.75 if len(trace_names) > 1 else 1,
                                name=trace_name,
                                showlegend=trace_name is not None)
            hist.update(**_trace_kwargs)
            fig.add_trace(hist, **add_trace_kwargs)

        TraceUpdater.__init__(self, fig, fig.data[-len(trace_names):])
        self.horizontal = horizontal
        self.remove_nan = remove_nan
        self.from_quantile = from_quantile
        self.to_quantile = to_quantile

        if data is not None:
            self.update(data)
Exemplo n.º 22
0
    def __init__(self,
                 data=None,
                 trace_names=None,
                 x_labels=None,
                 trace_kwargs=None,
                 add_trace_kwargs=None,
                 fig=None,
                 **layout_kwargs):
        """Create a scatter plot.

        Args:
            data (array_like): Data in any format that can be converted to NumPy.

                Must be of shape (`x_labels`, `trace_names`).
            trace_names (str or list of str): Trace names, corresponding to columns in pandas.
            x_labels (array_like): X-axis labels, corresponding to index in pandas.
            trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter`.
            add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
            fig (plotly.graph_objects.Figure): Figure to add traces to.
            **layout_kwargs: Keyword arguments for layout.

        ## Example

        ```python-repl
        >>> import vectorbt as vbt

        >>> scatter = vbt.plotting.Scatter(
        ...     data=[[1, 2], [3, 4]],
        ...     trace_names=['a', 'b'],
        ...     x_labels=['x', 'y']
        ... )
        >>> scatter.fig
        ```
        ![](/vectorbt/docs/img/Scatter.png)
        """
        Configured.__init__(self,
                            data=data,
                            trace_names=trace_names,
                            x_labels=x_labels,
                            trace_kwargs=trace_kwargs,
                            add_trace_kwargs=add_trace_kwargs,
                            fig=fig,
                            **layout_kwargs)

        if trace_kwargs is None:
            trace_kwargs = {}
        if add_trace_kwargs is None:
            add_trace_kwargs = {}
        if data is None:
            if trace_names is None:
                raise ValueError("At least trace_names must be passed")
        if trace_names is None:
            data = reshape_fns.to_2d(data)
            trace_names = [None] * data.shape[1]
        if isinstance(trace_names, str):
            trace_names = [trace_names]
        if x_labels is not None:
            x_labels = clean_labels(x_labels)

        if fig is None:
            fig = FigureWidget()
        fig.update_layout(**layout_kwargs)

        for i, trace_name in enumerate(trace_names):
            _trace_kwargs = trace_kwargs[i] if isinstance(
                trace_kwargs, (list, tuple)) else trace_kwargs
            trace_name = _trace_kwargs.pop('name', trace_name)
            if trace_name is not None:
                trace_name = str(trace_name)
            scatter = go.Scatter(x=x_labels,
                                 name=trace_name,
                                 showlegend=trace_name is not None)
            scatter.update(**_trace_kwargs)
            fig.add_trace(scatter, **add_trace_kwargs)

        TraceUpdater.__init__(self, fig, fig.data[-len(trace_names):])

        if data is not None:
            self.update(data)
Exemplo n.º 23
0
def create_box(data=None,
               trace_names=None,
               horizontal=False,
               remove_nan=True,
               trace_kwargs=None,
               return_trace_idxs=False,
               row=None,
               col=None,
               fig=None,
               **layout_kwargs):
    """Create a box plot.

    Args:
        data (array_like): Data in any format that can be converted to NumPy.

            Must be of shape (any, `trace_names`).
        trace_names (str or list of str): Trace names, corresponding to columns in pandas.
        horizontal (bool): Plot horizontally.
        remove_nan (bool): Whether to remove NaN values.
        trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Box`.
        return_trace_idxs (bool): Whether to return trace indices for `update_box_data`.
        row (int): Row position.
        col (int): Column position.
        fig (plotly.graph_objects.Figure): Figure to add traces to.
        **layout_kwargs: Keyword arguments for layout.

    ## Example

    ```python-repl
    >>> import vectorbt as vbt

    >>> vbt.plotting.create_box(
    ...     data=[[1, 2], [3, 4], [2, 1]],
    ...     trace_names=['a', 'b']
    ... )
    ```
    ![](/vectorbt/docs/img/create_box.png)
    """
    if trace_kwargs is None:
        trace_kwargs = {}
    if data is None:
        if trace_names is None:
            raise ValueError("At least trace_names must be passed")
    if trace_names is None:
        data = reshape_fns.to_2d(data)
        trace_names = [None] * data.shape[1]
    if isinstance(trace_names, str):
        trace_names = [trace_names]
    if fig is None:
        fig = CustomFigureWidget()
        fig.update_layout(barmode='overlay')
    fig.update_layout(**layout_kwargs)
    for i, trace_name in enumerate(trace_names):
        if trace_name is not None:
            trace_name = str(trace_name)
        box = go.Box(name=trace_name, showlegend=trace_name is not None)
        box.update(**(trace_kwargs[i] if isinstance(trace_kwargs, (
            list, tuple)) else trace_kwargs))
        fig.add_trace(box, row=row, col=col)

    trace_idxs = list(range(len(fig.data) - len(trace_names), len(fig.data)))
    if data is not None:
        update_box_data(fig,
                        data,
                        horizontal=horizontal,
                        trace_idx=trace_idxs,
                        remove_nan=remove_nan)
    if return_trace_idxs:
        return fig, trace_idxs
    return fig
Exemplo n.º 24
0
def create_heatmap(data=None,
                   x_labels=None,
                   y_labels=None,
                   horizontal=False,
                   trace_kwargs=None,
                   return_trace_idx=False,
                   row=None,
                   col=None,
                   fig=None,
                   **layout_kwargs):
    """Create a heatmap plot.

    Args:
        data (array_like): Data in any format that can be converted to NumPy.

            Must be of shape (`y_labels`, `x_labels`).
        x_labels (array_like): X-axis labels, corresponding to columns in pandas.
        y_labels (array_like): Y-axis labels, corresponding to index in pandas.
        horizontal (bool): Plot horizontally.
        trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Heatmap`.
        return_trace_idx (bool): Whether to return trace index for `update_heatmap_data`.
        row (int): Row position.
        col (int): Column position.
        fig (plotly.graph_objects.Figure): Figure to add traces to.
        **layout_kwargs: Keyword arguments for layout.

    ## Example

    ```python-repl
    >>> import vectorbt as vbt

    >>> vbt.plotting.create_heatmap(
    ...     data=[[1, 2], [3, 4]],
    ...     x_labels=['a', 'b'],
    ...     y_labels=['x', 'y']
    ... )
    ```
    ![](/vectorbt/docs/img/create_heatmap.png)
    """
    from vectorbt.settings import layout

    if trace_kwargs is None:
        trace_kwargs = {}
    if data is None:
        if x_labels is None or y_labels is None:
            raise ValueError("At least x_labels and y_labels must be passed")
    else:
        data = reshape_fns.to_2d(np.array(data))
    if horizontal:
        y_labels, y_labels = y_labels, x_labels
        if data is not None:
            data = data.transpose()
        horizontal = False
    if fig is None:
        fig = CustomFigureWidget()
        if 'width' in layout:
            # Calculate nice width and height
            max_width = layout['width']
            if data is not None:
                x_len = data.shape[1]
                y_len = data.shape[0]
            else:
                x_len = len(x_labels)
                y_len = len(y_labels)
            width = math.ceil(
                renormalize(x_len / (x_len + y_len), [0, 1],
                            [0.3 * max_width, max_width]))
            width = min(width + 150, max_width)  # account for colorbar
            height = math.ceil(
                renormalize(y_len / (x_len + y_len), [0, 1],
                            [0.3 * max_width, max_width]))
            height = min(height, max_width * 0.7)  # limit height
            fig.update_layout(width=width, height=height)

    fig.update_layout(**layout_kwargs)
    heatmap = go.Heatmap(hoverongaps=False,
                         colorscale='Plasma',
                         x=x_labels,
                         y=y_labels)
    heatmap.update(**trace_kwargs)
    fig.add_trace(heatmap, row=row, col=col)
    trace_idx = len(fig.data) - 1
    if data is not None:
        update_heatmap_data(fig,
                            data,
                            horizontal=horizontal,
                            trace_idx=trace_idx)
    if return_trace_idx:
        return fig, trace_idx
    return fig
Exemplo n.º 25
0
    def combine(self,
                other: tp.MaybeTupleList[tp.Union[tp.ArrayLike,
                                                  "BaseAccessor"]],
                *args,
                allow_multiple: bool = True,
                combine_func: tp.Optional[tp.Callable] = None,
                keep_pd: bool = False,
                to_2d: bool = False,
                concat: bool = False,
                numba_loop: bool = False,
                use_ray: bool = False,
                broadcast: bool = True,
                broadcast_kwargs: tp.KwargsLike = None,
                keys: tp.Optional[tp.IndexLike] = None,
                wrap_kwargs: tp.KwargsLike = None,
                **kwargs) -> tp.SeriesFrame:
        """Combine with `other` using `combine_func`.

        Args:
            other (array_like): Object to combine this array with.
            *args: Variable arguments passed to `combine_func`.
            allow_multiple (bool): Whether a tuple/list will be considered as multiple objects in `other`.
            combine_func (callable): Function to combine two arrays.

                Can be Numba-compiled.
            keep_pd (bool): Whether to keep inputs as pandas objects, otherwise convert to NumPy arrays.
            to_2d (bool): Whether to reshape inputs to 2-dim arrays, otherwise keep as-is.
            concat (bool): Whether to concatenate the results along the column axis.
                Otherwise, pairwise combine into a Series/DataFrame of the same shape.

                If True, see `vectorbt.base.combine_fns.combine_and_concat`.
                If False, see `vectorbt.base.combine_fns.combine_multiple`.
            numba_loop (bool): Whether to loop using Numba.

                Set to True when iterating large number of times over small input,
                but note that Numba doesn't support variable keyword arguments.
            use_ray (bool): Whether to use Ray to execute `combine_func` in parallel.

                Only works with `numba_loop` set to False and `concat` is set to True.
                See `vectorbt.base.combine_fns.ray_apply` for related keyword arguments.
            broadcast (bool): Whether to broadcast all inputs.
            broadcast_kwargs (dict): Keyword arguments passed to `vectorbt.base.reshape_fns.broadcast`.
            keys (index_like): Outermost column level.
            wrap_kwargs (dict): Keyword arguments passed to `vectorbt.base.array_wrapper.ArrayWrapper.wrap`.
            **kwargs: Keyword arguments passed to `combine_func`.

        !!! note
            If `combine_func` is Numba-compiled, will broadcast using `WRITEABLE` and `C_CONTIGUOUS`
            flags, which can lead to an expensive computation overhead if passed objects are large and
            have different shape/memory order. You also must ensure that all objects have the same data type.

            Also remember to bring each in `*args` to a Numba-compatible format.

        ## Example

        ```python-repl
        >>> import vectorbt as vbt
        >>> import pandas as pd

        >>> sr = pd.Series([1, 2], index=['x', 'y'])
        >>> df = pd.DataFrame([[3, 4], [5, 6]], index=['x', 'y'], columns=['a', 'b'])

        >>> sr.vbt.combine(df, combine_func=lambda x, y: x + y)
           a  b
        x  4  5
        y  7  8

        >>> sr.vbt.combine([df, df*2], combine_func=lambda x, y: x + y)
            a   b
        x  10  13
        y  17  20

        >>> sr.vbt.combine([df, df*2], combine_func=lambda x, y: x + y, concat=True, keys=['c', 'd'])
              c       d
           a  b   a   b
        x  4  5   7   9
        y  7  8  12  14
        ```

        Use Ray for small inputs and large processing times:

        ```python-repl
        >>> def combine_func(a, b):
        ...     time.sleep(1)
        ...     return a + b

        >>> sr = pd.Series([1, 2, 3])

        >>> %timeit sr.vbt.combine([1, 1, 1], combine_func=combine_func)
        3.01 s ± 2.98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

        >>> %timeit sr.vbt.combine([1, 1, 1], combine_func=combine_func, concat=True, use_ray=True)
        1.02 s ± 2.32 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
        ```
        """
        if not allow_multiple or not isinstance(other, (tuple, list)):
            others = (other, )
        else:
            others = other
        others = tuple(
            map(lambda x: x.obj if isinstance(x, BaseAccessor) else x, others))
        checks.assert_not_none(combine_func)
        # Broadcast arguments
        if broadcast:
            if broadcast_kwargs is None:
                broadcast_kwargs = {}
            if checks.is_numba_func(combine_func):
                # Numba requires writeable arrays
                # Plus all of our arrays must be in the same order
                broadcast_kwargs = merge_dicts(
                    dict(require_kwargs=dict(requirements=['W', 'C'])),
                    broadcast_kwargs)
            new_obj, *new_others = reshape_fns.broadcast(
                self.obj, *others, **broadcast_kwargs)
        else:
            new_obj, new_others = self.obj, others
        if not checks.is_pandas(new_obj):
            new_obj = ArrayWrapper.from_shape(new_obj.shape).wrap(new_obj)
        # Optionally cast to 2d array
        if to_2d:
            inputs = tuple(
                map(lambda x: reshape_fns.to_2d(x, raw=not keep_pd),
                    (new_obj, *new_others)))
        else:
            if not keep_pd:
                inputs = tuple(
                    map(lambda x: np.asarray(x), (new_obj, *new_others)))
            else:
                inputs = new_obj, *new_others
        if len(inputs) == 2:
            result = combine_func(inputs[0], inputs[1], *args, **kwargs)
            return ArrayWrapper.from_obj(new_obj).wrap(
                result, **merge_dicts({}, wrap_kwargs))
        if concat:
            # Concat the results horizontally
            if checks.is_numba_func(combine_func) and numba_loop:
                if use_ray:
                    raise ValueError("Ray cannot be used within Numba")
                for i in range(1, len(inputs)):
                    checks.assert_meta_equal(inputs[i - 1], inputs[i])
                result = combine_fns.combine_and_concat_nb(
                    inputs[0], inputs[1:], combine_func, *args, **kwargs)
            else:
                if use_ray:
                    result = combine_fns.combine_and_concat_ray(
                        inputs[0], inputs[1:], combine_func, *args, **kwargs)
                else:
                    result = combine_fns.combine_and_concat(
                        inputs[0], inputs[1:], combine_func, *args, **kwargs)
            columns = ArrayWrapper.from_obj(new_obj).columns
            if keys is not None:
                new_columns = index_fns.combine_indexes([keys, columns])
            else:
                top_columns = pd.Index(np.arange(len(new_others)),
                                       name='combine_idx')
                new_columns = index_fns.combine_indexes([top_columns, columns])
            return ArrayWrapper.from_obj(new_obj).wrap(
                result, **merge_dicts(dict(columns=new_columns), wrap_kwargs))
        else:
            # Combine arguments pairwise into one object
            if use_ray:
                raise ValueError("Ray cannot be used with concat=False")
            if checks.is_numba_func(combine_func) and numba_loop:
                for i in range(1, len(inputs)):
                    checks.assert_dtype_equal(inputs[i - 1], inputs[i])
                result = combine_fns.combine_multiple_nb(
                    inputs, combine_func, *args, **kwargs)
            else:
                result = combine_fns.combine_multiple(inputs, combine_func,
                                                      *args, **kwargs)
            return ArrayWrapper.from_obj(new_obj).wrap(
                result, **merge_dicts({}, wrap_kwargs))
Exemplo n.º 26
0
    def apply_and_concat(self,
                         ntimes: int,
                         *args,
                         apply_func: tp.Optional[tp.Callable] = None,
                         keep_pd: bool = False,
                         to_2d: bool = False,
                         numba_loop: bool = False,
                         use_ray: bool = False,
                         keys: tp.Optional[tp.IndexLike] = None,
                         wrap_kwargs: tp.KwargsLike = None,
                         **kwargs) -> tp.Frame:
        """Apply `apply_func` `ntimes` times and concatenate the results along columns.
        See `vectorbt.base.combine_fns.apply_and_concat_one`.

        Args:
            ntimes (int): Number of times to call `apply_func`.
            *args: Variable arguments passed to `apply_func`.
            apply_func (callable): Apply function.

                Can be Numba-compiled.
            keep_pd (bool): Whether to keep inputs as pandas objects, otherwise convert to NumPy arrays.
            to_2d (bool): Whether to reshape inputs to 2-dim arrays, otherwise keep as-is.
            numba_loop (bool): Whether to loop using Numba.

                Set to True when iterating large number of times over small input,
                but note that Numba doesn't support variable keyword arguments.
            use_ray (bool): Whether to use Ray to execute `combine_func` in parallel.

                Only works with `numba_loop` set to False and `concat` is set to True.
                See `vectorbt.base.combine_fns.ray_apply` for related keyword arguments.
            keys (index_like): Outermost column level.
            wrap_kwargs (dict): Keyword arguments passed to `vectorbt.base.array_wrapper.ArrayWrapper.wrap`.
            **kwargs: Keyword arguments passed to `combine_func`.

        !!! note
            The resulted arrays to be concatenated must have the same shape as broadcast input arrays.

        ## Example

        ```python-repl
        >>> import vectorbt as vbt
        >>> import pandas as pd

        >>> df = pd.DataFrame([[3, 4], [5, 6]], index=['x', 'y'], columns=['a', 'b'])
        >>> df.vbt.apply_and_concat(3, [1, 2, 3],
        ...     apply_func=lambda i, a, b: a * b[i], keys=['c', 'd', 'e'])
              c       d       e
           a  b   a   b   a   b
        x  3  4   6   8   9  12
        y  5  6  10  12  15  18
        ```

        Use Ray for small inputs and large processing times:

        ```python-repl
        >>> def apply_func(i, a):
        ...     time.sleep(1)
        ...     return a

        >>> sr = pd.Series([1, 2, 3])

        >>> %timeit sr.vbt.apply_and_concat(3, apply_func=apply_func)
        3.01 s ± 2.15 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

        >>> %timeit sr.vbt.apply_and_concat(3, apply_func=apply_func, use_ray=True)
        1.01 s ± 2.31 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
        ```
        """
        checks.assert_not_none(apply_func)
        # Optionally cast to 2d array
        if to_2d:
            obj = reshape_fns.to_2d(self.obj, raw=not keep_pd)
        else:
            if not keep_pd:
                obj = np.asarray(self.obj)
            else:
                obj = self.obj
        if checks.is_numba_func(apply_func) and numba_loop:
            if use_ray:
                raise ValueError("Ray cannot be used within Numba")
            result = combine_fns.apply_and_concat_one_nb(
                ntimes, apply_func, obj, *args, **kwargs)
        else:
            if use_ray:
                result = combine_fns.apply_and_concat_one_ray(
                    ntimes, apply_func, obj, *args, **kwargs)
            else:
                result = combine_fns.apply_and_concat_one(
                    ntimes, apply_func, obj, *args, **kwargs)
        # Build column hierarchy
        if keys is not None:
            new_columns = index_fns.combine_indexes(
                [keys, self.wrapper.columns])
        else:
            top_columns = pd.Index(np.arange(ntimes), name='apply_idx')
            new_columns = index_fns.combine_indexes(
                [top_columns, self.wrapper.columns])
        return self.wrapper.wrap(result,
                                 group_by=False,
                                 **merge_dicts(dict(columns=new_columns),
                                               wrap_kwargs))
Exemplo n.º 27
0
    def __init__(self,
                 data: tp.Optional[tp.ArrayLike] = None,
                 trace_names: tp.TraceNames = None,
                 x_labels: tp.Optional[tp.Labels] = None,
                 trace_kwargs: tp.KwargsLikeSequence = None,
                 add_trace_kwargs: tp.KwargsLike = None,
                 fig: tp.Optional[tp.BaseFigure] = None,
                 **layout_kwargs) -> None:
        """Create a bar plot.

        Args:
            data (array_like): Data in any format that can be converted to NumPy.

                Must be of shape (`x_labels`, `trace_names`).
            trace_names (str or list of str): Trace names, corresponding to columns in pandas.
            x_labels (array_like): X-axis labels, corresponding to index in pandas.
            trace_kwargs (dict or list of dict): Keyword arguments passed to `plotly.graph_objects.Bar`.

                Can be specified per trace as a sequence of dicts.
            add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
            fig (Figure or FigureWidget): Figure to add traces to.
            **layout_kwargs: Keyword arguments for layout.

        ## Example

        ```python-repl
        >>> import vectorbt as vbt

        >>> bar = vbt.plotting.Bar(
        ...     data=[[1, 2], [3, 4]],
        ...     trace_names=['a', 'b'],
        ...     x_labels=['x', 'y']
        ... )
        >>> bar.fig
        ```
        ![](/vectorbt/docs/img/Bar.svg)
        """
        Configured.__init__(self,
                            data=data,
                            trace_names=trace_names,
                            x_labels=x_labels,
                            trace_kwargs=trace_kwargs,
                            add_trace_kwargs=add_trace_kwargs,
                            fig=fig,
                            **layout_kwargs)

        if trace_kwargs is None:
            trace_kwargs = {}
        if add_trace_kwargs is None:
            add_trace_kwargs = {}
        if data is None:
            if trace_names is None:
                raise ValueError("At least trace_names must be passed")
        if trace_names is None:
            data = reshape_fns.to_2d(np.asarray(data))
            trace_names = [None] * data.shape[1]
        if isinstance(trace_names, str):
            trace_names = [trace_names]
        if x_labels is not None:
            x_labels = clean_labels(x_labels)

        if fig is None:
            fig = make_figure()
        fig.update_layout(**layout_kwargs)

        for i, trace_name in enumerate(trace_names):
            _trace_kwargs = resolve_dict(trace_kwargs, i=i)
            trace_name = _trace_kwargs.pop('name', trace_name)
            if trace_name is not None:
                trace_name = str(trace_name)
            bar = go.Bar(x=x_labels,
                         name=trace_name,
                         showlegend=trace_name is not None)
            bar.update(**_trace_kwargs)
            fig.add_trace(bar, **add_trace_kwargs)

        TraceUpdater.__init__(self, fig, fig.data[-len(trace_names):])

        if data is not None:
            self.update(data)
Exemplo n.º 28
0
    def __init__(self,
                 data=None,
                 x_labels=None,
                 y_labels=None,
                 trace_kwargs=None,
                 add_trace_kwargs=None,
                 fig=None,
                 **layout_kwargs):
        """Create a heatmap plot.

        Args:
            data (array_like): Data in any format that can be converted to NumPy.

                Must be of shape (`y_labels`, `x_labels`).
            x_labels (array_like): X-axis labels, corresponding to columns in pandas.
            y_labels (array_like): Y-axis labels, corresponding to index in pandas.
            trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Heatmap`.
            add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
            fig (plotly.graph_objects.Figure): Figure to add traces to.
            **layout_kwargs: Keyword arguments for layout.

        ## Example

        ```python-repl
        >>> import vectorbt as vbt

        >>> heatmap = vbt.plotting.Heatmap(
        ...     data=[[1, 2], [3, 4]],
        ...     x_labels=['a', 'b'],
        ...     y_labels=['x', 'y']
        ... )
        >>> heatmap.fig
        ```
        ![](/vectorbt/docs/img/Heatmap.png)
        """
        Configured.__init__(self,
                            data=data,
                            x_labels=x_labels,
                            y_labels=y_labels,
                            trace_kwargs=trace_kwargs,
                            add_trace_kwargs=add_trace_kwargs,
                            fig=fig,
                            **layout_kwargs)

        from vectorbt.settings import layout

        if trace_kwargs is None:
            trace_kwargs = {}
        if add_trace_kwargs is None:
            add_trace_kwargs = {}
        if data is None:
            if x_labels is None or y_labels is None:
                raise ValueError(
                    "At least x_labels and y_labels must be passed")
        else:
            data = reshape_fns.to_2d(np.array(data))
        if x_labels is not None:
            x_labels = clean_labels(x_labels)
        if y_labels is not None:
            y_labels = clean_labels(y_labels)

        if fig is None:
            fig = FigureWidget()
            if 'width' in layout:
                # Calculate nice width and height
                max_width = layout['width']
                if data is not None:
                    x_len = data.shape[1]
                    y_len = data.shape[0]
                else:
                    x_len = len(x_labels)
                    y_len = len(y_labels)
                width = math.ceil(
                    renormalize(x_len / (x_len + y_len), [0, 1],
                                [0.3 * max_width, max_width]))
                width = min(width + 150, max_width)  # account for colorbar
                height = math.ceil(
                    renormalize(y_len / (x_len + y_len), [0, 1],
                                [0.3 * max_width, max_width]))
                height = min(height, max_width * 0.7)  # limit height
                fig.update_layout(width=width, height=height)

        fig.update_layout(**layout_kwargs)

        heatmap = go.Heatmap(hoverongaps=False,
                             colorscale='Plasma',
                             x=x_labels,
                             y=y_labels)
        heatmap.update(**trace_kwargs)
        fig.add_trace(heatmap, **add_trace_kwargs)

        TraceUpdater.__init__(self, fig, [fig.data[-1]])

        if data is not None:
            self.update(data)
Exemplo n.º 29
0
    def update(self, data):
        """Update the trace data."""
        data = reshape_fns.to_2d(np.array(data))

        with self.fig.batch_update():
            self.traces[0].z = data
Exemplo n.º 30
0
    def update(self, data: tp.ArrayLike) -> None:
        """Update the trace data."""
        data = reshape_fns.to_2d(np.asarray(data))

        with self.fig.batch_update():
            self.traces[0].z = data