Пример #1
0
def heated_barplot(
    data: pd.Series, desat: float = 0.6, ax: Axes = None, figsize: tuple = (8, 10)
) -> Axes:
    """Plot a sharply divided ranking of positive and negative values.

    Args:
        data (pd.Series): Data to plot.
        desat (float, optional): Saturation of bar colors. Defaults to 0.6.
        ax (Axes, optional): Axes to plot on. Defaults to None.
        figsize (tuple, optional): Figure size. Defaults to (8, 10).

    Returns:
        Axes: Axes for the plot.
    """
    if not ax:
        fig, ax = plt.subplots(figsize=figsize)
    data.sort_values(ascending=False, inplace=True)
    blues = sns.color_palette("Blues", (data <= 0).sum(), desat=desat)
    reds = sns.color_palette("Reds_r", (data > 0).sum(), desat=desat)
    palette = reds + blues
    ax = sns.barplot(
        x=data.values, y=data.index, palette=palette, orient="h", ec="gray", ax=ax
    )
    ax.axvline(0.0, color="gray", lw=1, ls="-")
    return ax
Пример #2
0
def decorate_azimuth_ax(
    ax: Axes,
    label: str,
    length_array: np.ndarray,
    set_array: np.ndarray,
    set_names: Tuple[str, ...],
    set_ranges: SetRangeTuple,
    axial: bool,
    visualize_sets: bool,
    append_azimuth_set_text: bool = False,
):
    """
    Decorate azimuth rose plot ax.
    """
    # Title is the name of the target area or group
    prop_title = dict(boxstyle="square",
                      facecolor="linen",
                      alpha=1,
                      linewidth=2)
    # title = "\n".join(wrap(f"{label}", 10))
    title = fill(label, 10)
    ax.set_title(
        title,
        x=0.94 if axial else 1.15,
        y=0.8 if axial else 1.0,
        fontsize="large",
        fontweight="bold",
        fontfamily="DejaVu Sans",
        va="top",
        bbox=prop_title,
        transform=ax.transAxes,
        ha="center",
    )
    prop = dict(boxstyle="square", facecolor="linen", alpha=1, pad=0.45)
    # text = f"n ={len(set_array)}\n"
    text = f"n ={len(set_array)}"
    if append_azimuth_set_text:
        text += "\n"
        text = text + _create_azimuth_set_text(length_array, set_array,
                                               set_names)
    ax.text(
        x=0.96 if axial else 1.1,
        y=0.3 if axial else 0.15,
        s=text,
        transform=ax.transAxes,
        fontsize="medium",
        weight="roman",
        bbox=prop,
        fontfamily="DejaVu Sans",
        va="top",
        ha="center",
    )

    # Add lines to denote azimuth set edges
    if visualize_sets:
        for set_range in set_ranges:
            for edge in set_range:
                ax.axvline(np.deg2rad(edge), linestyle="dashed", color="black")
Пример #3
0
def all_traces(record_file: File, ax: Axes):
    """plot full traces of all neurons and trial onsets"""
    lever_trajectory = load_mat(record_file["response"])
    calcium_trace = _scale(DataFrame.load(record_file["measurement"]).values)
    time = np.linspace(0, lever_trajectory.shape[1] / lever_trajectory.sample_rate, lever_trajectory.shape[1])
    ax.plot(time, _scale(lever_trajectory.values[0]) - 5, COLORS[1])
    for idx, row in enumerate(calcium_trace):
        ax.plot(time, row + idx * 5)
    for point in lever_trajectory.timestamps / lever_trajectory.sample_rate:  # trial onsets
        ax.axvline(x=point, color=COLORS[2])
Пример #4
0
def draw_cves(
    axis: axes.Axes, project: tp.Type[Project],
    revisions: tp.List[FullCommitHash], cve_line_width: int, cve_color: str,
    label_size: int, vertical_alignment: str
) -> None:
    """
    Annotates CVEs for a project in an existing plot.

    Args:
        axis: the axis to use for the plot
        project: the project to add CVEs for
        revisions: a list of revisions included in the plot in the order they
                   appear on the x-axis
        cve_line_width: the line width of CVE annotations
        cve_color: the color of CVE annotations
        label_size: the label size of CVE annotations
        vertical_alignment: the vertical alignment of CVE annotations
    """
    cmap = create_lazy_commit_map_loader(project.NAME)()
    revision_time_ids = [cmap.time_id(rev) for rev in revisions]

    cve_provider = CVEProvider.get_provider_for_project(project)
    for revision, cves in cve_provider.get_revision_cve_tuples():
        cve_time_id = cmap.time_id(revision)
        if cve_time_id in revision_time_ids:
            index = float(revisions.index(revision))
        else:
            # revision not in sample; draw line between closest samples
            index = len([x for x in revision_time_ids if x < cve_time_id]) - 0.5

        transform = axis.get_xaxis_transform()
        for cve in cves:
            axis.axvline(
                index,
                label=cve.cve_id,
                linewidth=cve_line_width,
                color=cve_color
            )
            axis.text(
                index + 0.1,
                0,
                cve.cve_id,
                transform=transform,
                rotation=90,
                size=label_size,
                color=cve_color,
                va=vertical_alignment
            )
Пример #5
0
def plot_peaks(ax: Axes, peak_list: List[Peak.Peak], label: str = "Peaks", style: str = 'o') -> List[Line2D]:
	"""
	Plots the locations of peaks as found by PyMassSpec.

	:param ax: The axes to plot the peaks on
	:param peak_list: List of peaks to plot
	:param label: label for plot legend.
	:param style: The marker style. See `https://matplotlib.org/3.1.1/api/markers_api.html` for a complete list

	:return: A list of Line2D objects representing the plotted data.
	"""

	if not is_peak_list(peak_list):
		raise TypeError("'peak_list' must be a list of Peak objects")

	time_list = []
	height_list = []

	if "line" in style.lower():
		lines = []
		for peak in peak_list:
			lines.append(ax.axvline(x=peak.rt, color="lightgrey", alpha=0.8, linewidth=0.3))

		return lines

	else:
		for peak in peak_list:
			time_list.append(peak.rt)
			height_list.append(sum(peak.mass_spectrum.intensity_list))
			# height_list.append(peak.height)
			# print(peak.height - sum(peak.mass_spectrum.intensity_list))
			# print(sum(peak.mass_spectrum.intensity_list))

		return ax.plot(time_list, height_list, style, label=label)
