コード例 #1
0
ファイル: plot_set.py プロジェクト: biorack/metatlas
def _autoscale(ax: matplotlib.axes.Axes,
               axis: str = "y",
               sides: str = "both",
               margin: float = 0.1) -> None:
    """Autoscales the x or y axis of a given matplotlib ax object
    to fit the margins set by manually limits of the other axis,
    with margins in fraction of the width of the plot
    if sides is 'max' or 'min' then only adjust the limit on that side of axis"""
    assert axis in ["x", "y"]
    assert sides in ["both", "min", "max"]
    low, high = np.inf, -np.inf
    for artist in ax.collections + ax.lines:
        if axis == "y":
            set_lim = ax.set_ylim
            get_lim = ax.get_ylim
            cur_fixed_limit = ax.get_xlim()
            fixed, dependent = _get_xy(artist)
        else:
            set_lim = ax.set_xlim
            get_lim = ax.get_xlim
            cur_fixed_limit = ax.get_ylim()
            dependent, fixed = _get_xy(artist)
        low, high = _update_limts(low, high, fixed, dependent, cur_fixed_limit)
    margin = margin * (high - low)
    if low == np.inf and high == -np.inf:
        return
    assert low != np.inf and high != -np.inf
    new_min = (low - margin) if sides in ["both", "min"] else get_lim()[0]
    new_max = (high + margin) if sides in ["both", "max"] else get_lim()[1]
    set_lim(new_min, new_max)
コード例 #2
0
ファイル: plot.py プロジェクト: andreasfelix/apace
def draw_elements(
    ax: mpl.axes.Axes,
    lattice: Lattice,
    *,
    labels: bool = True,
    location: str = "top",
):
    """Draw the elements of a lattice onto a matplotlib axes."""
    x_min, x_max = ax.get_xlim()
    y_min, y_max = ax.get_ylim()
    rect_height = 0.05 * (y_max - y_min)
    if location == "top":
        y0 = y_max = y_max + rect_height
    else:
        y0 = y_min - rect_height
        y_min -= 3 * rect_height
        plt.hlines(y0, x_min, x_max, color="black", linewidth=1)
    ax.set_ylim(y_min, y_max)

    sign = -1
    start = end = 0
    for element, group in groupby(lattice.sequence):
        start = end
        end += element.length * sum(1 for _ in group)
        if end <= x_min:
            continue
        elif start >= x_max:
            break

        try:
            color = ELEMENT_COLOR[type(element)]
        except KeyError:
            continue

        y0_local = y0
        if isinstance(element, Dipole) and element.angle < 0:
            y0_local += rect_height / 4

        ax.add_patch(
            plt.Rectangle(
                (max(start, x_min), y0_local - 0.5 * rect_height),
                min(end, x_max) - max(start, x_min),
                rect_height,
                facecolor=color,
                clip_on=False,
                zorder=10,
            ))
        if labels and type(element) in {Dipole, Quadrupole}:
            sign = -sign
            ax.annotate(
                element.name,
                xy=(0.5 * (start + end), y0 + sign * rect_height),
                fontsize=FONT_SIZE,
                ha="center",
                va="bottom" if sign > 0 else "top",
                annotation_clip=False,
                zorder=11,
            )
コード例 #3
0
def zoom_xy_and_save(fig: matplotlib.figure.Figure,
                     ax: matplotlib.axes.Axes,
                     figbase: str,
                     plot_ext: str,
                     xyzoom: List[Tuple[float, float, float, float]],
                     scale: float = 1000) -> None:
    """
    Zoom in on subregions in x,y-space and save the figure.

    Arguments
    ---------
    fig : matplotlib.figure.Figure
        Figure to be processed.
    ax : matplotlib.axes.Axes
        Axes to be processed.
    fig_base : str
        Base name of the figure to be saved.
    plot_ext : str
        File extension of the figure to be saved.
    xyzoom : List[List[float, float, float, float]]
        List of xmin, xmax, ymin, ymax values to zoom into.
    scale: float
        Indicates whether the axes are in m (1) or km (1000).
    """
    xmin, xmax = ax.get_xlim()
    ymin, ymax = ax.get_ylim()

    dx_zoom = 0
    xy_ratio = (ymax - ymin) / (xmax - xmin)
    for ix in range(len(xyzoom)):
        xmin0 = xyzoom[ix][0]
        xmax0 = xyzoom[ix][1]
        ymin0 = xyzoom[ix][2]
        ymax0 = xyzoom[ix][3]
        dx = xmax0 - xmin0
        dy = ymax0 - ymin0
        if dy < xy_ratio * dx:
            # x range limiting
            dx_zoom = max(dx_zoom, dx)
        else:
            # y range limiting
            dx_zoom = max(dx_zoom, dy / xy_ratio)
    dy_zoom = dx_zoom * xy_ratio

    for ix in range(len(xyzoom)):
        x0 = (xyzoom[ix][0] + xyzoom[ix][1]) / 2
        y0 = (xyzoom[ix][2] + xyzoom[ix][3]) / 2
        ax.set_xlim(xmin=(x0 - dx_zoom / 2) / scale,
                    xmax=(x0 + dx_zoom / 2) / scale)
        ax.set_ylim(ymin=(y0 - dy_zoom / 2) / scale,
                    ymax=(y0 + dy_zoom / 2) / scale)
        figfile = (figbase + ".sub" + str(ix + 1) + plot_ext)
        savefig(fig, figfile)

    ax.set_xlim(xmin=xmin, xmax=xmax)
    ax.set_ylim(ymin=ymin, ymax=ymax)
