示例#1
0
def signal_set_xlim(axes, tmin=None, tmax=None, **kwargs):
    '''
        This set_xlim function replaces the usual matplotlib axes set_xlim
    function.  It will redraw the signals after having downsampled them.
    '''
    # don't do anything if locked.
    if axes._are_axes_locked:
        return
    axes.lock_axes()

    # parse inputs
    if tmax is None and is_iterable(tmin):
        tmin, tmax = tmin

    if hasattr(axes, '_signal_times'):
        for s_id in axes._signal_draw_order:
            # don't replot if bounds didn't actually change.
            xmin, xmax = axes.get_xlim()
            if xmin == tmin and xmax == tmax and\
               s_id in axes._signal_lines.keys():
                continue

            # delete existing lines
            if s_id in axes._signal_lines.keys():
                axes._signal_lines[s_id].remove()
                del axes._signal_lines[s_id]

            # downsample
            new_signal, new_times = downsample_for_plot(
                    axes._signals[s_id],
                    axes._signal_times[s_id], 
                    tmin, tmax, axes._signal_num_samples)

            line = axes.plot(new_times, new_signal, 
                    *axes._signal_args[s_id],                    
                    **axes._signal_kwargs[s_id])[0]

            # save this line so we can remove it later.
            axes._signal_lines[s_id] = line

    axes.unlock_axes()

    # actually change the xlimits
    axes._pre_signal_set_xlim(tmin, tmax, **kwargs)