Пример #6
0
def draw_bugs(axis: axes.Axes, project: tp.Type[Project],
              revisions: tp.List[FullCommitHash], bug_line_width: int,
              bug_color: str, label_size: int,
              vertical_alignment: str) -> None:
    """
    Annotates bugs for a project in an existing plot.

    Args:
        axis: the axis to use for the plot
        project: the project to add bugs for
        revisions: a list of revisions included in the plot in the order they
                   appear on the x-axis
        bug_line_width: the line width of bug annotations
        bug_color: the color of bug annotations
        label_size: the label size of bug annotations
        vertical_alignment: the vertical alignment of bug annotations
    """
    cmap = create_lazy_commit_map_loader(project.NAME)()
    revision_time_ids = [cmap.time_id(rev) for rev in revisions]

    bug_provider = BugProvider.get_provider_for_project(project)
    for rawbug in bug_provider.find_raw_bugs():
        bug_time_id = cmap.time_id(rawbug.fixing_commit)
        if bug_time_id in revision_time_ids:
            index = float(revisions.index(rawbug.fixing_commit))
        else:
            # revision not in sample; draw line between closest samples
            index = len([x
                         for x in revision_time_ids if x < bug_time_id]) - 0.5

        label = " ".join([f"#{rawbug.issue_id}"])

        transform = axis.get_xaxis_transform()
        axis.axvline(index,
                     label=label,
                     linewidth=bug_line_width,
                     color=bug_color)
        axis.text(index + 0.1,
                  0,
                  label,
                  transform=transform,
                  rotation=90,
                  size=label_size,
                  color=bug_color,
                  va=vertical_alignment)
Пример #7
0
    def add_stat_line(
            ax: Axes,
            series_input: SeriesPlotIn,
            stat: Callable[[Iterable[float]], float] = np.mean,
            vert: bool = True,
            **kwargs
    ) -> Axes:
        
        if "linestyle" not in kwargs.keys():
            kwargs.update({"linestyle": "--"})
            
        stat_val: float = stat(series_input.data)
        
        if vert:
            ax.axvline(x=stat_val, color=series_input.color, **kwargs)
        else:
            ax.axhline(y=stat_val, color=series_input.color, **kwargs)

        return ax
Пример #8
0
def example_traces(ax: Axes, record_file: File, start: float, end: float, cells: Set[int]):
    """Visualize calcium trace of cells and the lever trajectory"""
    lever_trajectory = load_mat(record_file["response"])
    calcium_trace = DataFrame.load(record_file["measurement"])
    neuron_rate = record_file.attrs['frame_rate']
    l_start, l_end = np.rint(np.multiply([start, end], lever_trajectory.sample_rate)).astype(np.int_)
    c_start, c_end = np.rint(np.multiply([start, end], neuron_rate)).astype(np.int_)
    ax.plot(np.linspace(0, l_end - l_start, l_end - l_start),  # lever trajectory
            _scale(lever_trajectory.values[0][l_start: l_end]), COLORS[1])
    time = np.linspace(0, calcium_trace.shape[1] / neuron_rate, lever_trajectory.shape[1])
    spacing = iter(range(0, 500, 2))
    for idx, row in enumerate(calcium_trace.values):
        if idx in cells:
            ax.plot(time[c_start: c_end] - l_start, _scale(row[c_start: c_end]) + next(spacing))
    stim_onsets = lever_trajectory.timestamps[
        (lever_trajectory.timestamps > l_start) & (lever_trajectory.timestamps < l_end)]\
        / lever_trajectory.sample_rate - l_start
    for x in stim_onsets:
        ax.axvline(x=x, color=COLORS[2])
Пример #9
0
    def _set_figure(cls, ax: axes.Axes, energy_range: Sequence, dos_range: Sequence):
        """set figure and axes for plotting

        :params ax: matplotlib.axes.Axes object
        :params dos_range: range of dos
        :params energy_range: range of energy
        """

        # y-axis
        if dos_range:
            ax.set_ylim(dos_range[0], dos_range[1])
        ax.set_ylabel("DOS")

        # x-axis
        if energy_range:
            ax.set_xlim(energy_range[0], energy_range[1])

        # others
        ax.axvline(0, linestyle="--", c='b', lw=1.0)
        ax.legend()
Пример #10
0
def growth_curve(ax: Axes,
                 plate: Plate,
                 scatter_color: str,
                 line_color: str = None,
                 growth_params: bool = True):
    """
    Add a growth curve scatter plot, with median, to an axis

    :param ax: a Matplotlib Axes object to add a plot to
    :param plate: a Plate instance
    :param scatter_color: a Colormap color
    :param line_color: a Colormap color for the median
    """
    from statistics import median

    if line_color is None:
        line_color = scatter_color

    for colony in plate.items:
        ax.scatter(
            # Matplotlib does not yet support timedeltas so we have to convert manually to float
            [
                td.total_seconds() / 3600
                for td in sorted(colony.growth_curve.data.keys())
            ],
            list(colony.growth_curve.data.values()),
            color=scatter_color,
            marker="o",
            s=1,
            alpha=0.25)

    # Plot the median
    ax.plot([
        td.total_seconds() / 3600
        for td in sorted(plate.growth_curve.data.keys())
    ], [median(val) for _, val in sorted(plate.growth_curve.data.items())],
            color=line_color,
            label="Median" if growth_params else f"Plate {plate.id}",
            linewidth=2)

    if growth_params:
        # Plot lag, vmax and carrying capacity lines
        if plate.growth_curve.lag_time.total_seconds() > 0:
            line = ax.axvline(plate.growth_curve.lag_time.total_seconds() /
                              3600,
                              color="grey",
                              linestyle="dashed",
                              alpha=0.5)
            line.set_label("Lag time")

        if plate.growth_curve.carrying_capacity > 0:
            line = ax.axhline(plate.growth_curve.carrying_capacity,
                              color="blue",
                              linestyle="dashed",
                              alpha=0.5)
            line.set_label("Carrying\ncapacity")

        if plate.growth_curve.growth_rate > 0:
            y0, y1 = 0, plate.growth_curve.carrying_capacity
            x0 = plate.growth_curve.lag_time.total_seconds() / 3600
            x1 = ((y1 - y0) / (plate.growth_curve.growth_rate * 3600)) + x0
            ax.plot([x0, x1], [y0, y1],
                    color="red",
                    linestyle="dashed",
                    alpha=0.5,
                    label="Maximum\ngrowth rate")
    def show_agent_opinions(
        self,
        t=-1,
        direction=True,
        sort=False,
        ax: Axes = None,
        fig: Figure = None,
        colorbar: bool = True,
        title: str = "Agent opinions",
        show_middle=True,
        **kwargs,
    ) -> Tuple[Figure, Axes]:
        cmap = kwargs.pop("cmap", OPINIONS_CMAP)

        idx = get_time_point_idx(self.sn.result.t, t)
        opinions = self.sn.result.y[:, idx]
        agents = np.arange(self.sn.N)
        if not direction:
            # only magnitude
            opinions = np.abs(opinions)

        if np.iterable(sort) or sort:

            if isinstance(sort, np.ndarray):
                # sort passed as indices
                ind = sort
            else:
                logger.warning(
                    "sorting opinions for `show_agent_opinions` means agent indices are jumbled"
                )
                # sort by opinion
                ind = np.argsort(opinions)
            opinions = opinions[ind]

        v = self._get_equal_opinion_limits()
        sm = ScalarMappable(norm=Normalize(*v), cmap=cmap)
        color = sm.to_rgba(opinions)

        ax.barh(
            agents,
            opinions,
            color=color,
            edgecolor="None",
            linewidth=0,  # remove bar borders
            height=1,  # per agent
            **kwargs,
        )
        ax.axvline(x=0, ls="-", color="k", alpha=0.5, lw=1)
        if (np.iterable(sort) or sort) and show_middle:
            min_idx = np.argmin(np.abs(opinions))
            ax.hlines(
                y=min_idx,
                xmin=v[0],
                xmax=v[1],
                ls="--",
                color="k",
                alpha=0.5,
                lw=1,
            )
            ax.annotate(
                f"{min_idx}",
                xy=(np.min(opinions), min_idx),
                fontsize="small",
                color="k",
                alpha=0.5,
                va="bottom",
                ha="left",
            )
        if colorbar:
            # create colorbar axes without stealing from main ax
            cbar = colorbar_inset(sm,
                                  "outer bottom",
                                  size="5%",
                                  pad=0.01,
                                  ax=ax)
            sns.despine(ax=ax, bottom=True)
            ax.tick_params(axis="x", bottom=False, labelbottom=False)
            cbar.set_label(OPINION_SYMBOL)

        ax.set_ylim(0, self.sn.N)
        ax.set_xlim(*v)
        if not colorbar:
            # xlabel not part of colorbar
            ax.set_xlabel(OPINION_SYMBOL)
        ax.set_ylabel("Agent $i$")
        ax.yaxis.set_major_locator(MaxNLocator(5))
        if title:
            ax.set_title(title)
        return fig, ax
