Exemplo n.º 1
0
def finalize_plot(axes: Axes, title: str, save: bool, name: str):
    axes.set_title(title)
    plt.tight_layout()
    if save:
        axes.get_figure().savefig(name)
    plt.show()
    plt.close()
Exemplo n.º 2
0
def format_date_labels(ax: Axes, rot):
    # mini version of autofmt_xdate
    for label in ax.get_xticklabels():
        label.set_ha("right")
        label.set_rotation(rot)
    fig = ax.get_figure()
    maybe_adjust_figure(fig, bottom=0.2)
Exemplo n.º 3
0
def _add_colorbar(ax: Axes, cmap: colors.Colormap, cmap_data: np.ndarray,
                  norm: colors.Normalize):
    """Show a colorbar right of the plot."""
    fig = ax.get_figure()
    mappable = cm.ScalarMappable(cmap=cmap, norm=norm)
    mappable.set_array(cmap_data)  # TODO: Or what???
    fig.colorbar(mappable, ax=ax)
Exemplo n.º 4
0
def _add_labels(ax: Axes, h: Union[Histogram1D, Histogram2D], kwargs: dict):
    """Add axis and plot labels.
    
    TODO: Document kwargs
    """
    title = kwargs.pop("title", h.title)
    xlabel = kwargs.pop("xlabel", h.axis_names[0])
    ylabel = kwargs.pop("ylabel",
                        h.axis_names[1] if len(h.axis_names) == 2 else None)

    if title:
        ax.set_title(title)
    if xlabel:
        ax.set_xlabel(xlabel)
    if ylabel:
        ax.set_ylabel(ylabel)
    ax.get_figure().tight_layout()
Exemplo n.º 5
0
def get_aspect(ax: Axes) -> float:
    """Returns the aspect ratio of an axes in a figure. This works around the 
    problem of matplotlib's ``ax.get_aspect`` returning strings if set to 
    'equal' for example

    Parameters
    ----------
    ax : Axes
        Matplotlib Axes object

    Returns
    -------
    float
        aspect ratio

    Notes
    -----

    :Author:
        Lucas Sawade ([email protected])

    :Last Modified:
        2021.01.20 11.30

    """

    # Total figure size
    figW, figH = ax.get_figure().get_size_inches()

    # Axis size on figure
    _, _, w, h = ax.get_position().bounds

    # Ratio of display units
    disp_ratio = (figH * h) / (figW * w)

    return disp_ratio
Exemplo n.º 6
0
def axes(host_ax: Axes,
         x_label: str = None,
         y_label_left: Union[str, Collection[str]] = None,
         y_label_right: Optional[Union[str, Collection[str]]] = None,
         offset: float = 0.055,
         tight: bool = None) -> List[Axes]:
    """
    Configure `Axes <https://matplotlib.org/api/axes_api.html#axis-labels-title-and-legend>`_ and its
    `Axis-es <https://matplotlib.org/api/axis_api.html>`_.

    :param host_ax: Axes (retrieved with :func:`subplots`).
    :param x_label: Label of the x-axis.
    :param y_label_left: If string, label of the only/single left y-axis. If multiple strings, multiple left y-axes
                         are created with 0 being the left-most (the farthest away from the plotting canvas).
    :param y_label_right: If provided and provided as string, second y-axis is generated. If provided as multiple
                          strings, multiple right y-axes are created with 0 being the left-most (the nearest to the
                          plotting canvas).
    :param offset: If multiple axes are generated on the same side, this offset is used to place the spine at the
                   specified Axes coordinate (from 0.0-1.0). As examples if this is 0.1, axes on left will have
                   positions -0.2, -0.1 and 0.0 and axes on the right 1.0, 1.1 and 1.2.
    :param tight: TODO: Whether to use tight layout.

    :return: Generated and configured (multiple) Axes, from the left-most Axes to the right-most Axes. The left-most
             axis (index 0) is instance of the provided host axis.
    """
    def make_patch_spines_invisible(_ax: Axes):
        """
        Having been created by :meth:`twinx`, second axis has its frame off, so the line of its detached spine is
        invisible. First, activate the frame but make the patch and spines invisible.
        """
        _ax.set_frame_on(True)
        _ax.patch.set_visible(False)
        for sp in _ax.spines.values():
            sp.set_visible(False)

    # Configure x-axis
    if x_label is not None:
        host_ax.set_xlabel(x_label)
    host_ax.grid(True, 'both')

    # Generate y-axes (that is only partly true as Axes instance contains both axis).
    axs_left: List[Axes] = []
    axs_right: List[Axes] = []

    # By default there is only one axis on the left, which is in fact already the host axis.
    if isinstance(y_label_left, str):
        host_ax.set_ylabel(y_label_left)
        axs_left.append(host_ax)
    elif y_label_left:
        # Narrow down the figure to prevent grid sticking outwards from the plotting canvas.
        # .tight_layout() below shall do that
        # > if tight:
        # >     host_ax.figure.subplots_adjust(left=(len(y_label_left) - 1) * offset)
        # Create multiple axes on the left
        for i, label in enumerate(y_label_left):
            # The left-most axis is the host axis.
            ax = host_ax if (i == 0) else host_ax.twinx()
            axs_left.append(ax)
            # Set proper position (offset the spine).
            ax.spines["left"].set_position(
                ("axes", -(len(y_label_left) - i - 1) * offset))
            # Patch spines are modified only for the axes not touching the plotting canvas (right-most).
            if i < (len(y_label_left) - 1):
                make_patch_spines_invisible(ax)
            ax.spines["left"].set_visible(True)  # Show the left spine.
            ax.yaxis.set_label_position("left")
            ax.yaxis.set_ticks_position("left")
            ax.set_ylabel(label)

    if y_label_right:
        if isinstance(y_label_right, str):
            y_label_right = [y_label_right]
        # Narrow down the figure to prevent grid sticking outwards from the plotting canvas.
        # .tight_layout() below shall do that
        # > if tight:
        # >     host_ax.figure.subplots_adjust(right=1 - (len(y_label_right) - 1) * offset)
        # Create multiple axes on the right
        for i, label in enumerate(y_label_right):
            ax = host_ax.twinx()
            axs_right.append(ax)
            # Set proper position (offset the spine).
            ax.spines["right"].set_position(("axes", 1 + i * offset))
            # Patch spines are modified only for the axes not touching the plotting canvas (left-most).
            if i > 0:
                make_patch_spines_invisible(ax)
            ax.spines["right"].set_visible(True)  # Show the right spine.
            ax.yaxis.set_label_position("right")
            ax.yaxis.set_ticks_position("right")
            ax.set_ylabel(label)

    axs = axs_left + axs_right

    # Configure y-axes
    for ax_idx, ax_ in enumerate(axs):
        # Disable scientific notation
        ax_.get_yaxis().get_major_formatter().set_useOffset(False)

    fig = host_ax.get_figure()
    # noinspection PyProtectedMember
    if fig._suptitle and fig._suptitle.get_text():
        # https://stackoverflow.com/a/45161551/5616255
        fig.tight_layout(rect=[0, 0.03, 1, 0.95])
    else:
        host_ax.get_figure().tight_layout()

    return axs