コード例 #4
0
    def default_vertices(self, ax: matplotlib.axes.Axes) -> tuple:
        """
        Default to rectangle that has a quarter-width/height border.
        """
        xlims = ax.get_xlim()
        ylims = ax.get_ylim()
        w = np.diff(xlims)
        h = np.diff(ylims)
        x1, x2 = xlims + w // 4 * np.array([1, -1])
        y1, y2 = ylims + h // 4 * np.array([1, -1])

        return ((x1, y1), (x1, y2), (x2, y2), (x2, y1))
コード例 #5
0
ファイル: plot.py プロジェクト: kapilsh/pyqstrat
def _plot_data(ax: mpl.axes.Axes, data: PlotData) -> Optional[List[mpl.lines.Line2D]]:
    
    x, y = None, None
    
    lines = None  # Return line objects so we can add legends
    
    disp = data.display_attributes
    
    if isinstance(data, XYData) or isinstance(data, TimeSeries):
        x, y = (data.x, data.y) if isinstance(data, XYData) else (np.arange(len(data.timestamps)), data.values)
        if isinstance(disp, LinePlotAttributes):
            lines, = ax.plot(x, y, linestyle=disp.line_type, linewidth=disp.line_width, color=disp.color)
            if disp.marker is not None:  # type: ignore
                ax.scatter(x, y, marker=disp.marker, c=disp.marker_color, s=disp.marker_size, zorder=100)
        elif isinstance(disp, ScatterPlotAttributes):
            lines = ax.scatter(x, y, marker=disp.marker, c=disp.marker_color, s=disp.marker_size, zorder=100)
        elif isinstance(disp, BarPlotAttributes):
            lines = ax.bar(x, y, color=disp.color)  # type: ignore
        elif isinstance(disp, FilledLinePlotAttributes):
            x, y = np.nan_to_num(x), np.nan_to_num(y)
            pos_values = np.where(y > 0, y, 0)
            neg_values = np.where(y < 0, y, 0)
            ax.fill_between(x, pos_values, color=disp.positive_color, step='post', linewidth=0.0)
            ax.fill_between(x, neg_values, color=disp.negative_color, step='post', linewidth=0.0)
        else:
            raise Exception(f'unknown plot combination: {type(data)} {type(disp)}')
            
        # For scatter and filled line, xlim and ylim does not seem to get set automatically
        if isinstance(disp, ScatterPlotAttributes) or isinstance(disp, FilledLinePlotAttributes):
            xmin, xmax = _adjust_axis_limit(ax.get_xlim(), x)
            if not np.isnan(xmin) and not np.isnan(xmax): ax.set_xlim((xmin, xmax))

            ymin, ymax = _adjust_axis_limit(ax.get_ylim(), y)
            if not np.isnan(ymin) and not np.isnan(ymax): ax.set_ylim((ymin, ymax))
                
    elif isinstance(data, TradeSet) and isinstance(disp, ScatterPlotAttributes):
        lines = ax.scatter(np.arange(len(data.timestamps)), data.values, marker=disp.marker, c=disp.marker_color, s=disp.marker_size, zorder=100)
    elif isinstance(data, TradeBarSeries) and isinstance(disp, CandleStickPlotAttributes):
        draw_candlestick(ax, np.arange(len(data.timestamps)), data.o, data.h, data.l, data.c, data.v, data.vwap, colorup=disp.colorup, colordown=disp.colordown)
    elif isinstance(data, BucketedValues) and isinstance(disp, BoxPlotAttributes):
        draw_boxplot(
            ax, data.bucket_names, data.bucket_values, disp.proportional_widths, disp.notched,  # type: ignore
            disp.show_outliers, disp.show_means, disp.show_all)  # type: ignore
    elif isinstance(data, XYZData) and (isinstance(disp, SurfacePlotAttributes) or isinstance(disp, ContourPlotAttributes)):
        display_type: str = 'contour' if isinstance(disp, ContourPlotAttributes) else 'surface'
        draw_3d_plot(ax, data.x, data.y, data.z, display_type, disp.marker, disp.marker_size, 
                     disp.marker_color, disp.interpolation, disp.cmap)
    else:
        raise Exception(f'unknown plot combination: {type(data)} {type(disp)}')

    return lines