Пример #12
0
class Bars:
    """This class represents a complex horizontal bar chart.

    This class extends (by composition) the functionality provided
    by Matplotlib.
    The chart is automatically rendered in Jupyter notebooks and can
    be saved on disk.
    The chart can be tailored to a great extent by passing keyword
    arguments to the constructor. (SEE the class attribute **Bars.conf**
    for listing the other optional **kwargs**).
    If it is not enough, the **conf.py** module in the Catbars package
    gives users full control over "rcParams".

    Parameters
    -----------    

    numbers : iterable container
        The numbers specifying the width of each bar. First numbers are
        converted into bars appearing on the top of the figure.

    left_labels : iterable container or str, optional
        Labels associated with the bars on the left.
        The "rank" option creates one-based indices.
        The "proportion" option creates labels representing the
        relative proportion of each bar in percents.
        "rank" and "proportion" labels depend on the "slice" unless
        "global_view" is True.

    right_labels : iterable container or str, optional
        Labels associated with the bars on the right. It accepts the same
        values as "left_labels".

    colors : iterable container, optional
        The container items can be of any type. Bar colors are
        automatically inferred in function of the available "tints" (SEE
        Bars.conf) and the most common items in the "slice" (unless
        "global_view is True). If there are more distinct items than
        available "tints", "default_color" and "default_label" are used with
        residual items. The automatic color selection can be overriden
        by "color_dic".

    line_dic : dict, optional
        This dictionary has to contain three keys: "number", "color" and 
        "label". It describes an optional vertical line to 
        draw.

    sort : bool, optional
        If True, "numbers" are sorted in descending order. Optional labels
        and the "colors" parameter are sorted in the same way. The default
        value is False.

    slice : tuple : (start,stop), optional
        start and stop are one-based indices. Slicing precedes sorting unless
        "global_view" is True.

    global_view : bool, optional
        If True, the whole dataset is considered instead of the optional
        slice when sorting, coloring, setting x bounds and creating
        "rank" and "proportion" labels. The default value is False.

    auto_scale : bool, optional
        If True, the logarithmic scale is used when it seems better for
        readability. The default value is False.

    color_dic : dict, optional
        A dictionary mapping "colors" items (keys) to Matplotlib colors
        (values)."colors" items which are not specified by the dic. are
        treated as residual items (SEE "colors").

    title : str, optional
        Figure title.

    xlabel : str, optional
        The Matplotlib xlabel.

    ylabel : str, optional
        The Matplotlib ylabel.

    legend_title : str, optional

    legend_visible : bool, optional
        The default value is True.

    figsize : (width, height), optional
        The Matplotlib figsize. The default value is (6,5).

    dpi : number, optional
        The Matplotlib dpi. The default value is 100.

    file_name : str or path-like or file-like object
        The path of the png file to write to. (SEE the method print_pdf()
        for writing pdf files).

    Returns
    --------
    catbars.bars.Bars
        A Bars instance. It encapsulates useful Matplotlib objects.
    
    
    Attributes
    -----------

    conf : dict
        This class attribute contains the advanced optional
        constructor parameters along with their current values.
        In particular, it contains the "fig_size", "dpi",  "tints",
        "default_color" and "default_label" values.
    fig : matplotlib.figure.Figure
    ax : matplotlib.axes.Axes
    canvas : matplotlib.backends.backend_agg.FigureCanvasAgg
    data : catbars.models.AbstractModel
        The Bars class delegates to another class data processing tasks.

    Methods
    -------
    print_png(file_name)
        To write png files.
    print_pdf(file_name)
        To write pdf files.

    

    """

    conf = Conf.conf

    def __init__(self,
                 numbers,
                 left_labels = None,
                 right_labels = None, # 'proportion' 'rank'
                 colors = None,
                 line_dic = None,
                 sort = False,
                 slice = None, # one-based indexing
                 global_view = False,
                 auto_scale = False,
                 color_dic = None,
                 title = None,
                 xlabel = None,
                 ylabel = None,
                 legend_title = None,
                 legend_visible = True,
                 file_name = None,
                 **kwargs):
        """
        The data space can adapt to long labels but only to
        some extent because the long label sizes are fixed.
        This class moves the edges of the axes to make room
        for labels (SEE Matplotlib HOW-TOs).
        """                        
        if 'log_level' in kwargs:
            logging.basicConfig(format='{levelname}:\n{message}',
                                level= getattr(logging, kwargs['log_level']),
                                style = '{')
            
        # Configuration: matplotlibrc is decorated by conf.py.
        self.conf = Conf.change_conf(kwargs)
        
        # Data formatted by the model.
        self.data = None
        
        
        # Core Matplotlib objects.
        self.fig = None
        self.ax = None
        self.canvas = None
        self.vertical_line = None
        self.bars = None # BarContainer.
        self._virtual_bars  = None # For global_view.

        # Helper attributes.
        self._left_label_data = None
        self._right_label_texts = None
        
        # Titles.
        self.title = title
        self.xlabel = xlabel
        self.ylabel = ylabel
        self.legend_title = legend_title
        
        self._global_view = global_view
        self.legend = None
        self._legend_width = 0
        self.legend_visible = legend_visible
        
        # The vertical line.        
        self.line_x = None
        self.line_label = None
        self.line_color = None
        
        
        # Original position of the axes edges in the figure.
        self._x0 = 0
        self._y0 = 0
        self._width = 1
        self._height = 1



        # To deal with not square figures,
        # only x sizes are adapted.
        self._x_coeff = 1

        

        #########################################################

        # Model.
        factory = ModelFactory(
            numbers,
            global_view = global_view,
            left_labels = left_labels,
            right_labels = right_labels,
            colors = colors,
            sort = sort,
            slice = slice,
            default_label = self.conf['default_label'],
            color_dic = color_dic,
            tints = self.conf['tints'],
            default_color = self.conf['default_color'])
        


        self.data = factory.model
                        
        
        self.fig = Figure(figsize = self.conf['figsize'],
                          dpi = self.conf['dpi'])
        
        self.canvas = FigureCanvasAgg(self.fig)
        
        
        self.ax = Axes(self.fig,
                       [self._x0,
                        self._y0,
                        self._width,
                        self._height])
        
        
        self.fig.add_axes(self.ax)
        self.canvas.draw()
       
        w, h = self.fig.get_size_inches()
        self._x_coeff = h / w
        
        # margin.
        margin = self.conf['margin']
        
        self._x0 = self._x_coeff * margin
        self._y0 = margin
        self._width = self._width - 2 * self._x_coeff * margin
        self._height = self._height - 2 * margin
        
        self._set_position()

                
        self.ax.tick_params(axis = 'y',
                            length = 0)
        
        self.ax.grid(b = True,
                     axis = 'x',
                     which = 'both',
                     color = 'black',
                     alpha = 0.3)
        
        
        for name in ['top', 'right']:
            self.ax.spines[name].set_visible(False)
        
        # xscale.
        if auto_scale is True:
            # To improve clarity.
            if self.data.spread > 1 or self.data.maximum > 1e6:
                self.ax.set(xscale = 'log')
        else:
            default_formatter = self.ax.get_xaxis().get_major_formatter()
            custom_formatter = self.build_formatter(default_formatter)
            formatter = matplotlib.ticker.FuncFormatter(custom_formatter)
            self.ax.get_xaxis().set_major_formatter(formatter)
        
        
        # Title.
        if self.title is not None:
            self._manage_title()

        
        _kwargs = dict()

        # Left labels.
        if self.data.left_labels is not None:
            _kwargs['tick_label'] = self.data.left_labels 
        else:
            _kwargs['tick_label'] = ''

        # colors.
        if self.data.actual_colors is not None:
            _kwargs['color'] = self.data.actual_colors
        else:
            _kwargs['color'] = self.data.default_color

        # bars.
        self.bars = self.ax.barh(list(range(self.data.length)),
                                 self.data.numbers,
                                 height = 1,
                                 edgecolor = 'white',
                                 linewidth = 1, # 0.4
                                 alpha = self.conf['color_alpha'],
                                 **_kwargs)

        # To fix x bounds, virtual bars are used.
        if self._global_view is True:
            self._virtual_bars = self.ax.barh(
                [0, 0],
                [self.data.minimum,
                 self.data.maximum],
                height = 0.5,
                edgecolor = 'white',
                linewidth = 1, # 0.4
                alpha = self.conf['color_alpha'],
                visible = False)

        
        # The vertical line.
        if line_dic is not None:
            self._set_line(line_dic)
        
            if (self.line_x is not None and
                self.data.minimum <= self.line_x <= self.data.maximum):
                #
                self.vertical_line = self.ax.axvline(
                    self.line_x,
                    ymin = 0,
                    ymax = 1,
                    color = self.line_color,
                    linewidth = 2,
                    alpha = self.conf['color_alpha'])
        

        # Left label constraint solving.
        self._make_room_for_left_labels()

        # ylabel.
        if self.ylabel is not None:
            self._manage_ylabel()
        
        
        # Legend.
        if (self.legend_visible is True and
            self.data.colors is not None):
            #
            self._draw_legend()
            self._make_room_for_legend()
        
        
        # Right labels.
        if self.data.right_labels is not None:
            self._draw_right_labels()            
            self._make_room_for_right_labels()
        
        
        min_tick_y = self._clean_x_ticklabels()
        
        # xlabel.
        if self.xlabel is not None:
            self._manage_xlabel(min_tick_y)
        else:
            delta_y0 = abs(self._y0 - min_tick_y)
            self._y0 = self._y0 + delta_y0
            self._height = self._height - delta_y0
            self._set_position()

        
        self.canvas.draw()
        
        # Printing.
        if file_name is not None:
            self.canvas.print_png(file_name)
        

        
        #############################################################
        
    
    def _set_line(self, line_dic):
        try:
            self.line_x = line_dic['number']
            self.line_label = line_dic['label']
            self.line_color = line_dic['color']
        except Exception:
            text = """
"line_dic" has to define three keys: 'number', 'label' and 'color'.
"""
            raise TypeError(text.strip())


    def _manage_title(self):
        
        pad_in_points = self.fig_coord_to_points(self.fig,
                                                 self.conf['title_pad'],
                                                 axis = 'y')
        title_label = self.ax.set_title(
            self.title,
            pad = pad_in_points,
            fontsize = self.conf['title_font_size'],
            fontweight = 'bold')
        
        self.canvas.draw()
        
        h = title_label.get_window_extent(
            renderer = self.canvas.get_renderer()
            ).height
        
        h_in_fig_coord = self.disp_to_fig_coord(self.fig,
                                                h,
                                                axis = 'y')
        total_h = (h_in_fig_coord +
                   self.conf['title_pad'])
        
        self._height = self._height - total_h
        self._set_position()
    
    
    def _make_room_for_left_labels(self):
        
        """
        Constraint solving for left labels.
        "left_label_data" is stored for further processing and
        will be used to align left and right labels.
        """
        
        
        left_label_data = [] # To align left and right labels.
        min_x = None
        self.canvas.draw()
        for left_label in self.ax.get_yticklabels():
            x, y = left_label.get_position()
            va = left_label.get_va()
            bbox = left_label.get_window_extent(
                    renderer = self.canvas.get_renderer()
                 )
            
            inv = self.fig.transFigure.inverted()
            lab_x, _ = inv.transform((bbox.x0, bbox.y0))
            
            if min_x is None or lab_x < min_x:
                min_x = lab_x # In pixels.
            left_label_data.append((y, va))
        
        delta_x0 = abs(self._x0 - min_x)
        self._x0 = self._x0 + delta_x0
        self._width = self._width - delta_x0
        self._set_position()
        
        self._left_label_data = left_label_data
    
    
    
    def _manage_ylabel(self):
        
        """
        """
        pad = self.fig_coord_to_points(self.fig,
                                       self._x_coeff * self.conf['pad'])
        y_label = self.ax.set_ylabel(
            self.ylabel,
            labelpad = pad,
            fontweight = 'bold',
            fontsize = self.conf['axis_title_font_size'])
        
        self.canvas.draw()
        
        bbox = y_label.get_window_extent(
                    renderer = self.canvas.get_renderer()
                 )
        

        w_in_fig_coord = self.disp_to_fig_coord(self.fig,
                                                bbox.width)
        
        delta_x0 = (w_in_fig_coord +
                    self._x_coeff * self.conf['pad'])
        
        self._x0 = self._x0 + delta_x0
        self._width = self._width - delta_x0 
        self._set_position()
    
        

    def _draw_legend(self):
                
        artists = []
        labels = []
        for i, color in enumerate(self.data.legend_colors):
            # Proxy artists.
            patch = mpatches.Patch(facecolor = color,
                                   alpha = self.conf['color_alpha'])
            artists.append(patch)
            labels.append(self.data.legend_labels[i])
        
        if self.vertical_line is not None:
            artists.append(self.vertical_line)
            labels.append(self.line_label)
        
        kwargs = dict()
        if self.legend_title is not None:
            kwargs['title'] = self.legend_title

        lgd = self.fig.legend(artists,
                              labels,
                              loc ='center left',
                              frameon = False,
                              labelspacing = 0.25,
                              borderpad = 0,
                              borderaxespad = 0,
                              prop = {
                              'size' : self.conf['axis_title_font_size']},
                              **kwargs)
        lgd.get_title().set_fontsize(self.conf['axis_title_font_size'])
        lgd.get_title().set_fontweight('bold')
        lgd.get_title().set_multialignment('center')

        
        self.canvas.draw()
        
        # Constraint solving.
        lgd_width = (lgd.get_window_extent(
                          renderer = self.canvas.get_renderer()
                          ).width)
        
        lgd_width_in_fig_coord = self.disp_to_fig_coord(self.fig,
                                                        lgd_width)
        self.legend = lgd
        self._legend_width = lgd_width_in_fig_coord
        
        logging.info('legend width in pixels {}\n'.format(lgd_width))
        
    

    def _make_room_for_legend(self):
        self.legend.set_bbox_to_anchor((1 -
                                        self._legend_width -
                                        self._x_coeff * self.conf['margin'],
                                        0.5))
        self._width = (self._width -
                       self._legend_width -
                       self._x_coeff *self.conf['pad'])
        self._set_position()
    
    
        
    def _draw_right_labels(self):
        
        """
        Right labels.
        """
        right_label_texts = []
        for i, bar in enumerate(self.bars):
            y, va = self._left_label_data[i]
            w = bar.get_width()
            t = None
            if self.data.right_labels is not None:
                a_right_label = self.data.right_labels[i]
                text = ' {}'.format(a_right_label)
                t = self.ax.text(w, y,
                                 text,
                                 verticalalignment = va,
                                 fontweight = 'normal',
                                 zorder = 10)
                right_label_texts.append(t)
                self.canvas.draw()
        self._right_label_texts = right_label_texts
        
                                
    
    def _make_room_for_right_labels(self):
        
        """
        Constraint solving in figure coordinates.
        A bisection technique is used.
        """
        def _objective_function(coeff_array,
                                label_array,
                                x):
            #
            return max(x, max(coeff_array * x + label_array))

        bar_coeff = []
        text_widths = []
        for i, bar in enumerate(self.bars):
            bar_coeff.append(self._get_bar_coeff(bar))
            t = self._right_label_texts[i]
            text_widths.append(self._get_text_width(t))
        
        coeff_array = np.array(bar_coeff)
        label_array = np.array(text_widths)
        
        f = partial(_objective_function,
                    coeff_array,
                    text_widths)
        
        min_w = self.conf['min_ax_width']
        max_it = self.conf['right_label_max_it']
        tolerance = self.conf['right_label_solver_tolerance']
        
        # Two special cases.
        if f(self._width) == self._width:        
            pass
        # To check whether a solution exists.
        elif f(min_w) < self._width:
            w_b = self._width
            w_a = min_w
            i = 0
            # To prevent from infinite loops.
            while abs(w_b - w_a) > tolerance and i < max_it:
                new_w = w_a + (w_b - w_a) / 2
                if f(new_w) < self._width:
                    w_a = new_w
                else:
                    w_b = new_w
                logging.info('w_a {}\nw_b {}\n'.format(w_a, w_b))
                i += 1
            self._width = w_a
        else:
            self._width = min_w

        self._set_position()

        if i == max_it:
            logging.warning("""
right_label_max_it {} has been hit.
""".format(max_it))
    
    
    def _get_bar_coeff(self, bar):
        """
        bar_width_in_ax_coord can't be greater than 0.95 if
        xmargin = 0.05.
        """
        data_x_one = bar.get_bbox().x1 # Assuming that x0 = 0.
        disp_x_one, _ = self.ax.transData.transform((data_x_one, 0))
        inv = self.ax.transAxes.inverted()
        bar_width_in_ax_coord, _ = inv.transform((disp_x_one, _))
        return bar_width_in_ax_coord


    def _get_text_width(self, t):
        
        t_width = t.get_window_extent(
            renderer = self.canvas.get_renderer()
            ).width # In pixels.
        
        return self.disp_to_fig_coord(self.fig, t_width)
    
    
            
    def _clean_x_ticklabels(self):
        """
        To discard overlaps.
        """
        
        self.canvas.draw()
            
        labels = self.get_visible_ticklabels(
                        self.ax,
                        self.ax.xaxis.get_ticklabels(which = 'both')
                        )
        
        label_bboxes = [lab.get_window_extent(
                             renderer = self.canvas.get_renderer()
                             ) 
                        for lab in labels]
        
        
        current_bbox = label_bboxes[-1]
        min_tick_y = current_bbox.y0
        for i in range(len(label_bboxes) - 1,
                       0,
                       -1):
            if label_bboxes[i-1].overlaps(current_bbox):
                labels[i-1].set_visible(False)
            else:
                current_bbox = label_bboxes[i-1]
                if current_bbox.y0 < min_tick_y:
                    min_tick_y = current_bbox.y0
        inv = self.fig.transFigure.inverted()
        _, tick_y = inv.transform((0, min_tick_y))
        return tick_y
    

    
    def _manage_xlabel(self,
                       min_tick_y):
        """
        min_tick_y is negative.
        """
        pad = self.fig_coord_to_points(self.fig,
                                       self.conf['pad'],
                                       axis = 'y')
        
        x_label = self.ax.set_xlabel(
            self.xlabel,
            labelpad = pad,
            fontweight = 'bold',
            fontsize = self.conf['axis_title_font_size'])
        
        self.canvas.draw()
        bbox = x_label.get_window_extent(
                    renderer = self.canvas.get_renderer()
                 )
        h = self.disp_to_fig_coord(self.fig,
                                   bbox.height,
                                   axis = 'y')
        
        delta_y0 = abs(self._y0 - min_tick_y) + h + self.conf['pad']
        
        self._y0 = self._y0 + delta_y0
        self._height = self._height - delta_y0
        self._set_position()
        
        self.canvas.draw()

    
    def _set_position(self):
        self.ax.set_position([self._x0,
                              self._y0,
                              self._width,
                              self._height])
        positions = ['x0', 'y0', 'width', 'height']
        text = 'Position of the Axes instance edges\n'
        for pos in positions:
            text = text + '{} {}\n'.format(pos, getattr(self, '_'+pos))
        logging.info(text)
                
                
        
    def disp_to_fig_coord(self,
                          fig,
                          dist,
                          axis = 'x'):
        """
        Conversion of a distance from display coordinates
        to figure coordinates.
        """
        w, h = fig.get_size_inches()
        if axis == 'x':
            return dist / (fig.dpi * w)
        else:
            return dist / (fig.dpi * h)

    
    def points_to_fig_coord(self,
                            fig,
                            points,
                            axis = 'x'):
        """
        axis = 'x' refers to the X axis ('y' corresponds to the Y axis).
        There are 72 points per inch.
        """
        w, h = fig.get_size_inches()
        if axis == 'x':
            return (points * 1 / 72) / w
        else:
            return (points * 1 / 72) / h


    def fig_coord_to_points(self,
                            fig,
                            fraction,
                            axis = 'x'):
        """
        axis = 'x' refers to the X axis ('y' corresponds to the Y axis).
        Conversion from figure coordinates to points.
        """
        w, h = fig.get_size_inches()
        if axis == 'x':
            return fraction * w * 72
        else:
            return fraction * h * 72

    
        
    def get_visible_ticklabels(self,
                               ax,
                               labels):
        """
        Only a part of the built labels are displayed by
        the Matplotlib machinary.
        """
        visible_labels = []
        
        x_min, x_max = ax.get_xlim()
        
        for label in labels:
            x = label.get_position()[0]
            if x_min <= x <= x_max:
                if label.get_visible() and label.get_text():
                    visible_labels.append(label)
        return visible_labels


    def build_formatter(self, default_formatter):
        """
        Custom scientific notation.
        """
        def f(default_f, x, pos):
            if x > 1e6 or x < 1e-3:
                text = '{:.1e}'.format(x)
                n, e = text.split('e')
                if float(n) == 0:
                    return 0
                e = '{'+ e.lstrip('0+') + '}'
                label = r'${} \times 10^{}$'.format(n, e)
                return label
            else:
                return default_f(x, pos)
        
        return partial(f, default_formatter)


    
    def print_pdf(self, file_name):
        
        from matplotlib.backends.backend_pdf import PdfPages

        pp = PdfPages(file_name)
        pp.savefig(figure = self.fig)
        pp.close()

    def print_png(self, file_name):
        self.canvas.print_png(file_name)


    def _repr_png_(self):
        """
        For notebook integration.
        """
        w, h = self.fig.get_size_inches()
        buf  = BytesIO() # In-memory bytes buffer.
        self.canvas.print_png(buf)
        return (buf.getvalue(),
                {'width' : str(w * self.fig.dpi),
                 'height': str(h * self.fig.dpi)})
