Exemple #1
0
    def cycle_colors(cls, axis, nr_cycles=1):
        """
        Cycle the axis color cycle ``nr_cycles`` forward

        :param axis: The axis to manipulate
        :type axis: matplotlib.axes.Axes

        :param nr_cycles: The number of colors to cycle through.
        :type nr_cycles: int

        .. note::

          This is an absolute cycle, as in, it will always start from the first
          color defined in the color cycle.

        """
        if nr_cycles < 1:
            return

        colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

        if nr_cycles > len(colors):
            nr_cycles -= len(colors)

        axis.set_prop_cycle(
            make_cycler(color=colors[nr_cycles:] + colors[:nr_cycles]))
Exemple #2
0
    def set_axis_cycler(cls, axis, *cyclers):
        """
        Context manager to set cyclers on an axis (and the default cycler as
        well), and then restore the default cycler.

        .. note:: The given cyclers are merged with the original cycler. The
            given cyclers will override any key of the original cycler, and the
            number of values will be adjusted to the maximum size between all
            of them. This way of merging allows decoupling the length of all
            keys.
        """
        orig_cycler = plt.rcParams['axes.prop_cycle']

        # Get the maximum value length among all cyclers involved
        values_len = max(
            len(values) for values in itertools.chain(
                orig_cycler.by_key().values(),
                itertools.chain.from_iterable(
                    cycler.by_key().values() for cycler in cyclers),
            ))

        # We can only add together cyclers with the same number of values for
        # each key, so cycle through the provided values, up to the right
        # length
        def pad_values(values):
            values = itertools.cycle(values)
            values = itertools.islice(values, 0, values_len)
            return list(values)

        def pad_cycler(cycler):
            keys = cycler.by_key()
            return {key: pad_values(values) for key, values in keys.items()}

        cycler = {}
        for user_cycler in cyclers:
            cycler.update(pad_cycler(user_cycler))

        # Merge the cyclers and original cycler together, so we still get the
        # original values of the keys not overridden by the given cycler
        parameters = {
            **pad_cycler(orig_cycler),
            **cycler,
        }
        cycler = make_cycler(**parameters)

        def set_cycler(cycler):
            plt.rcParams['axes.prop_cycle'] = cycler
            axis.set_prop_cycle(cycler)

        set_cycler(cycler)
        try:
            yield
        finally:
            # Since there is no way to get the cycler from an Axis,
            # we cannot restore the original one, so use the
            # default one instead
            set_cycler(orig_cycler)
Exemple #3
0
            def wrapper(self,
                        *args,
                        filepath=None,
                        axis=None,
                        output=None,
                        img_format=None,
                        always_save=False,
                        colors: TypedList[str] = None,
                        linestyles: TypedList[str] = None,
                        markers: TypedList[str] = None,
                        rc_params=None,
                        **kwargs):
                def is_f_param(param):
                    """
                    Return True if the parameter is for `f`, False if it is
                    for setup_plot()
                    """
                    try:
                        desc = inspect.signature(f).parameters[param]
                    except KeyError:
                        return False
                    else:
                        # Passing kwargs=42 to a function taking **kwargs
                        # should not return True here, as we only consider
                        # explicitly listed arguments
                        return desc.kind not in (
                            inspect.Parameter.VAR_KEYWORD,
                            inspect.Parameter.VAR_POSITIONAL,
                        )

                # Factor the *args inside the **kwargs by binding them to the
                # user-facing signature, which is the one of the wrapper.
                kwargs.update(
                    inspect.signature(wrapper).bind_partial(self,
                                                            *args).arguments)

                f_kwargs = {
                    param: val
                    for param, val in kwargs.items() if is_f_param(param)
                }

                img_format = img_format or guess_format(filepath) or 'png'
                local_fig = axis is None

                # When we create the figure ourselves, always save the plot to
                # the default location
                if local_fig and filepath is None and always_save:
                    filepath = self.get_default_plot_path(
                        img_format=img_format,
                        plot_name=f.__name__,
                    )

                cyclers = dict(
                    color=colors,
                    linestyle=linestyles,
                    marker=markers,
                )
                cyclers = {
                    name: value
                    for name, value in cyclers.items() if value
                }
                if cyclers:
                    cyclers = [
                        make_cycler(**{name: value})
                        for name, value in cyclers.items()
                    ]
                    set_cycler = lambda axis: cls.set_axis_cycler(
                        axis, *cyclers)
                else:
                    set_cycler = lambda axis: nullcontext()

                if rc_params:
                    set_rc_params = lambda axis: cls.set_axis_rc_params(
                        axis, rc_params)
                else:
                    set_rc_params = lambda axis: nullcontext()

                # Allow returning an axis directly, or just update a given axis
                if return_axis:
                    # In that case, the function takes all the kwargs
                    with set_cycler(axis), set_rc_params(axis):
                        axis = f(**kwargs, axis=axis)
                else:
                    if local_fig:
                        setup_plot_kwargs = {
                            param: val
                            for param, val in kwargs.items()
                            if param not in f_kwargs
                        }
                        fig, axis = self.setup_plot(**setup_plot_kwargs)

                    f_kwargs.update(
                        axis=axis,
                        local_fig=f_kwargs.get('local_fig', local_fig),
                    )
                    with set_cycler(axis), set_rc_params(axis):
                        f(**f_kwargs)

                if isinstance(axis, numpy.ndarray):
                    fig = axis[0].get_figure()
                else:
                    fig = axis.get_figure()

                def resolve_formatter(fmt):
                    format_map = {
                        'rst': cls._get_rst_content,
                        'html': cls._get_html,
                    }
                    try:
                        return format_map[fmt]
                    except KeyError:
                        raise ValueError(f'Unsupported format: {fmt}')

                if output is None:
                    out = axis

                    # Show the LISA figure toolbar
                    if is_running_ipython():
                        # Make sure we only add one button per figure
                        try:
                            toolbar = self._get_fig_data(fig, 'toolbar')
                        except KeyError:
                            toolbar = self._make_fig_toolbar(fig)
                            self._set_fig_data(fig, 'toolbar', toolbar)
                            display(toolbar)

                        mplcursors.cursor(fig)
                else:
                    out = resolve_formatter(output)(f, [], f_kwargs, axis)

                if filepath:
                    if img_format in ('html', 'rst'):
                        content = resolve_formatter(img_format)(f, [],
                                                                f_kwargs, axis)

                        with open(filepath, 'wt', encoding='utf-8') as fd:
                            fd.write(content)
                    else:
                        fig.savefig(filepath,
                                    format=img_format,
                                    bbox_inches='tight')

                return out