コード例 #6
0
ファイル: plot.py プロジェクト: andreasfelix/apace
def draw_sub_lattices(
    ax: mpl.axes.Axes,
    lattice: Lattice,
    *,
    labels: bool = True,
    location: str = "top",
):
    x_min, x_max = ax.get_xlim()
    length_gen = [0.0, *(obj.length for obj in lattice.children)]
    position_list = np.add.accumulate(length_gen)
    i_min = np.searchsorted(position_list, x_min)
    i_max = np.searchsorted(position_list, x_max, side="right")
    ticks = position_list[i_min:i_max]
    ax.set_xticks(ticks)
    ax.grid(color=Color.LIGHT_GRAY, linestyle="--", linewidth=1)

    if labels:
        y_min, y_max = ax.get_ylim()
        height = 0.08 * (y_max - y_min)
        if location == "top":
            y0 = y_max - height
        else:
            y0, y_min = y_min - height / 3, y_min - height

        ax.set_ylim(y_min, y_max)
        start = end = 0
        for obj in lattice.children:
            end += obj.length
            if not isinstance(obj, Lattice) or start >= x_max or end <= x_min:
                continue

            x0 = 0.5 * (max(start, x_min) + min(end, x_max))
            ax.annotate(
                obj.name,
                xy=(x0, y0),
                fontsize=FONT_SIZE + 2,
                fontstyle="oblique",
                va="center",
                ha="center",
                clip_on=True,
                zorder=102,
            )
            start = end
コード例 #7
0
ファイル: __init__.py プロジェクト: f-koehler/mlxtk
    def apply(self, axes: matplotlib.axes.Axes,
              figure: matplotlib.figure.Figure):

        axes.grid(self.grid)
        if self.logx:
            axes.set_xscale("log")
        if self.logy:
            axes.set_yscale("log")

        xmin, xmax = axes.get_xlim()
        ymin, ymax = axes.get_ylim()
        xmin = xmin if self.xmin is None else self.xmin
        xmax = xmax if self.xmax is None else self.xmax
        ymin = ymin if self.ymin is None else self.ymin
        ymax = ymax if self.ymax is None else self.ymax
        axes.set_xlim(xmin=xmin, xmax=xmax)
        axes.set_ylim(ymin=ymin, ymax=ymax)

        if self.dpi and (figure is not None):
            figure.set_dpi(self.dpi)
コード例 #8
0
ファイル: plotting.py プロジェクト: josebsalazar/RTGLOBAL
def plot_vlines(
    ax: matplotlib.axes.Axes,
    vlines: preprocessing.NamedDates,
    alignment: str,
) -> None:
    """ Helper function for marking special events with labeled vertical lines.

    Parameters
    ----------
    ax : matplotlib.axes.Axes
        the subplot to draw into
    vlines : dict of { datetime : label }
        the dates and labels for the lines
    alignment : str
        one of { "top", "bottom" }
    """
    ymin, ymax = ax.get_ylim()
    xmin, xmax = ax.get_xlim()
    for x, label in vlines.items():
        if xmin <= ax.xaxis.convert_xunits(x) <= xmax:
            label = textwrap.shorten(label, width=20, placeholder="...")
            ax.axvline(x, color="gray", linestyle=":")
            if alignment == 'top':
                y = ymin+0.98*(ymax-ymin)
            elif alignment == 'bottom':
                y = ymin+0.02*(ymax-ymin)
            else:
                raise ValueError(f"Unsupported alignment: '{alignment}'")
            ax.text(
                x, y,
                s=f'{label}\n',
                color="gray",
                rotation=90,
                horizontalalignment="center",
                verticalalignment=alignment,
            )
    return None
コード例 #9
0
def equal_axlim(axs: mpl.axes.Axes, mode: str = 'union') -> None:
    """Make x/y axes limits the same.

    Parameters
    ----------
    axs : mpl.axes.Axes
        `Axes` instance whose limits are to be adjusted.
    mode : str
        How do we adjust the limits? Options:
            'union'
                Limits include old ranges of both x and y axes, *default*.
            'intersect'
                Limits only include values in both ranges.
            'x'
                Set y limits to x limits.
            'y'
                Set x limits to y limits.
    Raises
    ------
    ValueError
        If `mode` is not one of the options above.
    """
    xlim = axs.get_xlim()
    ylim = axs.get_ylim()
    modes = {
        'union': (min(xlim[0], ylim[0]), max(xlim[1], ylim[1])),
        'intersect': (max(xlim[0], ylim[0]), min(xlim[1], ylim[1])),
        'x': xlim,
        'y': ylim
    }
    if mode not in modes:
        raise ValueError(f"Unknown mode '{mode}'. Shoulde be one of: "
                         "'union', 'intersect', 'x', 'y'.")
    new_lim = modes[mode]
    axs.set_xlim(new_lim)
    axs.set_ylim(new_lim)
コード例 #10
0
 def plot(self, ax: mpl.axes.Axes):
     """Plot the styled object onto a matplotlib axes."""
     img_height = int(abs(ax.get_xlim()[1] - ax.get_xlim()[0]))
     img_width = int(abs(ax.get_ylim()[1] - ax.get_ylim()[0]))
     array = self.pad_to(img_width, img_height)
     ax.imshow(array, cmap=self.cmap, alpha=self.style.get("alpha", 0.5))