Пример #13
0
def plot_pianoroll(
    ax: Axes,
    pianoroll: ndarray,
    is_drum: bool = False,
    resolution: Optional[int] = None,
    downbeats: Optional[Sequence[int]] = None,
    preset: str = "full",
    cmap: str = "Blues",
    xtick: str = "auto",
    ytick: str = "octave",
    xticklabel: bool = True,
    yticklabel: str = "auto",
    tick_loc: Sequence[str] = ("bottom", "left"),
    tick_direction: str = "in",
    label: str = "both",
    grid_axis: str = "both",
    grid_linestyle: str = ":",
    grid_linewidth: float = 0.5,
    **kwargs,
):
    """
    Plot a piano roll.

    Parameters
    ----------
    ax : :class:`matplotlib.axes.Axes`
        Axes to plot the piano roll on.
    pianoroll : ndarray, shape=(?, 128), (?, 128, 3) or (?, 128, 4)
        Piano roll to plot. For a 3D piano-roll array, the last axis can
        be either RGB or RGBA.
    is_drum : bool
        Whether it is a percussion track. Defaults to False.
    resolution : int
        Time steps per quarter note. Required if `xtick` is 'beat'.
    downbeats : list
        Boolean array that indicates whether the time step contains a
        downbeat (i.e., the first time step of a bar).
    preset : {'full', 'frame', 'plain'}
        Preset theme. For 'full' preset, ticks, grid and labels are on.
        For 'frame' preset, ticks and grid are both off. For 'plain'
        preset, the x- and y-axis are both off. Defaults to 'full'.
    cmap : str or :class:`matplotlib.colors.Colormap`
        Colormap. Will be passed to :func:`matplotlib.pyplot.imshow`.
        Only effective when `pianoroll` is 2D. Defaults to 'Blues'.
    xtick : {'auto', 'beat', 'step', 'off'}
        Tick format for the x-axis. For 'auto' mode, set to 'beat' if
        `resolution` is given, otherwise set to 'step'. Defaults to
        'auto'.
    ytick : {'octave', 'pitch', 'off'}
        Tick format for the y-axis. Defaults to 'octave'.
    xticklabel : bool
        Whether to add tick labels along the x-axis.
    yticklabel : {'auto', 'name', 'number', 'off'}
        Tick label format for the y-axis. For 'name' mode, use pitch
        name as tick labels. For 'number' mode, use pitch number. For
        'auto' mode, set to 'name' if `ytick` is 'octave' and 'number'
        if `ytick` is 'pitch'. Defaults to 'auto'.
    tick_loc : sequence of {'bottom', 'top', 'left', 'right'}
        Tick locations. Defaults to `('bottom', 'left')`.
    tick_direction : {'in', 'out', 'inout'}
        Tick direction. Defaults to 'in'.
    label : {'x', 'y', 'both', 'off'}
        Whether to add labels to x- and y-axes. Defaults to 'both'.
    grid_axis : {'x', 'y', 'both', 'off'}
        Whether to add grids to the x- and y-axes. Defaults to 'both'.
    grid_linestyle : str
        Grid line style. Will be passed to
        :meth:`matplotlib.axes.Axes.grid`.
    grid_linewidth : float
        Grid line width. Will be passed to
        :meth:`matplotlib.axes.Axes.grid`.
    **kwargs
        Keyword arguments to be passed to
        :meth:`matplotlib.axes.Axes.imshow`.

    """
    # Plot the piano roll
    if pianoroll.ndim == 2:
        transposed = pianoroll.T
    elif pianoroll.ndim == 3:
        transposed = pianoroll.transpose(1, 0, 2)
    else:
        raise ValueError("`pianoroll` must be a 2D or 3D numpy array")

    img = ax.imshow(
        transposed,
        cmap=cmap,
        aspect="auto",
        vmin=0,
        vmax=1 if pianoroll.dtype == np.bool_ else 127,
        origin="lower",
        interpolation="none",
        **kwargs,
    )

    # Format ticks and labels
    if xtick == "auto":
        xtick = "beat" if resolution is not None else "step"
    elif xtick not in ("beat", "step", "off"):
        raise ValueError(
            "`xtick` must be one of 'auto', 'beat', 'step' or 'off', not "
            f"{xtick}.")
    if yticklabel == "auto":
        yticklabel = "name" if ytick == "octave" else "number"
    elif yticklabel not in ("name", "number", "off"):
        raise ValueError(
            "`yticklabel` must be one of 'auto', 'name', 'number' or 'off', "
            f"{yticklabel}.")

    if preset == "full":
        ax.tick_params(
            direction=tick_direction,
            bottom=("bottom" in tick_loc),
            top=("top" in tick_loc),
            left=("left" in tick_loc),
            right=("right" in tick_loc),
            labelbottom=xticklabel,
            labelleft=(yticklabel != "off"),
            labeltop=False,
            labelright=False,
        )
    elif preset == "frame":
        ax.tick_params(
            direction=tick_direction,
            bottom=False,
            top=False,
            left=False,
            right=False,
            labelbottom=False,
            labeltop=False,
            labelleft=False,
            labelright=False,
        )
    elif preset == "plain":
        ax.axis("off")
    else:
        raise ValueError(
            f"`preset` must be one of 'full', 'frame' or 'plain', not {preset}"
        )

    # Format x-axis
    if xtick == "beat" and preset != "frame":
        if resolution is None:
            raise ValueError(
                "`resolution` must not be None when `xtick` is 'beat'.")
        n_beats = pianoroll.shape[0] // resolution
        ax.set_xticks(resolution * np.arange(n_beats) - 0.5)
        ax.set_xticklabels("")
        ax.set_xticks(resolution * (np.arange(n_beats) + 0.5) - 0.5,
                      minor=True)
        ax.set_xticklabels(np.arange(1, n_beats + 1), minor=True)
        ax.tick_params(axis="x", which="minor", width=0)

    # Format y-axis
    if ytick == "octave":
        ax.set_yticks(np.arange(0, 128, 12))
        if yticklabel == "name":
            ax.set_yticklabels(["C{}".format(i - 2) for i in range(11)])
    elif ytick == "step":
        ax.set_yticks(np.arange(0, 128))
        if yticklabel == "name":
            if is_drum:
                ax.set_yticklabels(
                    [note_number_to_drum_name(i) for i in range(128)])
            else:
                ax.set_yticklabels(
                    [note_number_to_name(i) for i in range(128)])
    elif ytick != "off":
        raise ValueError(
            f"`ytick` must be one of 'octave', 'pitch' or 'off', not {ytick}.")

    # Format axis labels
    if label not in ("x", "y", "both", "off"):
        raise ValueError(
            f"`label` must be one of 'x', 'y', 'both' or 'off', not {label}.")

    if label in ("x", "both"):
        if xtick == "step" or not xticklabel:
            ax.set_xlabel("time (step)")
        else:
            ax.set_xlabel("time (beat)")

    if label in ("y", "both"):
        if is_drum:
            ax.set_ylabel("key name")
        else:
            ax.set_ylabel("pitch")

    # Plot the grid
    if grid_axis not in ("x", "y", "both", "off"):
        raise ValueError(
            "`grid` must be one of 'x', 'y', 'both' or 'off', not "
            f"{grid_axis}.")
    if grid_axis != "off":
        ax.grid(
            axis=grid_axis,
            color="k",
            linestyle=grid_linestyle,
            linewidth=grid_linewidth,
        )

    # Plot downbeat boundaries
    if downbeats is not None:
        for downbeat in downbeats:
            ax.axvline(x=downbeat, color="k", linewidth=1)

    return img
