示例#1
0
def align_hilo_ylim(ax1: plt.Axes, ax2: plt.Axes):
    # Limits
    ((a1, b1), (a2, b2)) = (ax1.get_ylim(), ax2.get_ylim())

    # Visible ticks
    t1 = np.asarray([y for y in ax1.get_yticks() if (a1 <= y <= b1)])
    t2 = np.asarray([y for y in ax2.get_yticks() if (a2 <= y <= b2)])

    # Relative position of ticks
    r1 = (t1 - a1) / (b1 - a1)
    r2 = (t2 - a2) / (b2 - a2)

    # Lower and upper space
    lo = max(min(r1) - 0, min(r2) - 0)
    hi = max(1 - max(r1), 1 - max(r2))

    # Stretch the middle part (usually breaks everything)
    (s1, s2) = (1, 1)

    # Adjust lower and upper space
    f1 = s1 * (max(t1) - min(t1)) / (1 - hi - lo)
    f2 = s2 * (max(t2) - min(t2)) / (1 - hi - lo)
    (a1, b1) = (-lo * f1 + min(t1),
                +hi * f1 + min(t1) + s1 * (max(t1) - min(t1)))
    (a2, b2) = (-lo * f2 + min(t2),
                +hi * f2 + min(t2) + s2 * (max(t2) - min(t2)))

    # Set limits
    ax1.set_ylim(a1, b1)
    ax2.set_ylim(a2, b2)
示例#2
0
def align_twinx_ticks(ax_left: plt.Axes, ax_right: plt.Axes) -> np.ndarray:
    """
    Returns an array of ticks for the right axis which match ones on the left.

    There's no easy way of aligning ticks nor a good general solution.
    """
    left = ax_left.get_ylim()
    right = ax_right.get_ylim()
    return linear_mapping(left, right, ax_left.get_yticks())
示例#3
0
def add_percent_axis(ax: plt.Axes, data_size, flip_axis: bool = False) -> plt.Axes:
    """
    Adds a twin axis with percentages to a count plot.

    Args:
        ax: Plot axes figure to add percentage axis to
        data_size: Total count to use to normalize percentages
        flip_axis: Whether the countplot had its axes flipped

    Returns:
        Twin axis that percentages were added to
    """
    if flip_axis:
        ax_perc = ax.twiny()
        ax_perc.set_xticks(100 * ax.get_xticks() / data_size)
        ax_perc.set_xlim(
            (
                100.0 * (float(ax.get_xlim()[0]) / data_size),
                100.0 * (float(ax.get_xlim()[1]) / data_size),
            )
        )
        ax_perc.xaxis.set_major_formatter(mtick.PercentFormatter())
        ax_perc.xaxis.set_tick_params(labelsize=10)
    else:
        ax_perc = ax.twinx()
        ax_perc.set_yticks(100 * ax.get_yticks() / data_size)
        ax_perc.set_ylim(
            (
                100.0 * (float(ax.get_ylim()[0]) / data_size),
                100.0 * (float(ax.get_ylim()[1]) / data_size),
            )
        )
        ax_perc.yaxis.set_major_formatter(mtick.PercentFormatter())
        ax_perc.yaxis.set_tick_params(labelsize=10)
    ax_perc.grid(False)
    return ax_perc
