예제 #1
0
    def __init__(self,
                 data: tp.Optional[tp.ArrayLike] = None,
                 x_labels: tp.Optional[tp.Labels] = None,
                 y_labels: tp.Optional[tp.Labels] = None,
                 is_x_category: bool = False,
                 is_y_category: bool = False,
                 trace_kwargs: tp.KwargsLike = None,
                 add_trace_kwargs: tp.KwargsLike = None,
                 fig: tp.Optional[tp.BaseFigure] = None,
                 **layout_kwargs) -> None:
        """Create a heatmap plot.

        Args:
            data (array_like): Data in any format that can be converted to NumPy.

                Must be of shape (`y_labels`, `x_labels`).
            x_labels (array_like): X-axis labels, corresponding to columns in pandas.
            y_labels (array_like): Y-axis labels, corresponding to index in pandas.
            is_x_category (bool): Whether X-axis is a categorical axis.
            is_y_category (bool): Whether Y-axis is a categorical axis.
            trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Heatmap`.
            add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
            fig (Figure or FigureWidget): Figure to add traces to.
            **layout_kwargs: Keyword arguments for layout.

        ## Example

        ```python-repl
        >>> import vectorbt as vbt

        >>> heatmap = vbt.plotting.Heatmap(
        ...     data=[[1, 2], [3, 4]],
        ...     x_labels=['a', 'b'],
        ...     y_labels=['x', 'y']
        ... )
        >>> heatmap.fig
        ```
        ![](/vectorbt/docs/img/Heatmap.svg)
        """
        Configured.__init__(self,
                            data=data,
                            x_labels=x_labels,
                            y_labels=y_labels,
                            trace_kwargs=trace_kwargs,
                            add_trace_kwargs=add_trace_kwargs,
                            fig=fig,
                            **layout_kwargs)

        from vectorbt._settings import settings
        layout_cfg = settings['plotting']['layout']

        if trace_kwargs is None:
            trace_kwargs = {}
        if add_trace_kwargs is None:
            add_trace_kwargs = {}
        if data is None:
            if x_labels is None or y_labels is None:
                raise ValueError(
                    "At least x_labels and y_labels must be passed")
        else:
            data = reshape_fns.to_2d(np.asarray(data))
        if x_labels is not None:
            x_labels = clean_labels(x_labels)
        if y_labels is not None:
            y_labels = clean_labels(y_labels)

        if fig is None:
            fig = make_figure()
            if 'width' in layout_cfg:
                # Calculate nice width and height
                max_width = layout_cfg['width']
                if data is not None:
                    x_len = data.shape[1]
                    y_len = data.shape[0]
                else:
                    x_len = len(x_labels)
                    y_len = len(y_labels)
                width = math.ceil(
                    renormalize(x_len / (x_len + y_len), (0, 1),
                                (0.3 * max_width, max_width)))
                width = min(width + 150, max_width)  # account for colorbar
                height = math.ceil(
                    renormalize(y_len / (x_len + y_len), (0, 1),
                                (0.3 * max_width, max_width)))
                height = min(height, max_width * 0.7)  # limit height
                fig.update_layout(width=width, height=height)

        heatmap = go.Heatmap(hoverongaps=False,
                             colorscale='Plasma',
                             x=x_labels,
                             y=y_labels)
        heatmap.update(**trace_kwargs)
        fig.add_trace(heatmap, **add_trace_kwargs)

        axis_kwargs = dict()
        if is_x_category:
            if fig.data[-1]['xaxis'] is not None:
                axis_kwargs['xaxis' +
                            fig.data[-1]['xaxis'][1:]] = dict(type='category')
            else:
                axis_kwargs['xaxis'] = dict(type='category')
        if is_y_category:
            if fig.data[-1]['yaxis'] is not None:
                axis_kwargs['yaxis' +
                            fig.data[-1]['yaxis'][1:]] = dict(type='category')
            else:
                axis_kwargs['yaxis'] = dict(type='category')
        fig.update_layout(**axis_kwargs)
        fig.update_layout(**layout_kwargs)

        TraceUpdater.__init__(self, fig, (fig.data[-1], ))

        if data is not None:
            self.update(data)