Exemplo n.º 7
0
def plot3D(
          ax: Axes,
          correlations: Correlations,
          atoms: Tuple[str],
          *,
          color: str = 'tab20',
          nolabels: bool = False,
          showlegend: bool = False,
          offset: int = 0,
          project: int = 2,
          slices: int = 16) \
          -> Slices3D:
    '''
    Factory for the Slices3D object. Splits data into intervals and sets up scrolling.
    Has to return the Slices3D object to the main function, otherwise it is GC'd and scrolling
    stops working.
    '''
    # Get mins and maxes so all plots have the same xy coords
    # correlations formated [sequence_number, residue_type, (chemical_shifts)]
    shifts = list(zip(*(c[2] for c in correlations)))
    maxes = [max(dim) for dim in shifts]
    mins = [min(dim) for dim in shifts]

    # Find intervals for bins
    projection_max = maxes.pop(project)
    projection_min = mins.pop(project)
    bin_width = (projection_max - projection_min) / slices
    cutoffs = [projection_min + i * bin_width for i in range(slices)]
    # Add 1 to not have to deal with < vs <= for last bin
    intervals = list(zip(cutoffs, cutoffs[1:] + [projection_max + 1]))

    # Sets axes equal for all plots
    x_padding = 0.1 * (maxes[0] - mins[0])
    y_padding = 0.1 * (maxes[1] - mins[1])
    ax.set_xlim((mins[0] - x_padding, maxes[0] + x_padding))
    ax.set_ylim((mins[1] - y_padding, maxes[1] + y_padding))

    sliced_data = []
    for low, high in intervals:
        sliced = []
        # chemical shifts are index 2
        for sequence_number, residue_type, chemical_shifts in sorted(
                correlations, key=lambda x: x[2][project]):
            if low <= chemical_shifts[project] < high:
                # Remove projected index from list of indices used for plotting
                plot_point = list(chemical_shifts[:])
                plot_point.pop(project)
                sliced.append([sequence_number, residue_type, plot_point])
        sliced_data.append(sliced)

    atoms = atoms[:project] + atoms[project + 1:]

    three_d = Slices3D(ax,
                       sliced_data,
                       intervals,
                       atoms=atoms,
                       color=color,
                       nolabels=nolabels,
                       showlegend=showlegend,
                       offset=offset)
    ax.get_figure().canvas.mpl_connect('scroll_event', three_d.on_scroll)
    return three_d
