Esempio n. 1
0
    def draw_LX_bounds(self, ax: plt.Axes, redshifts_on: bool = True):
        self.hconv = 0.70 / self.h

        ax.axhspan(self.bins[0][0] * 1e44 * self.hconv**2,
                   self.bins[-1][1] * 1e44 * self.hconv**2,
                   facecolor='lime',
                   linewidth=0,
                   alpha=0.2)
        ax.axhline(self.bins[0][0] * 1e44 * self.hconv**2,
                   color='lime',
                   linewidth=1,
                   alpha=0.1)

        for i, (luminosity_min, luminosity_max, redshift_min,
                redshift_max) in enumerate(self.bins):

            ax.axhline(luminosity_max * 1e44 * self.hconv**2,
                       color='lime',
                       linewidth=1,
                       alpha=0.1)

            # Print redshift bounds once every 2 bins to avoid clutter.
            if i % 2 == 0 and redshifts_on:
                ax.text(10**ax.get_xlim()[0],
                        10**(0.5 * np.log10(luminosity_min * luminosity_max)) *
                        1e44 * self.hconv**2,
                        f"$z$ = {redshift_min:.3f} - {redshift_max:.3f}",
                        horizontalalignment='left',
                        verticalalignment='center',
                        color='k',
                        alpha=0.3)
Esempio n. 2
0
    def _add_shading_to_axes(self, ax: Axes) -> None:
        """
        Adds any required shading.
        """

        common_args = dict(zorder=-10, alpha=0.3, color="grey", linewidth=0)

        # First do horizontal
        if self.x_shade[0] is not None:
            ax.axvspan(self.x_lim[0], self.x_shade[0], **common_args)

        if self.x_shade[1] is not None:
            ax.axvspan(self.x_shade[1], self.x_lim[1], **common_args)

        # Vertical
        if self.y_shade[0] is not None:
            ax.axhspan(self.y_lim[0], self.y_shade[0], **common_args)

        if self.y_shade[1] is not None:
            ax.axhspan(self.y_shade[1], self.y_lim[1], **common_args)

        return
Esempio n. 3
0
File: plot.py Progetto: whynot-s/evo
def error_array(ax: plt.Axes,
                err_array: ListOrArray,
                x_array: typing.Optional[ListOrArray] = None,
                statistics: typing.Optional[typing.Dict[str, float]] = None,
                threshold: float = None,
                cumulative: bool = False,
                color: str = 'grey',
                name: str = "error",
                title: str = "",
                xlabel: str = "index",
                ylabel: typing.Optional[str] = None,
                subplot_arg: int = 111,
                linestyle: str = "-",
                marker: typing.Optional[str] = None):
    """
    high-level function for plotting raw error values of a metric
    :param fig: matplotlib axes
    :param err_array: an nx1 array of values
    :param x_array: an nx1 array of x-axis values
    :param statistics: optional dictionary of {metrics.StatisticsType.value: value}
    :param threshold: optional value for horizontal threshold line
    :param cumulative: set to True for cumulative plot
    :param name: optional name of the value array
    :param title: optional plot title
    :param xlabel: optional x-axis label
    :param ylabel: optional y-axis label
    :param subplot_arg: optional matplotlib subplot ID if used as subplot
    :param linestyle: matplotlib linestyle
    :param marker: optional matplotlib marker style for points
    """
    if cumulative:
        if x_array is not None:
            ax.plot(x_array,
                    np.cumsum(err_array),
                    linestyle=linestyle,
                    marker=marker,
                    color=color,
                    label=name)
        else:
            ax.plot(np.cumsum(err_array),
                    linestyle=linestyle,
                    marker=marker,
                    color=color,
                    label=name)
    else:
        if x_array is not None:
            ax.plot(x_array,
                    err_array,
                    linestyle=linestyle,
                    marker=marker,
                    color=color,
                    label=name)
        else:
            ax.plot(err_array,
                    linestyle=linestyle,
                    marker=marker,
                    color=color,
                    label=name)
    if statistics is not None:
        for stat_name, value in statistics.items():
            color = next(ax._get_lines.prop_cycler)['color']
            if stat_name == "std" and "mean" in statistics:
                mean, std = statistics["mean"], statistics["std"]
                ax.axhspan(mean - std / 2,
                           mean + std / 2,
                           color=color,
                           alpha=0.5,
                           label=stat_name)
            else:
                ax.axhline(y=value,
                           color=color,
                           linewidth=2.0,
                           label=stat_name)
    if threshold is not None:
        ax.axhline(y=threshold,
                   color='red',
                   linestyle='dashed',
                   linewidth=2.0,
                   label="threshold")
    plt.ylabel(ylabel if ylabel else name)
    plt.xlabel(xlabel)
    plt.title(title)
    plt.legend(frameon=True)