示例#4
0
def plot_spectrum(spectrum,
                  annotate_ions: bool = False,
                  mirror_intensity: bool = False,
                  grid: Union[bool, str] = True,
                  ax: plt.Axes = None,
                  peak_color="teal",
                  **plt_kwargs) -> plt.Axes:
    """
    Plot a single MS/MS spectrum.

    Code is largely taken from package "spectrum_utils".

    Parameters
    ----------
    spectrum: matchms.Spectrum
        The spectrum to be plotted.
    annotate_ions:
        Flag indicating whether or not to annotate fragment using peak comments
        (if present in the spectrum). The default is True.
    mirror_intensity:
        Flag indicating whether to flip the intensity axis or not.
    grid:
        Draw grid lines or not. Either a boolean to enable/disable both major
        and minor grid lines or 'major'/'minor' to enable major or minor grid
        lines respectively.
    ax:
        Axes instance on which to plot the spectrum. If None the current Axes
        instance is used.

    Returns
    -------
    plt.Axes
        The matplotlib Axes instance on which the spectrum is plotted.
    """
    # pylint: disable=too-many-locals, too-many-arguments
    if ax is None:
        ax = plt.gca()

    min_mz = max(0, np.floor(spectrum.peaks.mz[0] / 100 - 1) * 100)
    max_mz = np.ceil(spectrum.peaks.mz[-1] / 100 + 1) * 100
    max_intensity = spectrum.peaks.intensities.max()

    intensities = spectrum.peaks.intensities / max_intensity

    def make_stems():
        """calculate where the stems of the spectrum peaks are going to be"""
        x = np.zeros([2, spectrum.peaks.mz.size], dtype="float")
        y = np.zeros(x.shape)
        x[:, :] = np.tile(spectrum.peaks.mz, (2, 1))
        y[1, :] = intensities
        return x, y

    x, y = make_stems()
    if mirror_intensity is True:
        y = -y
    ax.plot(x, y, color=peak_color, linewidth=1.0, marker="", zorder=5, **plt_kwargs)
    if annotate_ions and isinstance(spectrum.get("peak_comments"), dict):
        for mz, comment in spectrum.get("peak_comments").items():
            idx = (-abs(spectrum.peaks.mz - mz)).argmax()
            ax.text(mz, intensities[idx], f"m/z: {mz} \n {comment}",
                    _annotation_kws)

    ax.set_xlim(min_mz, max_mz)
    ax.yaxis.set_major_formatter(mticker.PercentFormatter(xmax=1.0))
    y_max = 1.25 if annotate_ions else 1.10
    ax.set_ylim(*(0, y_max) if not mirror_intensity else (-y_max, 0))

    ax.xaxis.set_minor_locator(mticker.AutoLocator())
    ax.yaxis.set_minor_locator(mticker.AutoLocator())
    ax.xaxis.set_minor_locator(mticker.AutoMinorLocator())
    ax.yaxis.set_minor_locator(mticker.AutoMinorLocator())
    if grid in (True, "both", "major"):
        ax.grid(visible=True, which="major", color="#9E9E9E", linewidth=0.2)
    if grid in (True, "both", "minor"):
        ax.grid(visible=True, which="minor", color="#9E9E9E", linewidth=0.2)
    ax.set_axisbelow(True)

    ax.tick_params(axis="both", which="both", labelsize="small")
    y_ticks = ax.get_yticks()
    ax.set_yticks(y_ticks[y_ticks <= 1.0])

    ax.set_xlabel("m/z", style="italic")
    ax.set_ylabel("Intensity")
    title = "Spectrum" if spectrum.get("compound_name") is None else spectrum.get("compound_name")
    ax.set_title(title)
    return ax
示例#5
0
文件: plt2.py 项目: yulkang/pylabyk
def break_axis(
    amin,
    amax=None,
    xy='x',
    ax: plt.Axes = None,
    fun_draw: Callable = None,
    margin=0.05,
) -> (plt.Axes, plt.Axes):
    """
    :param amin: data coordinate to start breaking from
    :param amax: data coordinate to end breaking at
    :param xy: 'x' or 'y'
    :param fun_draw: if not None, fun_draw(ax1) and fun_draw(ax2) will
    be run to recreate ax. Use the same function as that was called for
    with ax. Use, e.g., fun_draw=lambda ax: ax.plot(x, y)
    :return: axs: a list of axes created
    """

    if amax is None:
        amax = amin

    if ax is None:
        ax = plt.gca()

    if xy == 'x':
        rect = ax.get_position().bounds
        lim = ax.get_xlim()
        prop_min = (amin - lim[0]) / (lim[1] - lim[0])
        prop_max = (amax - lim[0]) / (lim[1] - lim[0])
        rect1 = np.array([rect[0], rect[1], rect[2] * prop_min, rect[3]])
        rect2 = [
            rect[0] + rect[2] * prop_max, rect[1], rect[2] * (1 - prop_max),
            rect[3]
        ]

        fig = ax.figure  # type: plt.Figure
        ax1 = fig.add_axes(plt.Axes(fig=fig, rect=rect1))
        ax1.update_from(ax)
        if fun_draw is not None:
            fun_draw(ax1)
        ax1.set_xticks(ax.get_xticks())
        ax1.set_xlim(lim[0], amin)
        ax1.spines['right'].set_visible(False)

        ax2 = fig.add_axes(plt.Axes(fig=fig, rect=rect2))
        ax2.update_from(ax)
        if fun_draw is not None:
            fun_draw(ax2)
        ax2.set_xticks(ax.get_xticks())
        ax2.set_xlim(amax, lim[1])
        ax2.spines['left'].set_visible(False)
        ax2.set_yticks([])

        ax.set_visible(False)
        # plt.show()  # CHECKED
        axs = [ax1, ax2]

    elif xy == 'y':
        rect = ax.get_position().bounds
        lim = ax.get_ylim()
        prop_all = ((amin - lim[0]) + (lim[1] - amax)) / (1 - margin)
        prop_min = (amin - lim[0]) / prop_all
        prop_max = (lim[1] - amax) / prop_all
        rect1 = np.array([rect[0], rect[1], rect[2], rect[3] * prop_min])
        rect2 = [
            rect[0], rect[1] + rect[3] * (1 - prop_max), rect[2],
            rect[3] * (1 - prop_max)
        ]

        fig = ax.figure  # type: plt.Figure
        ax1 = fig.add_axes(plt.Axes(fig=fig, rect=rect1))
        ax1.update_from(ax)
        if fun_draw is not None:
            fun_draw(ax1)
        ax1.set_yticks(ax.get_yticks())
        ax1.set_ylim(lim[0], amin)
        ax1.spines['top'].set_visible(False)

        ax2 = fig.add_axes(plt.Axes(fig=fig, rect=rect2))
        ax2.update_from(ax)
        if fun_draw is not None:
            fun_draw(ax2)
        ax2.set_yticks(ax.get_yticks())
        ax2.set_ylim(amax, lim[1])
        ax2.spines['bottom'].set_visible(False)
        ax2.set_xticks([])

        ax.set_visible(False)
        # plt.show()  # CHECKED
        axs = [ax1, ax2]

    else:
        raise ValueError()

    return axs