예제 #2
0
    def __init__(self,
                 data=None,
                 x_labels=None,
                 y_labels=None,
                 trace_kwargs=None,
                 add_trace_kwargs=None,
                 fig=None,
                 **layout_kwargs):
        """Create a heatmap plot.

        Args:
            data (array_like): Data in any format that can be converted to NumPy.

                Must be of shape (`y_labels`, `x_labels`).
            x_labels (array_like): X-axis labels, corresponding to columns in pandas.
            y_labels (array_like): Y-axis labels, corresponding to index in pandas.
            trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Heatmap`.
            add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
            fig (plotly.graph_objects.Figure): Figure to add traces to.
            **layout_kwargs: Keyword arguments for layout.

        ## Example

        ```python-repl
        >>> import vectorbt as vbt

        >>> heatmap = vbt.plotting.Heatmap(
        ...     data=[[1, 2], [3, 4]],
        ...     x_labels=['a', 'b'],
        ...     y_labels=['x', 'y']
        ... )
        >>> heatmap.fig
        ```
        ![](/vectorbt/docs/img/Heatmap.png)
        """
        Configured.__init__(self,
                            data=data,
                            x_labels=x_labels,
                            y_labels=y_labels,
                            trace_kwargs=trace_kwargs,
                            add_trace_kwargs=add_trace_kwargs,
                            fig=fig,
                            **layout_kwargs)

        from vectorbt.settings import layout

        if trace_kwargs is None:
            trace_kwargs = {}
        if add_trace_kwargs is None:
            add_trace_kwargs = {}
        if data is None:
            if x_labels is None or y_labels is None:
                raise ValueError(
                    "At least x_labels and y_labels must be passed")
        else:
            data = reshape_fns.to_2d(np.array(data))
        if x_labels is not None:
            x_labels = clean_labels(x_labels)
        if y_labels is not None:
            y_labels = clean_labels(y_labels)

        if fig is None:
            fig = FigureWidget()
            if 'width' in layout:
                # Calculate nice width and height
                max_width = layout['width']
                if data is not None:
                    x_len = data.shape[1]
                    y_len = data.shape[0]
                else:
                    x_len = len(x_labels)
                    y_len = len(y_labels)
                width = math.ceil(
                    renormalize(x_len / (x_len + y_len), [0, 1],
                                [0.3 * max_width, max_width]))
                width = min(width + 150, max_width)  # account for colorbar
                height = math.ceil(
                    renormalize(y_len / (x_len + y_len), [0, 1],
                                [0.3 * max_width, max_width]))
                height = min(height, max_width * 0.7)  # limit height
                fig.update_layout(width=width, height=height)

        fig.update_layout(**layout_kwargs)

        heatmap = go.Heatmap(hoverongaps=False,
                             colorscale='Plasma',
                             x=x_labels,
                             y=y_labels)
        heatmap.update(**trace_kwargs)
        fig.add_trace(heatmap, **add_trace_kwargs)

        TraceUpdater.__init__(self, fig, [fig.data[-1]])

        if data is not None:
            self.update(data)
예제 #3
0
def create_heatmap(data=None,
                   x_labels=None,
                   y_labels=None,
                   horizontal=False,
                   trace_kwargs=None,
                   return_trace_idx=False,
                   row=None,
                   col=None,
                   fig=None,
                   **layout_kwargs):
    """Create a heatmap plot.

    Args:
        data (array_like): Data in any format that can be converted to NumPy.

            Must be of shape (`y_labels`, `x_labels`).
        x_labels (array_like): X-axis labels, corresponding to columns in pandas.
        y_labels (array_like): Y-axis labels, corresponding to index in pandas.
        horizontal (bool): Plot horizontally.
        trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Heatmap`.
        return_trace_idx (bool): Whether to return trace index for `update_heatmap_data`.
        row (int): Row position.
        col (int): Column position.
        fig (plotly.graph_objects.Figure): Figure to add traces to.
        **layout_kwargs: Keyword arguments for layout.

    ## Example

    ```python-repl
    >>> import vectorbt as vbt

    >>> vbt.plotting.create_heatmap(
    ...     data=[[1, 2], [3, 4]],
    ...     x_labels=['a', 'b'],
    ...     y_labels=['x', 'y']
    ... )
    ```
    ![](/vectorbt/docs/img/create_heatmap.png)
    """
    from vectorbt.settings import layout

    if trace_kwargs is None:
        trace_kwargs = {}
    if data is None:
        if x_labels is None or y_labels is None:
            raise ValueError("At least x_labels and y_labels must be passed")
    else:
        data = reshape_fns.to_2d(np.array(data))
    if horizontal:
        y_labels, y_labels = y_labels, x_labels
        if data is not None:
            data = data.transpose()
        horizontal = False
    if fig is None:
        fig = CustomFigureWidget()
        if 'width' in layout:
            # Calculate nice width and height
            max_width = layout['width']
            if data is not None:
                x_len = data.shape[1]
                y_len = data.shape[0]
            else:
                x_len = len(x_labels)
                y_len = len(y_labels)
            width = math.ceil(
                renormalize(x_len / (x_len + y_len), [0, 1],
                            [0.3 * max_width, max_width]))
            width = min(width + 150, max_width)  # account for colorbar
            height = math.ceil(
                renormalize(y_len / (x_len + y_len), [0, 1],
                            [0.3 * max_width, max_width]))
            height = min(height, max_width * 0.7)  # limit height
            fig.update_layout(width=width, height=height)

    fig.update_layout(**layout_kwargs)
    heatmap = go.Heatmap(hoverongaps=False,
                         colorscale='Plasma',
                         x=x_labels,
                         y=y_labels)
    heatmap.update(**trace_kwargs)
    fig.add_trace(heatmap, row=row, col=col)
    trace_idx = len(fig.data) - 1
    if data is not None:
        update_heatmap_data(fig,
                            data,
                            horizontal=horizontal,
                            trace_idx=trace_idx)
    if return_trace_idx:
        return fig, trace_idx
    return fig