Exemple #4
0
import mplcursors

from ipywidgets import widgets, Layout, interact
from IPython.display import display

from lisa.utils import is_running_ipython

COLOR_CYCLE = [
    '#377eb8', '#ff7f00', '#4daf4a', '#f781bf', '#a65628', '#984ea3',
    '#999999', '#e41a1c', '#dede00'
]
"""
Colorblind-friendly cycle, see https://gist.github.com/thriveth/8560036
"""

plt.rcParams['axes.prop_cycle'] = make_cycler(color=COLOR_CYCLE)


class WrappingHBox(widgets.HBox):
    """
    HBox that will overflow on multiple lines if the content is too large to
    fit on one line.
    """
    def __init__(self, *args, **kwargs):
        layout = Layout(
            # Overflow items to the next line rather than hiding them
            flex_flow='row wrap',
            # Evenly spread on one line of items
            justify_content='space-around',
        )
        super().__init__(*args, layout=layout, **kwargs)
Exemple #5
0
            def wrapper(self,
                        *args,
                        filepath=None,
                        axis=None,
                        output=None,
                        img_format=None,
                        always_save=False,
                        colors=None,
                        **kwargs):

                # Bind the function to the instance, so we avoid having "self"
                # showing up in the signature, which breaks parameter
                # formatting code.
                f = func.__get__(self, type(self))

                def is_f_param(param):
                    """
                    Return True if the parameter is for `f`, False if it is
                    for setup_plot()
                    """
                    try:
                        desc = inspect.signature(f).parameters[param]
                    except KeyError:
                        return False
                    else:
                        # Passing kwargs=42 to a function taking **kwargs
                        # should not return True here, as we only consider
                        # explicitely listed arguments
                        return desc.kind not in (
                            inspect.Parameter.VAR_KEYWORD,
                            inspect.Parameter.VAR_POSITIONAL,
                        )

                f_kwargs = {
                    param: val
                    for param, val in kwargs.items() if is_f_param(param)
                }

                img_format = img_format or guess_format(filepath) or 'png'
                local_fig = axis is None

                # When we create the figure ourselves, always save the plot to
                # the default location
                if local_fig and filepath is None and always_save:
                    filepath = self.get_default_plot_path(
                        img_format=img_format,
                        plot_name=f.__name__,
                    )

                if colors:
                    cycler = make_cycler(color=colors)
                    set_cycler = lambda axis: cls.set_axis_cycler(axis, cycler)
                else:
                    set_cycler = lambda axis: nullcontext()
                # Allow returning an axis directly, or just update a given axis
                if return_axis:
                    # In that case, the function takes all the kwargs
                    with set_cycler(axis):
                        axis = f(*args, **kwargs, axis=axis)
                else:
                    if local_fig:
                        setup_plot_kwargs = {
                            param: val
                            for param, val in kwargs.items()
                            if param not in f_kwargs
                        }
                        fig, axis = self.setup_plot(**setup_plot_kwargs)

                    with set_cycler(axis):
                        f(*args, axis=axis, local_fig=local_fig, **f_kwargs)

                if isinstance(axis, numpy.ndarray):
                    fig = axis[0].get_figure()
                else:
                    fig = axis.get_figure()

                def resolve_formatter(fmt):
                    format_map = {
                        'rst': cls._get_rst_content,
                        'html': cls._get_html,
                    }
                    try:
                        return format_map[fmt]
                    except KeyError:
                        raise ValueError('Unsupported format: {}'.format(fmt))

                if output is None:
                    out = axis

                    # Show the LISA figure toolbar
                    if is_running_ipython():
                        # Make sure we only add one button per figure
                        try:
                            toolbar = self._get_fig_data(fig, 'toolbar')
                        except KeyError:
                            toolbar = self._make_fig_toolbar(fig)
                            self._set_fig_data(fig, 'toolbar', toolbar)
                            display(toolbar)

                        mplcursors.cursor(fig)
                else:
                    out = resolve_formatter(output)(f, args, f_kwargs, axis)

                if filepath:
                    if img_format in ('html', 'rst'):
                        content = resolve_formatter(img_format)(f, args,
                                                                f_kwargs, axis)

                        with open(filepath, 'wt', encoding='utf-8') as fd:
                            fd.write(content)
                    else:
                        fig.savefig(filepath,
                                    format=img_format,
                                    bbox_inches='tight')

                return out