Exemplo n.º 8
0
def add_scalebar(
    ax: Axes,
    direction: str,
    length: float,
    label: str,
    align: str = "left",
    pos_along: float = 0.1,
    pos_across: float = 0.12,
    label_size: float = 8.5,
    label_align: str = "center",
    label_pad: float = 0.2,
    label_offset: float = -0.75,
    lw: float = 1.4,
    brackets: bool = True,
    bracket_length: float = 3,
    label_background: dict = dict(facecolor="white", edgecolor="none",
                                  alpha=1),
    in_layout: bool = True,
):
    """
    Indicate horizontal or vertical scale of an axes.
    
    :param direction:  Orientation of the scalebar. "h" (for horizontal) to
                label the x-axis, "v" (for vertical) to label the y-axis.
    :param length:  Size of the scalebar, in data coordinates.
    :param label:  Text to label the scalebar with.
    :param align:  Alignment of the scalebar, in the direction of the scalebar.
    :param pos_along:  Position of the center of the scalebar, in the direction
                of the scalebar. In axes coordinates.
    :param pos_across:  Position of the center of the scalebar, in the
                direction orthogonal to the scalebar. In axes coordinates.
    :param label_size:  Fontsize (text height) of the label. In points (1/72-th
                of an inch).
    :param label_align:  {"left", "center", "right"}. Alignment of the label
                text along the writing direction, relative to the scalebar.
    :param label_pad:   When `label_align` is "left" or "right": padding of the
                label against the edge of the scalebar. In `label_size` units.
                Ignored when `label_align` is "center".
    :param label_offset:  Position of the center-line of the label text,
                relative to the scalebar. In `label_size` units. Negative
                values position the label to the left or below the scalebar.
    :param lw:  Line width of the scalebar.
    :param brackets:  Whether to end the scalebar in short perpendicular lines
                at both ends.
    :param bracket_length:  In points (1/72-th of an inch).
    :param label_background:  Passed on to `set_bbox()` of the label text.
    :param in_layout:  Whether to take the scalebar into account when calling
                `fig.tight_layout()`.
    """
    if direction == "h":
        trans = ax.get_xaxis_transform()
        lims = ax.get_xlim()
        orthogonal_direction = "v"
    elif direction == "v":
        trans = ax.get_yaxis_transform()
        lims = ax.get_ylim()
        orthogonal_direction = "h"
    # `bar_start` and `bar_end` are in data coordinates.
    bar_start = lims[0] + (pos_along * diff(lims))
    bar_end = bar_start + length
    plot_options = dict(
        c="black",
        lw=lw,
        clip_on=False,
        transform=trans,
        zorder=4,
        solid_capstyle="projecting",
    )
    text_options = dict(va="center",
                        ha=label_align,
                        transform=trans,
                        fontsize=label_size)
    fig = ax.get_figure()
    bracket_length_axcoords = points_to_figcoords(
        bracket_length,
        fig,
        trans=ax.transAxes,
        direction=orthogonal_direction)
    label_size_datacoords = points_to_figcoords(label_size,
                                                fig,
                                                trans=ax.transData,
                                                direction=direction)
    label_size_axcoords = points_to_figcoords(label_size,
                                              fig,
                                              trans=ax.transAxes,
                                              direction=orthogonal_direction)
    label_pos_across = pos_across + (label_offset * label_size_axcoords)
    if label_align == "left":
        label_pos_along = bar_start + (label_pad * label_size_datacoords)
    elif label_align == "center":
        label_pos_along = (bar_start + bar_end) / 2
    elif label_align == "right":
        label_pos_along = bar_end - (label_pad * label_size_datacoords)
    text_coords = (label_pos_along, label_pos_across)
    bar_coords = ([bar_start, bar_end], [pos_across, pos_across])
    # Format of below coords: imagine a horizontal bar, then: ([x, x], [y, y]).
    label_direction = sign(label_offset)
    bracket_direction = -label_direction
    bracket_start_coords = (
        [bar_start, bar_start],
        [pos_across, pos_across + bracket_direction * bracket_length_axcoords],
    )
    bracket_end_coords = (
        [bar_end, bar_end],
        [pos_across, pos_across + bracket_direction * bracket_length_axcoords],
    )
    if direction == "v":
        bar_coords = reversed(bar_coords)
        text_coords = reversed(text_coords)
        bracket_start_coords = reversed(bracket_start_coords)
        bracket_end_coords = reversed(bracket_end_coords)
        text_options.update(dict(rotation=90, rotation_mode="anchor"))
    artists: List[Artist] = []
    artists.extend(ax.plot(*bar_coords, **plot_options))
    if brackets:
        artists.extend(ax.plot(*bracket_start_coords, **plot_options))
        artists.extend(ax.plot(*bracket_end_coords, **plot_options))
    text: Text = ax.text(*text_coords, label, **text_options)
    text.set_bbox(label_background)
    artists.append(text)
    for artist in artists:
        artist.set_in_layout(in_layout)