def show_perf_vs_size(
    x_list: List[np.ndarray],
    y_list: List[np.ndarray],
    label_list: List[str],
    *,
    xlabel: str = None,
    ylabel: str = None,
    title: str = None,
    ax: Axes = None,
    xticks=(0, 25, 50, 75, 100),
    yticks=(0, 0.5, 1),
    xlim=(0, 100),
    ylim=(0, 1),
    xticklabels=('0', '25', '50', '75', '100'),
    yticklabels=('0', '0.5', '1'),
    style_list=None,
    linewidth=1,
    show_legend=True,
    legend_param=None,
    vline=None,
    hline=None,
    xlabel_param=None,
    # letter=None,
):
    """x being model size, number of parameter, dataset size, etc.
    y being performance.
    """

    if style_list is None:
        # should give a default set
        raise NotImplementedError

    if xlabel_param is None:
        xlabel_param = dict()

    # if letter is not None:
    #     ax.text(0, 1, letter, horizontalalignment='left', verticalalignment='top',
    #             transform=ax.get_figure().transFigure, fontweight='bold')

    assert len(x_list) == len(y_list) == len(label_list)
    for idx, (x_this, y_this,
              label_this) in enumerate(zip(x_list, y_list, label_list)):
        linestyle, color, marker = style_list[idx]
        ax.plot(x_this,
                y_this,
                linestyle=linestyle,
                color=color,
                marker=marker,
                label=label_this,
                linewidth=linewidth)

    if vline is not None:
        # color maybe adjusted later
        ax.axvline(vline, color='black', linewidth=linewidth, linestyle='--')

    if hline is not None:
        # color maybe adjusted later
        ax.axhline(hline, color='black', linewidth=linewidth, linestyle='--')

    # ax.set_xlim(0, 1)
    ax.set_xlim(*xlim)
    ax.set_ylim(*ylim)
    ax.set_xticks(xticks)
    ax.set_yticks(yticks)
    ax.set_xticklabels(xticklabels, **xlabel_param)
    ax.set_yticklabels(yticklabels)

    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)
    if title is not None:
        ax.set_title(title)

    if show_legend:
        if legend_param is None:
            ax.legend()
        else:
            ax.legend(**legend_param)