示例#6
0
文件: pyGM.py 项目: aleximorin/pyGM
    def plot_map(self,
                 im: np.array,
                 title: str = '',
                 cbar_unit: str = None,
                 tag: str = None,
                 meta: dict = None,
                 cmap: str = 'viridis',
                 view_extent: np.array = None,
                 outline: bool = False,
                 points: bool = False,
                 point_color: bool = False,
                 rectangle: bool = False,
                 labels: bool = False,
                 ticks: bool = True,
                 scamap: bool = None,
                 ax: plt.Axes = None,
                 hillshade: bool = False,
                 scale_dict: dict = None,
                 grid: bool = True,
                 alpha: float = 1,
                 showplot: bool = False,
                 sci: bool = False,
                 figsize: tuple = None,
                 ashape: bool = None):

        # Main plotting function. Ensures that all other plots have the same parameters
        """ Inputs:
        im: The 3D np array to be plotted. Can be the DEM, the thickness, the error, etc.
        title: The title of the plot
        cbar_unit: The units of the color bar
        tag: The tag of the plot with which it will be saved
        cmap: The colormap wanted for the plot, defaults to viridis
        view_extent: The extent wanted for the plot. Defaults to None which will set the extent to the whole map.
                     Takes as input a numpy array, like the self.point_extent array
        Outline: Boolean value, defaults to False. Set to true if you want the outline plotted on the map.
        points: Boolean value, defaults to False. Set to true if you want the points plotted on the map.
        point_color: Boolean value, defaults to False. Set to true if you want the points colored by the thickness.
                     Could be changed in the future to a scalar map instead of a boolean.
        rectangle: Boolean value, defaults to False. Set to true if you want a rectangle outlining the points' extent
                   in the map. CURRENTLY DOESN'T WORK PROPERLY
        labels: Boolean value, defaults to False. Set to true if you want the x and y labels plotted.
        ticks: Boolean value, defaults to True. Set to true if you want the x and y ticks plotted.
        scamap: Scalarmap object, defaults to None. Set if you want a specific colormap scale. Useful for subplots.
        ax: matplotlib Ax object, defaults to None. Set if you want to specify on which ax to plot.
            Useful for subplots.
        hillshade: Boolean value, defaults to False. Set for a hillshade effect, especially on DEMs
        scale: Dictionary, defaults to None. Parameters to add a scale to the map
        alpha: Float defaults to 1. Sets the transparency of the main map image
        showplot: Boolean, defaults to False. Set to true if you want to see the figure. Only pops up if
                  no ax object is provided.
        """
        # Gets the two last dimensions of the 3D array. Needed because rasters from rasterio are of (1, m, n) size
        if len(im.shape) == 3:
            im = im[0]

        if meta is None:
            meta = self.meta

        # Manages the extent array. The order needed here is different than given from the shapely format
        b = [0, 2, 1, 3]
        extent = [
            rasterio.transform.array_bounds(*im.shape, meta['transform'])[i]
            for i in b
        ]
        xmin, xmax, ymin, ymax = [self.outline.total_bounds[i] for i in b]

        if view_extent is not None:
            view_extent = [view_extent[i] for i in b]
            xmin, xmax, ymin, ymax = view_extent

        fig = None
        if ax is None:
            fig = plt.figure(figsize=figsize, tight_layout=True)
            ax = plt.gca()

        # Plot the image and add a colorbar
        if scamap is None:
            norm = Normalize(np.nanmin(im), np.nanmax(im))
            scamap = plt.cm.ScalarMappable(cmap=cmap, norm=norm)

        if hillshade:
            hills = im_to_hillshade(self.dem_im[0], 225, 40)
            ax.imshow(hills, extent=[self.extent[i] for i in b], cmap='Greys')

        img = ax.imshow(im, cmap=cmap, extent=extent, alpha=alpha)

        if cbar_unit is not None:
            divider = make_axes_locatable(ax)
            cax = divider.append_axes('right', size=0.2, pad=0.1)
            cbar = plt.colorbar(img, cax=cax, orientation='vertical')
            cbar.ax.get_yaxis().labelpad = 15
            cbar.ax.set_title(f'{cbar_unit}')

        c = 'black'
        e = None
        if point_color:
            c = scamap.to_rgba(self.gpr.iloc[:, 2])
            e = 'black'

        # Plots various map accessories if asked
        if points:
            lw = 0.1
            if len(self.gpr) > 1000:
                lw = 0
            ax.scatter(self.gpr.geometry.x,
                       self.gpr.geometry.y,
                       c=c,
                       cmap=cmap,
                       edgecolors=e,
                       linewidths=lw)

        if outline:
            self.outline.plot(ax=ax, facecolor='None', edgecolor='black')

        if rectangle:
            rec = self.create_rectangle(color='red', lw=1)
            ax.add_patch(rec)
        if ashape:
            x, y = self.ashape.unary_union.exterior.xy
            ax.plot(x, y, color='red', lw=1)

        if scale_dict is not None:
            scale_dict = parse_scale_dict(scale_dict)

            y_offset = (ymax - ymin) * scale_dict['y_offset']
            x_offset = (xmax - xmin) * scale_dict['x_offset']

            length = scale_dict['length']
            label_length = length

            if scale_dict['units'] == 'km':
                length *= 1000

            xs = [xmin + x_offset, xmin + x_offset + length]
            ys = [ymin + y_offset, ymin + y_offset]

            bar = ax.plot(xs, ys, linewidth=5, color=scale_dict['color'])
            txt = ax.text(np.mean(xs),
                          np.mean(ys) + y_offset *
                          (1 + 0.05 / scale_dict['y_offset']),
                          f'{label_length} {scale_dict["units"]}',
                          ha='center',
                          va='center',
                          color=scale_dict['color'])
            # txt.set_path_effects([pe.withStroke(linewidth=5, foreground='w')])

        # Customises the map
        ax.set_title(f'{title}')
        ax.set_xlim(xmin, xmax)
        ax.set_ylim(ymin, ymax)
        ax.set_aspect('equal')

        if labels:
            ax.set_xlabel('Eastings [m]')
            ax.set_ylabel('Northing [m]')
        if ticks:
            if sci:
                ax.ticklabel_format(axis='both', style='sci', scilimits=(0, 0))
            else:
                ax.set_yticklabels(
                    ['{:,.0f}'.format(y) for y in ax.get_yticks().tolist()])
                ax.set_xticklabels(
                    ['{:,.0f}'.format(x) for x in ax.get_xticks().tolist()],
                    rotation=-45)

        else:
            ax.set_yticklabels([])
            ax.set_xticklabels([])
            ax.set_xticks([])
            ax.set_yticks([])

        if grid:
            ax.grid()

        # Defines a custom legend for the outline due to a geopandas bug
        # Currently no legend, need to figure out something for every model and dem source
        """handles, labels = ax.get_legend_handles_labels()
        handles.append(mlines.Line2D([], [], color='black', label=data['outline']))
        handles.append(matplotlib.patches.Patch (color='none', label=crs))
        ax.legend(handles=handles, bbox_to_anchor=(0.5, data['leg_pos']), loc="lower center",
                  bbox_transform=fig.transFigure, ncol=len(handles), frameon=True)"""

        if fig is not None:
            if showplot:
                plt.show()
            if tag is not None:
                fig.savefig(f'{self.img_folder}/{tag}.png',
                            bbox_inches='tight')
            return fig, ax

        else:
            return img