def make_into_publication_axes(
    axes,
    axis="both",
    base_unit=("s", "V"),
    base_unit_prefix=("m", "m"),
    plot_to_base_factor=(1.0, 1.0),
    target_size_frac=(0.2, 0.2),
    scale_bar_origin_frac=(0.7, 0.7),
    live_update=True,
    chunks=[1, 2, 5, 10, 20, 25, 30, 50, 100, 200, 500],
    y_label_rotation="horizontal",
    color="black",
):
    """
        Makes the axes look like trace plots in publications.  That is, the
    frame is not drawn, and a scale bar for x and y is displayed instead.
    Inputs:
        axes: The axes you wish to format this way.
        axis: 'x', 'y', or 'both'
        base_unit: The base SI unit, for example: 'V', 'Hz', 'F', 'C'...
        base_unit_prefix: The prefix needed to get to the plotted units.
        plot_to_base_factor: The factor which converts plotted units to base
                             units. (i.e. pixels_to_um for an image)
        target_size_frac: The desired size of the scale bar, this may be off by
                          a bit, depending on the choice of <chunks>. (kwarg)
        scale_bar_origin_frac   : The position of the origin of the scale
                                  bar in fractional axes units.
        live_update: If True, the scale bars will update when ylim changes.
        chunks           : Passed on to get_scale_bar_info.
        y_label_rotation: the angle of the y label. Anything that a matplotlib
                Text object accepts as a rotation= kwarg.
        color: The color of the lines and labels.
    Returns:
        None       : It just alters the axes.
    """
    if live_update:
        if not hasattr(axes, "_is_signal_axes"):
            make_into_signal_axes(axes)
        axes._live_updating_scalebars = True

    axes.set_frame_on(False)
    axes.set_xticks([])
    axes.set_yticks([])

    # parse inputs
    if is_iterable(base_unit):
        base_unit_x, base_unit_y = base_unit
    else:
        base_unit_x, base_unit_y = (base_unit, base_unit)

    if is_iterable(base_unit_prefix):
        base_unit_prefix_x, base_unit_prefix_y = base_unit_prefix
    else:
        base_unit_prefix_x, base_unit_prefix_y = (base_unit_prefix, base_unit_prefix)

    if is_iterable(plot_to_base_factor):
        plot_to_base_factor_x, plot_to_base_factor_y = plot_to_base_factor
    else:
        plot_to_base_factor_x, plot_to_base_factor_y = (plot_to_base_factor, plot_to_base_factor)

    if is_iterable(target_size_frac):
        target_size_frac_x, target_size_frac_y = target_size_frac
    else:
        target_size_frac_x, target_size_frac_y = (target_size_frac, target_size_frac)

    def create_x_scale_bar(axes):
        # clear old
        if hasattr(axes, "_x_scale_bar"):
            for line in axes._x_scale_bar:
                line.remove()
            del axes._x_scale_bar
        if hasattr(axes, "_x_scale_text"):
            axes._x_scale_text.remove()
            del axes._x_scale_text

        x_min = numpy.min(axes.get_xlim())
        x_max = numpy.max(axes.get_xlim())
        x_range = x_max - x_min

        y_min = numpy.min(axes.get_ylim())
        y_max = numpy.max(axes.get_ylim())
        y_range = y_max - y_min

        scale_bar_info = get_scale_bar_info(
            target_size_frac_x, x_range * plot_to_base_factor_x, base_unit_prefix=base_unit_prefix_x, chunks=chunks
        )

        bar_min = x_min + scale_bar_origin_frac[0] * x_range
        bar_max = bar_min + scale_bar_info["scale_base"] / plot_to_base_factor_x
        bar_y = scale_bar_origin_frac[1] * y_range + y_min

        axes._x_scale_bar = axes.plot((bar_min, bar_max), (bar_y, bar_y), linewidth=2, color=color, clip_on=False)

        bar_range = bar_max - bar_min
        bar_mid = bar_min + bar_range / 2.0

        text = "%s %s%s" % (scale_bar_info["best_chunk"], scale_bar_info["scale_unit_prefix"], base_unit_x)
        f = axes.figure
        canvas_size_in_pixels = (f.get_figwidth() * f.get_dpi(), f.get_figheight() * f.get_dpi())
        y_pix = as_fraction_axes(y=8, axes=axes, canvas_size_in_pixels=canvas_size_in_pixels) * y_range
        axes._x_scale_text = axes.text(
            bar_mid, bar_y - y_pix, text, horizontalalignment="center", verticalalignment="top", color=color
        )

    def create_y_scale_bar(axes):
        # clear old
        if hasattr(axes, "_y_scale_bar"):
            for line in axes._y_scale_bar:
                line.remove()
            del axes._y_scale_bar
        if hasattr(axes, "_y_scale_text"):
            axes._y_scale_text.remove()
            del axes._y_scale_text

        x_min = numpy.min(axes.get_xlim())
        x_max = numpy.max(axes.get_xlim())
        x_range = x_max - x_min

        y_min = numpy.min(axes.get_ylim())
        y_max = numpy.max(axes.get_ylim())
        y_range = y_max - y_min

        scale_bar_info = get_scale_bar_info(
            target_size_frac_y, y_range * plot_to_base_factor_y, base_unit_prefix=base_unit_prefix_y, chunks=chunks
        )

        bar_min = scale_bar_origin_frac[1] * y_range + y_min
        bar_max = bar_min + scale_bar_info["scale_base"] / plot_to_base_factor_y
        bar_x = scale_bar_origin_frac[0] * x_range + x_min

        axes._y_scale_bar = axes.plot((bar_x, bar_x), (bar_min, bar_max), linewidth=2, color=color, clip_on=False)

        bar_range = bar_max - bar_min
        bar_mid = bar_min + bar_range / 2.0

        text = "%s %s%s" % (scale_bar_info["best_chunk"], scale_bar_info["scale_unit_prefix"], base_unit_y)
        f = axes.figure
        canvas_size_in_pixels = (f.get_figwidth() * f.get_dpi(), f.get_figheight() * f.get_dpi())
        x_pix = as_fraction_axes(x=8, axes=axes, canvas_size_in_pixels=canvas_size_in_pixels) * x_range
        axes._y_scale_text = axes.text(
            bar_x - x_pix,
            bar_mid,
            text,
            horizontalalignment="right",
            verticalalignment="center",
            rotation=y_label_rotation,
            color=color,
        )

    if axis in ["both", "x"]:
        axes._create_x_scale_bar = create_x_scale_bar
    else:
        axes._create_x_scale_bar = lambda axes: None

    if axis in ["both", "y"]:
        axes._create_y_scale_bar = create_y_scale_bar
    else:
        axes._create_y_scale_bar = lambda axes: None