Пример #15
0
    def plot_z_trend_histogram(self,
                               axis: Axes = None,
                               polar: bool = True,
                               normed: bool = True) -> None:

        if axis is None:
            axis = self.figure.add_subplot(111)

        cluster = Cluster(simulation_name=self.simulation.simulation_name,
                          clusterID=0,
                          redshift='z000p000')
        aperture_float = self.get_apertures(cluster)[
            self.aperture_id] / cluster.r200

        if not os.path.isfile(
                os.path.join(
                    self.path,
                    f'redshift_rot0rot4_histogram_aperture_{self.aperture_id}.npy'
                )):
            warnings.warn(
                f"File redshift_rot0rot4_histogram_aperture_{self.aperture_id}.npy not found."
            )
            print("self.make_simhist() activated.")
            self.make_simhist()

        print(
            f"Retrieving npy files: redshift_rot0rot4_histogram_aperture_{self.aperture_id}.npy"
        )
        sim_hist = np.load(os.path.join(
            self.path,
            f'redshift_rot0rot4_histogram_aperture_{self.aperture_id}.npy'),
                           allow_pickle=True)
        sim_hist = np.asarray(sim_hist)

        if normed:
            norm_factor = np.sum(self.simulation.sample_completeness)
            sim_hist[2] /= norm_factor
            sim_hist[3] /= norm_factor
            y_label = r"Sample fraction"
        else:
            y_label = r"Number of samples"

        items_labels = f""" REDSHIFT TRENDS - HISTOGRAM
							Number of clusters: {self.simulation.totalClusters:d}
							$z$ = 0.0 - 1.8
							Total samples: {np.sum(self.simulation.sample_completeness):d} $\equiv N_\mathrm{{clusters}} \cdot N_\mathrm{{redshifts}}$
							Aperture radius = {aperture_float:.2f} $R_{{200\ true}}$"""
        print(items_labels)

        sim_colors = {
            'ceagle': 'pink',
            'celr_e': 'lime',
            'celr_b': 'orange',
            'macsis': 'aqua',
        }

        axis.axvline(90, linestyle='--', color='k', alpha=0.5, linewidth=2)
        axis.step(sim_hist[0],
                  sim_hist[2],
                  color=sim_colors[self.simulation.simulation_name],
                  where='mid')
        axis.fill_between(sim_hist[0],
                          sim_hist[2] + sim_hist[3],
                          sim_hist[2] - sim_hist[3],
                          step='mid',
                          color=sim_colors[self.simulation.simulation_name],
                          alpha=0.2,
                          edgecolor='none',
                          linewidth=0)

        axis.set_ylabel(y_label, size=25)
        axis.set_xlabel(
            r"$\Delta \theta \equiv (\mathbf{L}_\mathrm{gas},\mathrm{\widehat{CoP}},\mathbf{L}_\mathrm{stars})$\quad[degrees]",
            size=25)
        axis.set_xlim(0, 180)
        axis.set_ylim(0, 0.1)
        axis.text(0.03,
                  0.97,
                  items_labels,
                  horizontalalignment='left',
                  verticalalignment='top',
                  transform=axis.transAxes,
                  size=15)

        if polar:
            inset_axis = self.figure.add_axes([0.75, 0.65, 0.25, 0.25],
                                              projection='polar')
            inset_axis.patch.set_alpha(0)  # Transparent background
            inset_axis.set_theta_zero_location('N')
            inset_axis.set_thetamin(0)
            inset_axis.set_thetamax(180)
            inset_axis.set_xticks(np.pi / 180. *
                                  np.linspace(0, 180, 5, endpoint=True))
            inset_axis.set_yticks([])
            inset_axis.step(sim_hist[0] / 180 * np.pi,
                            sim_hist[2],
                            color=sim_colors[self.simulation.simulation_name],
                            where='mid')
            inset_axis.fill_between(
                sim_hist[0] / 180 * np.pi,
                sim_hist[2] + sim_hist[3],
                sim_hist[2] - sim_hist[3],
                step='mid',
                color=sim_colors[self.simulation.simulation_name],
                alpha=0.2,
                edgecolor='none',
                linewidth=0)

        patch_ceagle = Patch(facecolor=sim_colors['ceagle'],
                             label='C-EAGLE',
                             edgecolor='k',
                             linewidth=1)
        patch_celre = Patch(facecolor=sim_colors['celr_e'],
                            label='CELR-E',
                            edgecolor='k',
                            linewidth=1)
        patch_celrb = Patch(facecolor=sim_colors['celr_b'],
                            label='CELR-B',
                            edgecolor='k',
                            linewidth=1)
        patch_macsis = Patch(facecolor=sim_colors['macsis'],
                             label='MACSIS',
                             edgecolor='k',
                             linewidth=1)

        leg2 = axis.legend(
            handles=[patch_ceagle, patch_celre, patch_celrb, patch_macsis],
            loc='lower center',
            handlelength=1,
            fontsize=20)
        axis.add_artist(leg2)