Exemplo n.º 1
0
def bayesian_confidence(ax: plt.Axes,
                        bin_edges: np.ndarray = np.array([-1000, -40, -20, -10, -4.5, -1.5, 1.5, 4.5, 10, 20, 40, 1000])):
    x, y = [], []
    for pid in DataExp1.pids:
        data = DataExp1(pid)
        model = data.build_model(models.ChoiceModel4Param)
        L = model.L + model.L_uniform
        x += list(logsumexp(L, b=model.is_chosen, axis=1) - logsumexp(L, b=1 - model.is_chosen, axis=1))
        y += list((model.df['confidence'] == 'high').astype(float))
    df = pd.DataFrame({'x': x, 'y': y})
    df['bin'] = pd.cut(df['x'], bin_edges, labels=False)
    x, y, yerr = [], [], []
    for i in range(len(bin_edges) - 1):
        _df = df[df['bin'] == i]
        if len(_df) == 0:
            continue
        x.append(_df['x'].mean())
        y.append(_df['y'].mean())
        yerr.append(_df['y'].sem())
    ax.errorbar(x, y, yerr, label='Human $\pm$ sem', c='darkgreen', fmt='.-', capsize=2, ms=2, capthick=0.5)
    ax.set_yticks([0, 1])
    ax.set_yticklabels(['Low', 'High'])
    ax.set_ylabel('Avg. reported\nconfidence', labelpad=-16)
    # ax.set_xticks([0, 0.5, 1])
    ax.set_xlabel(r'logit$(\,P_{\mathregular{ideal}}(S\,|\,\bf{X}$$)\,)$')
    ax.legend(loc='lower right', handler_map={ErrorbarContainer: HandlerErrorbar(yerr_size=0.25)})
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    plt.tight_layout()
Exemplo n.º 2
0
def plot_2E(ax: plt.Axes,
            bin_edges_human: np.ndarray = np.array(
                [-1000, -90, -60, -30, -4, -2, 0, 2, 4, 30, 60, 1000]),
            bin_edges_model: np.ndarray = np.linspace(-160, 100, 40)):
    Δ, p_human, p_model = [], [], []
    for pid in DataExp1.pids:
        data = DataExp1(pid)
        model = data.build_model(models.BayesianIdealObserver)
        df = model.predict(model.fit())
        δ = np.stack([
            np.log(df[s]) - np.log(np.sum(df.loc[:, df.columns != s], axis=1))
            for s in data.structures
        ])
        Δ += list(δ.T.flatten())
        p_model += list(
            data.cross_validate(models.ChoiceModel4Param).to_numpy().flatten())
        p_human += list(
            np.array([data.df['choice'] == s
                      for s in Exp1.structures]).T.flatten())
    df = pd.DataFrame({'Δ': Δ, 'p_human': p_human, 'p_model': p_model})
    x_human, y_human, yerr_human, x_model, y_model, yerr_model = [], [], [], [], [], []
    df['bin'] = pd.cut(df['Δ'], bin_edges_human, labels=False)
    for i in range(len(bin_edges_human) - 1):
        _df = df[df['bin'] == i]
        x_human.append(_df['Δ'].mean())
        y_human.append(_df['p_human'].mean())
        yerr_human.append(_df['p_human'].sem())
    df['bin'] = pd.cut(df['Δ'], bin_edges_model, labels=False)
    for i in range(len(bin_edges_model) - 1):
        _df = df[df['bin'] == i]
        x_model.append(_df['Δ'].mean())
        y_model.append(_df['p_model'].mean())
        yerr_model.append(_df['p_model'].sem())
    ax.errorbar(x_human,
                y_human,
                yerr_human,
                label='Human ± sem',
                color=colors['decision_human'],
                fmt='.',
                capsize=2,
                ms=2,
                capthick=0.5,
                zorder=1)
    ax.plot(x_model,
            y_model,
            color=colors['decision_model'],
            label='Model',
            ms=1,
            zorder=0)
    ax.set_xlabel(r'logit( $P_\mathregular{ideal}$($S\,|\,\bf{X}$) )')
    ax.set_ylabel(r'$P($choice=$S\,|\,\bf{X}$)')
    ax.set_ylim(0, 1)
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles[::-1],
              labels[::-1],
              loc='upper left',
              handler_map={ErrorbarContainer: HandlerErrorbar(yerr_size=0.25)})
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    plt.tight_layout()
Exemplo n.º 3
0
def binned_plot(xdata: List[float], ydata: List[float], q: int, ax: plt.Axes):
    df = pd.DataFrame({'x': xdata, 'y': ydata})
    df['bin'] = pd.qcut(xdata, q, labels=False, duplicates='drop')
    x, y, yerr = [], [], []
    for i in range(q):
        _df = df[df['bin'] == i]
        x.append(_df['x'].mean())
        y.append(_df['y'].mean())
        yerr.append(_df['y'].sem())
    ax.errorbar(x, y, yerr, label='Human $\pm$ sem', c='darkgreen', fmt='.-', capsize=2, ms=2, capthick=0.5)
Exemplo n.º 4
0
def plot_lt(df: pd.DataFrame, ax: plt.Axes) -> None:
    grouped = df.groupby('server_options.flush_every_n')
    for flush_every_n, n_group in grouped:
        grouped = n_group.groupby('num_clients')
        throughput = grouped['throughput'].agg(np.mean).sort_index()
        throughput_std = grouped['throughput'].agg(np.std).sort_index().fillna(0)
        latency = grouped['latency'].agg(np.mean).sort_index()
        latency_std = grouped['latency'].agg(np.std).sort_index().fillna(0)
        print(throughput_std)
        ax.errorbar(throughput / 100000, latency, yerr=latency_std,
                    xerr=throughput_std / 100000, fmt='.-', label=flush_every_n,
                    barsabove=True)
Exemplo n.º 5
0
 def format_trace_plot(plot: plt.Axes, trace_forward: np.ndarray,
                       trace_reverse: np.ndarray):
     x = np.arange(n_trace + 1)[1:] * trace_spacing * 100
     plot.errorbar(x,
                   trace_forward[:, 0],
                   yerr=2 * trace_forward[:, 1],
                   ecolor='b',
                   elinewidth=0,
                   mec='none',
                   mew=0,
                   linestyle='None',
                   zorder=10)
     plot.plot(
         x,
         trace_forward[:, 0],
         'b-',
         marker='o',
         mec='b',
         mfc='w',
         label='Forward',
         zorder=20,
     )
     plot.errorbar(x,
                   trace_reverse[:, 0],
                   yerr=2 * trace_reverse[:, 1],
                   ecolor='r',
                   elinewidth=0,
                   mec='none',
                   mew=0,
                   linestyle='None',
                   zorder=10)
     plot.plot(x,
               trace_reverse[:, 0],
               'r-',
               marker='o',
               mec='r',
               mfc='w',
               label='Reverse',
               zorder=20)
     y_fill_upper = [trace_forward[-1, 0] + 2 * trace_forward[-1, 1]
                     ] * 2
     y_fill_lower = [trace_forward[-1, 0] - 2 * trace_forward[-1, 1]
                     ] * 2
     xlim = [0, 100]
     plot.fill_between(xlim,
                       y_fill_lower,
                       y_fill_upper,
                       color='orchid',
                       zorder=5)
     plot.set_xlim(xlim)
     plot.legend()
     plot.set_xlabel("% Samples Analyzed", fontsize=20)
     plot.set_ylabel(r"$\Delta G$ in kcal/mol", fontsize=20)
def plot_3B(ax: plt.Axes):
    data = pool(DataExp2)
    data.plot_stacked_bar(ax)
    n = len(ExpConfig.glo_exp2)
    y_human, y1, y2 = np.zeros(n), np.zeros(n), np.zeros(n)
    for pid in DataExp2.pids:
        data = DataExp2(pid)
        y_human += data.plot_line_human()[0]
        m1 = data.load_model(
            models.ChoiceModel4Param,
            DataExp1(pid).build_model(models.ChoiceModel4Param).fit())
        y1 += data.plot_line_model(m1.predict(m1.fit()))
        y2 += data.plot_line_model(
            data.cross_validate(models.ChoiceModel4Param))
    y_human, y1, y2 = y_human / len(DataExp2.pids), y1 / len(
        DataExp2.pids), y2 / len(DataExp2.pids)
    err = [np.sqrt(p * (1 - p) / len(DataExp2.pids) / 20) for p in y_human]
    ax.errorbar(DataExp2.x,
                y_human,
                err,
                label='Human $\pm$ sem',
                color=colors['decision_human'],
                capsize=5,
                capthick=1,
                lw=1,
                ms=3,
                fmt='o',
                zorder=3)
    ax.plot(DataExp2.x,
            y1,
            'o--',
            label='Transfer model',
            color=colors['decision_transfer'],
            lw=1,
            ms=3,
            zorder=2)
    ax.plot(DataExp2.x,
            y2,
            'o-',
            label='Fitted model',
            color=colors['decision_model'],
            lw=1,
            ms=3,
            zorder=2)
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles[::-1],
              labels[::-1],
              loc='upper right',
              handler_map={ErrorbarContainer: HandlerErrorbar(yerr_size=0.35)})
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    plt.tight_layout()
Exemplo n.º 7
0
 def format_trace_plot(plot: plt.Axes, trace_forward: np.ndarray, trace_reverse: np.ndarray):
     x = np.arange(n_trace + 1)[1:] * trace_spacing * 100
     plot.errorbar(x, trace_forward[:, 0], yerr=2 * trace_forward[:, 1], ecolor='b',
                   elinewidth=0, mec='none', mew=0, linestyle='None',
                   zorder=10)
     plot.plot(x, trace_forward[:, 0], 'b-', marker='o', mec='b', mfc='w', label='Forward', zorder=20,)
     plot.errorbar(x, trace_reverse[:, 0], yerr=2 * trace_reverse[:, 1], ecolor='r',
                   elinewidth=0, mec='none', mew=0, linestyle='None',
                   zorder=10)
     plot.plot(x, trace_reverse[:, 0], 'r-', marker='o', mec='r', mfc='w', label='Reverse', zorder=20)
     y_fill_upper = [trace_forward[-1, 0] + 2 * trace_forward[-1, 1]] * 2
     y_fill_lower = [trace_forward[-1, 0] - 2 * trace_forward[-1, 1]] * 2
     xlim = [0, 100]
     plot.fill_between(xlim, y_fill_lower, y_fill_upper, color='orchid', zorder=5)
     plot.set_xlim(xlim)
     plot.legend()
     plot.set_xlabel("% Samples Analyzed", fontsize=20)
     plot.set_ylabel(r"$\Delta G$ in kcal/mol", fontsize=20)
Exemplo n.º 8
0
 def show(self, ax: Axes, plot: "plt.Plot"):
     if not plot.plot_settings[lit.ERROR_BAR]:
         ax.plot(self.xvalues,
                 self.yvalues,
                 self.fmt,
                 color=self.color,
                 label=self.label,
                 **self.plot_kwargs)
     else:
         ax.errorbar(self.xvalues,
                     self.yvalues,
                     self.yerr,
                     self.xerr,
                     fmt=self.fmt,
                     color=self.color,
                     label=self.label,
                     **self.plot_kwargs,
                     **self.err_kwargs)
Exemplo n.º 9
0
def errorbar(ax: plt.Axes, data: FittingData):  # pylint: disable=C0103
    """
    Plot error bar to figure.

    :param ax: Figure axes.
    :type ax: matplotlib.pyplot.Axes
    :param data: Data to visualize
    :type data: eddington.fitting_data.FittingData
    """
    ax.errorbar(
        x=data.x,
        y=data.y,
        xerr=data.xerr,
        yerr=data.yerr,
        markersize=1,
        marker="o",
        linestyle="None",
    )
Exemplo n.º 10
0
    def plot_observations(self, axes: plt.Axes):
        sun_observations = Sun2009()
        r_r500, S_S500_50, S_S500_10, S_S500_90 = sun_observations.get_shortcut(
        )

        rexcess = Pratt2010(n_radial_bins=21)
        bin_median, bin_perc16, bin_perc84 = rexcess.combine_entropy_profiles(
            m500_limits=(1e14 * Solar_Mass, 5e14 * Solar_Mass),
            k500_rescale=True)

        r = np.array([*axes.get_xlim()])
        k = 1.40 * r**1.1
        axes.plot(r, k, c='grey', ls='--', label='VKB (2005)')

        asymmetric_error = np.array(
            list(zip(bin_median - bin_perc16, bin_perc84 - bin_median))).T
        axes.errorbar(rexcess.radial_bins,
                      bin_median,
                      yerr=asymmetric_error,
                      fmt='o',
                      markersize=2,
                      color='grey',
                      ecolor='lightgray',
                      elinewidth=0.7,
                      capsize=0,
                      label=rexcess.citation)
        asymmetric_error = np.array(
            list(zip(S_S500_50 - S_S500_10, S_S500_90 - S_S500_50))).T
        axes.errorbar(r_r500,
                      S_S500_50,
                      yerr=asymmetric_error,
                      fmt='^',
                      markersize=2,
                      color='grey',
                      ecolor='lightgray',
                      elinewidth=0.7,
                      capsize=0,
                      label=sun_observations.citation)
    def plot(self, ax: plt.Axes):
        utils.format_axes(ax)
        with_both_bounds = np.logical_and(np.isfinite(self.sed_lower), np.isfinite(self.sed_upper))
        E_mean = np.sqrt(self.E_left * self.E_right)
        E_err_left = E_mean - self.E_left
        E_err_right = self.E_right - E_mean

        # check if the same data has already been plotted and listed on legend
        label = str(self)
        _, legend_texts = ax.get_legend_handles_labels()
        for legend_text in legend_texts:
            if label == legend_text:
                return

        fmt = self.marker
        ax.errorbar(
            E_mean[with_both_bounds],
            self.sed_mean[with_both_bounds],
            xerr=[E_err_left[with_both_bounds], E_err_right[with_both_bounds]],
            yerr=(
                (self.sed_mean - self.sed_lower)[with_both_bounds],
                (self.sed_upper - self.sed_mean)[with_both_bounds],
            ),
            fmt=fmt,
            color=self.color,
            label=label,
        )
        with_upper_bound = np.logical_not(with_both_bounds)
        ax.errorbar(
            E_mean[with_upper_bound],
            self.sed_upper[with_upper_bound],
            xerr=[E_err_left[with_upper_bound], E_err_right[with_upper_bound]],
            yerr=self.sed_upper[with_upper_bound] / 2,
            uplims=True,
            fmt=fmt,
            color=self.color,
        )
Exemplo n.º 12
0
def plot_single(axis: plt.Axes, plot: plots.Plottable):
    if plot.displayType == "marker":
        p = axis.errorbar(
            plot.x,
            plot.y,
            yerr=plot.yErr,
            xerr=plot.xErr,
            capsize=3,
            linestyle="None",
            markersize=10,
            marker=styles.style("marker"),
        )
    else:
        p = axis.plot(plot.x, plot.y, linestyle=styles.style(plot.displayType))
    if plot.label:
        p[0].set_label(plot.label)
Exemplo n.º 13
0
def median_plot(axes: plt.Axes, x: np.ndarray, y: np.ndarray,  **kwargs):

	perc84 = Line2D([], [], color='k', marker='^', linewidth=1, linestyle='-', markersize=3, label=r'$84^{th}$ percentile')
	perc50 = Line2D([], [], color='k', marker='o', linewidth=1, linestyle='-', markersize=3, label=r'median')
	perc16 = Line2D([], [], color='k', marker='v', linewidth=1, linestyle='-', markersize=3, label=r'$16^{th}$ percentile')
	legend = axes.legend(handles=[perc84, perc50, perc16], loc='lower right', handlelength=2)
	axes.add_artist(legend)
	data_plot = utils.medians_2d(x, y, **kwargs)
	axes.errorbar(data_plot['median_x'], data_plot['median_y'], yerr=data_plot['err_y'],
	              marker='o', ms=2, alpha=1, linestyle='-', capsize=0, linewidth=0.5)
	axes.errorbar(data_plot['median_x'], data_plot['percent16_y'], yerr=data_plot['err_y'],
	              marker='v', ms=2, alpha=1, linestyle='-', capsize=0, linewidth=0.5)
	axes.errorbar(data_plot['median_x'], data_plot['percent84_y'], yerr=data_plot['err_y'],
	              marker='^', ms=2, alpha=1, linestyle='-', capsize=0, linewidth=0.5)
Exemplo n.º 14
0
def _plot_strat_1d(
    strat: Strategy,
    ax: plt.Axes,
    true_testfun: Optional[Callable],
    cred_level: float,
    target_level: Optional[float],
    xlabel: str,
    ylabel: str,
    yes_label: str,
    no_label: str,
    gridsize: int,
):
    """Helper function for creating 1-d plots. See plot_strat for an explanation of the arguments."""

    x, y = strat.x, strat.y
    assert x is not None and y is not None, "No data to plot!"

    grid = strat.model.dim_grid(gridsize=gridsize)
    samps = norm.cdf(strat.model.sample(grid, num_samples=10000).detach())
    phimean = samps.mean(0)

    ax.plot(np.squeeze(grid), phimean)
    if cred_level is not None:
        upper = np.quantile(samps, cred_level, axis=0)
        lower = np.quantile(samps, 1 - cred_level, axis=0)
        ax.fill_between(
            np.squeeze(grid),
            lower,
            upper,
            alpha=0.3,
            hatch="///",
            edgecolor="gray",
            label=f"{cred_level*100:.0f}% posterior mass",
        )
    if target_level is not None:
        from aepsych.utils import interpolate_monotonic

        threshold_samps = [
            interpolate_monotonic(grid.squeeze().numpy(), s, target_level,
                                  strat.lb[0], strat.ub[0]) for s in samps
        ]
        thresh_med = np.mean(threshold_samps)
        thresh_lower = np.quantile(threshold_samps, q=1 - cred_level)
        thresh_upper = np.quantile(threshold_samps, q=cred_level)

        ax.errorbar(
            thresh_med,
            target_level,
            xerr=np.r_[thresh_med - thresh_lower,
                       thresh_upper - thresh_med][:, None],
            capsize=5,
            elinewidth=1,
            label=
            f"Est. {target_level*100:.0f}% threshold \n(with {cred_level*100:.0f}% posterior \nmass marked)",
        )

    if true_testfun is not None:
        true_f = true_testfun(grid)
        ax.plot(grid, true_f.squeeze(), label="True function")
        if target_level is not None:
            true_thresh = interpolate_monotonic(
                grid.squeeze().numpy(),
                true_f.squeeze(),
                target_level,
                strat.lb[0],
                strat.ub[0],
            )

            ax.plot(
                true_thresh,
                target_level,
                "o",
                label=f"True {target_level*100:.0f}% threshold",
            )

    ax.scatter(
        x[y == 0, 0],
        np.zeros_like(x[y == 0, 0]),
        marker=3,
        color="r",
        label=no_label,
    )
    ax.scatter(
        x[y == 1, 0],
        np.zeros_like(x[y == 1, 0]),
        marker=3,
        color="b",
        label=yes_label,
    )
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)

    return ax
Exemplo n.º 15
0
def plot_triple_data(
    ax: plt.Axes,
    lab_values: ty.List[ty.List[str]],
    lab_names: ty.Tuple[str, str, str, str],
    f_mean: np.ndarray,
    f_min: np.ndarray,
    f_max: np.ndarray,
    f_std: np.ndarray,
    outer_label: str = None,
    outer_value: str = None,
    mean_all: float = None,
    max_all: float = None,
    min_all: float = None,
    min_lower_limit: float = 0,
    hypa_labels: ty.Dict[str, ty.Dict[str, str]] = None,
):
    """Plot groups of groups of columns

    f_mean.shape (x, y, z)

    z groups of (groups of columns)
    y groups of columns
    x columns per group

        example of (3, 4, 2) shape
    ^
    |                   x x
    |x                  xxx xx   x
    |xxx xx             xxx xx  xx  x
    |xxx xxx x   x x    xxx xxx xxx x x
    |xxx xxx xxx xxx    xxx xxx xxx xxx
    |xxx xxx xxx xxx    xxx xxx xxx xxx
    .----------------------------------->
    """
    logg = logging.getLogger(f"c.{__name__}.plot_triple_data")
    logg.setLevel("INFO")
    logg.debug("Start plot_triple_data")
    logg.debug(f"f_mean.shape: {f_mean.shape}")
    logg.debug(f"outer_value: {outer_value} outer_label {outer_label}")
    logg.debug(f"lab_names: {lab_names}")
    logg.debug(f"lab_values: {lab_values}")
    logg.debug(f"hypa_labels: {hypa_labels}")

    if outer_label is not None and outer_value is not None:
        if hypa_labels is not None:
            lab_values_disp = []
            for il, this_lab_values in enumerate(lab_values):
                this_lab_name = lab_names[il]
                new_disp_values = []
                for this_lab_value in this_lab_values:

                    # if I know how to translate
                    if this_lab_name in hypa_labels:
                        disp_lab_value = hypa_labels[this_lab_name][
                            this_lab_value]
                        logg.debug(f"disp_lab_value: {disp_lab_value}")

                    # use the current available label, no translation
                    else:
                        disp_lab_value = this_lab_value

                    new_disp_values.append(disp_lab_value)

                lab_values_disp.append(new_disp_values)

            # translate the outer label if you have the translation available
            if outer_label in hypa_labels:
                outer_value_disp = hypa_labels[outer_label][outer_value]
            else:
                outer_value_disp = outer_value

        else:
            lab_values_disp = lab_values
            outer_value_disp = outer_value
    logg.debug(f"lab_values_disp: {lab_values_disp}")
    logg.debug(f"outer_value_disp: {outer_value_disp}")

    title = ""
    if outer_label is not None and outer_value is not None:
        title += f"{outer_label}"
        # title += f": \\textbf{{{outer_value_disp}}}"
        title += f": $\\bf{{{outer_value_disp}}}$"
        title += "\n"
    title += f"{lab_names[0]}"
    title += f": {lab_values_disp[0]}"
    title += "\n"
    title += f" grouped by {lab_names[1]}"
    title += f": {lab_values_disp[1]}"
    title += "\n"
    title += f" grouped by {lab_names[2]}"
    title += f": {lab_values_disp[2]}"
    ax.set_title(title, fontsize=14)

    lab_fontsize = 14
    ax.set_ylabel("F-score (min/mean/max and std-dev)", fontsize=lab_fontsize)
    ax.set_xlabel(f"{lab_names[1]} ({lab_names[2]})", fontsize=lab_fontsize)

    f_dim = f_mean.shape

    # the width of each super group
    width_outer_group = 0.9

    # scale the available space by 0.8 to leave space between subgroups
    width_inner_group = width_outer_group * 0.8 / f_dim[1]

    # the columns are side by side
    width_inner_col = width_inner_group / f_dim[0]

    # where the z super groups of columns start
    x_outer_pos = np.arange(f_dim[2])

    # where the y sub groups start
    x_inner_pos = np.arange(f_dim[1]) * width_outer_group / f_dim[1]

    # where to put the ticks for each subgroup
    x_inner_ticks = x_inner_pos + (width_inner_group / 2)

    # all the ticks and the relative labels
    all_x_ticks = np.array([])
    all_xticklabels = []

    # if there are too many groups of columns draw less info
    too_many_groups = f_dim[1] * f_dim[2] > 12

    if not too_many_groups:
        err_capsize = 5
        std_capsize = 3
        std_capthick = 4
        xticklabels_rot = 0
        xticklabels_ha = "center"
    else:
        err_capsize = 3
        std_capsize = 2
        std_capthick = 3
        xticklabels_rot = 30
        xticklabels_ha = "right"

    # for each super group
    for iz in range(f_dim[2]):

        # where this group starts
        shift_group = x_outer_pos[iz]

        # where to put the ticks
        this_ticks = x_inner_ticks + shift_group
        all_x_ticks = np.hstack((all_x_ticks, this_ticks))
        this_labels = [
            f"{vy} ({lab_values_disp[2][iz]})" for vy in lab_values_disp[1]
        ]
        all_xticklabels.extend(this_labels)

        # reset the cycler
        cc = cycler(color=[
            "#1f77b4",
            "#ff7f0e",
            "#2ca02c",
            "#d62728",
            "#9467bd",
            "#8c564b",
            "#e377c2",
            "#7f7f7f",
            "#bcbd22",
            "#17becf",
        ])

        ax.set_prop_cycle(cc)

        # for each column
        for ix in range(f_dim[0]):

            # how much this batch of y columns must be shifted
            shift_col = width_inner_col * ix
            x_col = x_inner_pos + shift_col + shift_group

            # extract the values of the y columns in this batch
            y_f = f_mean[ix, :, iz]

            # only put the label for the first z slice
            the_label = lab_values_disp[0][ix] if iz == 0 else None

            # plot the bars
            ax.bar(
                x=x_col,
                height=y_f,
                width=width_inner_col,
                label=the_label,
                align="edge",
                capsize=err_capsize,
            )

            # compute the relative min/max
            y_min = y_f - f_min[ix, :, iz]
            y_max = f_max[ix, :, iz] - y_f
            y_err = np.vstack((y_min, y_max))
            ax.errorbar(
                x=x_col + width_inner_col / 2,
                y=y_f,
                yerr=y_err,
                linestyle="None",
                capsize=err_capsize,
                ecolor="k",
            )

            # get the standard deviation
            y_std = f_std[ix, :, iz]
            ax.errorbar(
                x_col + width_inner_col / 2,
                y_f,
                yerr=y_std,
                linestyle="None",
                capsize=std_capsize,
                capthick=std_capthick,
                ecolor="b",
            )

    if min_all is not None and min_all < min_lower_limit:
        # logg.debug(f"Resettin min_all: {min_all} to min_lower_limit")
        min_all = min_lower_limit

    if max_all is not None and min_all is not None:
        ax.set_ylim(top=max_all * 1.01, bottom=min_all * 0.99)
    elif max_all is not None:
        ax.set_ylim(top=max_all * 1.01)
    elif min_all is not None:
        ax.set_ylim(bottom=min_all * 0.99)

    if mean_all is not None:
        ax.axhline(mean_all)

        bottom, top = ax.get_ylim()
        mean_rescaled = ((mean_all - bottom) / (top - bottom)) * 1.01
        ax.annotate(
            text=f"{mean_all:.03f}",
            xy=(0.005, mean_rescaled),
            xycoords="axes fraction",
            fontsize=13,
        )

    ax.set_xticks(all_x_ticks)
    ax.set_xticklabels(
        labels=all_xticklabels,
        rotation=xticklabels_rot,
        horizontalalignment=xticklabels_ha,
    )
    ax.legend(title=f"{lab_names[0]}", title_fontsize=lab_fontsize)
def plot_3C(ax: plt.Axes,
            bin_edges_human: np.ndarray = np.array(
                [-1000, -4.5, -3, -1.5, 0, 1.5, 3, 4.5, 1000]),
            bin_edges_model: np.ndarray = np.array(
                [-1000, -4.5, -3, -1.5, 0, 1.5, 3, 4.5, 1000])):
    Δ, p_human, p1, p2 = [], [], [], []
    for pid in DataExp2.pids:
        data = DataExp2(pid)
        model = data.build_model(models.BayesianIdealObserver)
        df = model.predict(model.fit())
        Δ += list(np.log(df['C']) - np.log(df['H']))
        model = data.load_model(
            models.ChoiceModel4Param,
            DataExp1(pid).build_model(models.ChoiceModel4Param).fit())
        p1 += list(model.predict(model.fit())['C'])
        p2 += list(data.cross_validate(models.ChoiceModel4Param)['C'])
        p_human += list((data.df['choice'] == 'C') * 1.0)
    df = pd.DataFrame({'Δ': Δ, 'p_human': p_human, 'p1': p1, 'p2': p2})
    x_human, y_human, yerr_human, x_model, y1, yerr1, y2, yerr2 = [], [], [], [], [], [], [], []
    df['bin'] = pd.cut(df['Δ'], bin_edges_human, labels=False)
    for i in range(len(bin_edges_human) - 1):
        _df = df[df['bin'] == i]
        x_human.append(_df['Δ'].mean())
        y_human.append(_df['p_human'].mean())
        yerr_human.append(_df['p_human'].sem())
    df['bin'] = pd.cut(df['Δ'], bin_edges_model, labels=False)
    for i in range(len(bin_edges_model) - 1):
        _df = df[df['bin'] == i]
        x_model.append(_df['Δ'].mean())
        y1.append(_df['p1'].mean())
        yerr1.append(_df['p1'].sem())
        y2.append(_df['p2'].mean())
        yerr2.append(_df['p2'].sem())
    ax.errorbar(x_human,
                y_human,
                yerr_human,
                label='Human $\pm$ sem',
                color=colors['decision_human'],
                fmt='.',
                capsize=2,
                ms=2,
                capthick=0.5,
                zorder=1)
    ax.plot(x_model,
            y1,
            '--',
            color=colors['decision_transfer'],
            label='Transfer model',
            ms=1,
            zorder=0)
    ax.plot(x_model,
            y2,
            '-',
            color=colors['decision_model'],
            label='Fitted model',
            ms=1,
            zorder=0)
    ax.set_xlabel(r'logit( $P_\mathregular{ideal}(S=C\,|\,\bf{X}$) )')
    ax.set_ylabel(r'$P$(choice=$C\,|\,\bf{X}$)')
    ax.set_ylim(0, 1)
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles[::-1],
              labels[::-1],
              loc='lower right',
              handler_map={ErrorbarContainer: HandlerErrorbar(yerr_size=0.25)})
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    plt.tight_layout()
Exemplo n.º 17
0
    def plot_on_axes(self, axes: Axes, errorbar_kwargs: Union[dict, None] = None):
        """
        Plot this set of observational data as an errorbar().

        Parameters
        ----------

        axes: plt.Axes
            The matplotlib axes to plot the data on. This will either
            plot the data as a line or a set of errorbar points, with
            the short citation (self.citation) being included in the
            legend automatically.

        errorbar_kwargs: dict
            Optional keyword arguments to pass to plt.errorbar.
        """

        # Do this because dictionaries are mutable
        if errorbar_kwargs is not None:
            kwargs = errorbar_kwargs
        else:
            kwargs = {}

        # Ensure correct units throughout, in case somebody changed them
        if self.x_scatter is not None:
            self.x_scatter.convert_to_units(self.x.units)

        if self.y_scatter is not None:
            self.y_scatter.convert_to_units(self.y.units)

        if self.plot_as == "points":
            kwargs["linestyle"] = "none"
            kwargs["marker"] = "."
            kwargs["zorder"] = points_zorder

            # Need to "intelligently" size the markers
            kwargs["markersize"] = (
                rcParams["lines.markersize"]
                * (1.5 - tanh(2.0 * log10(len(self.x)) - 4.0))
                / 2.5
            )

            kwargs["alpha"] = (3.0 - tanh(2.0 * log10(len(self.x)) - 4.0)) / 4.0

            # Looks weird if errorbars are present
            if self.y_scatter is None:
                kwargs["markerfacecolor"] = "none"

            if len(self.x) > 1000:
                kwargs["rasterize"] = True
        elif self.plot_as == "line":
            kwargs["zorder"] = line_zorder

        # Make both the data name and redshift appear in the legend
        data_label = f"{self.citation} ($z={self.redshift:.1f}$)"

        axes.errorbar(
            self.x,
            self.y,
            yerr=self.y_scatter,
            xerr=self.x_scatter,
            **kwargs,
            label=data_label,
        )

        return
Exemplo n.º 18
0
    def overlay_entropy_profiles(self,
                                 axes: plt.Axes = None,
                                 r_units: str = 'r500',
                                 k_units: str = 'K500adi',
                                 vkb05_line: bool = True,
                                 color: str = 'k',
                                 alpha: float = 1.,
                                 markersize: float = 1,
                                 linewidth: float = 0.5) -> None:

        stand_alone = False
        if axes is None:
            stand_alone = True
            fig, axes = plt.subplots()
            axes.loglog()
            axes.set_xlabel(f'$r$ [{r_units}]')
            axes.set_ylabel(f'$K$ [${k_units}$]')
            axes.axvline(1, linestyle=':', color=color, alpha=alpha)

        # Set-up entropy data
        fields = [
            'K_500', 'K_1000', 'K_1500', 'K_2500', 'K_0p15r500', 'K_30kpc'
        ]
        K_stat = dict()
        if k_units == 'K500adi':
            K_conv = 1 / getattr(self, 'K_500_adi')
            axes.axhline(1, linestyle=':', color=color, alpha=alpha)
        elif k_units == 'keVcm^2':
            K_conv = np.ones_like(getattr(self, 'K_500_adi'))
            axes.fill_between(np.array(axes.get_xlim()),
                              y1=np.nanmin(self.K_500_adi),
                              y2=np.nanmax(self.K_500_adi),
                              facecolor='k',
                              alpha=0.3)
        else:
            raise ValueError("Conversion unit unknown.")
        for field in fields:
            data = np.multiply(getattr(self, field), K_conv)
            K_stat[field] = (np.nanpercentile(data,
                                              16), np.nanpercentile(data, 50),
                             np.nanpercentile(data, 84))
            K_stat[field.replace('K',
                                 'num')] = np.count_nonzero(~np.isnan(data))

        # Set-up radial distance data
        r_stat = dict()
        if r_units == 'r500':
            r_conv = 1 / getattr(self, 'r_500')
        elif r_units == 'r2500':
            r_conv = 1 / getattr(self, 'r_2500')
        elif r_units == 'kpc':
            r_conv = np.ones_like(getattr(self, 'r_2500'))
        else:
            raise ValueError("Conversion unit unknown.")
        for field in ['r_500', 'r_1000', 'r_1500', 'r_2500']:
            data = np.multiply(getattr(self, field), r_conv)
            if k_units == 'K500adi':
                data[np.isnan(self.K_500_adi)] = np.nan
            r_stat[field] = (np.nanpercentile(data,
                                              16), np.nanpercentile(data, 50),
                             np.nanpercentile(data, 84))
            r_stat[field.replace('r',
                                 'num')] = np.count_nonzero(~np.isnan(data))
        data = np.multiply(getattr(self, 'r_500') * 0.15, r_conv)
        if k_units == 'K500adi':
            data[np.isnan(self.K_500_adi)] = np.nan
        r_stat['r_0p15r500'] = (np.nanpercentile(data, 16),
                                np.nanpercentile(data, 50),
                                np.nanpercentile(data, 84))
        r_stat['num_0p15r500'] = np.count_nonzero(~np.isnan(data))
        data = np.multiply(
            np.ones_like(getattr(self, 'r_2500')) * 30 * unyt.kpc, r_conv)
        if k_units == 'K500adi':
            data[np.isnan(self.K_500_adi)] = np.nan
        r_stat['r_30kpc'] = (np.nanpercentile(data,
                                              16), np.nanpercentile(data, 50),
                             np.nanpercentile(data, 84))
        r_stat['num_30kpc'] = np.count_nonzero(~np.isnan(data))

        for suffix in [
                '_500', '_1000', '_1500', '_2500', '_0p15r500', '_30kpc'
        ]:
            x_low, x, x_hi = r_stat['r' + suffix]
            y_low, y, y_hi = K_stat['K' + suffix]
            num_objects = f"{r_stat['num' + suffix]}, {K_stat['num' + suffix]}"
            point_label = f"r{suffix:.<17s} Num(x,y) = {num_objects}"
            if stand_alone:
                axes.scatter(x, y, label=point_label, s=markersize)
                axes.errorbar(x,
                              y,
                              yerr=[[y_hi - y], [y - y_low]],
                              xerr=[[x_hi - x], [x - x_low]],
                              ls='none',
                              ms=markersize,
                              lw=linewidth)
            else:
                axes.scatter(x, y, color=color, alpha=alpha, s=markersize)
                axes.errorbar(x,
                              y,
                              yerr=[[y_hi - y], [y - y_low]],
                              xerr=[[x_hi - x], [x - x_low]],
                              ls='none',
                              ecolor=color,
                              alpha=alpha,
                              ms=markersize,
                              lw=linewidth)

        if vkb05_line:
            if r_units == 'r500' and k_units == 'K500adi':
                r = np.linspace(*axes.get_xlim(), 31)
                k = 1.40 * r**1.1 / self.hconv
                axes.plot(r, k, linestyle='--', color=color, alpha=alpha)
            else:
                print((
                    "The VKB05 adiabatic threshold should be plotted only when both "
                    "axes are in scaled units, since the line is calibrated on an NFW "
                    "profile with self-similar halos with an average concentration of "
                    "c_500 ~ 4.2 for the objects in the Sun et al. (2009) sample."
                ))

        if k_units == 'K500adi':
            r_r500, S_S500_50, S_S500_10, S_S500_90 = self.get_shortcut()

            plt.fill_between(r_r500,
                             S_S500_10,
                             S_S500_90,
                             color='grey',
                             alpha=0.5,
                             linewidth=0)
            plt.plot(r_r500, S_S500_50, c='k')

        if stand_alone:
            plt.legend()
            plt.show()
Exemplo n.º 19
0
    def plot_line(
        self,
        ax: Axes,
        x: unyt_array,
        y: unyt_array,
        label: Union[str, None] = None,
        x_lim: Union[List, None] = None,
        y_lim: Union[List, None] = None,
        min_num_points_highlight: int = 0,
    ):
        """
        Plot a line using these parameters on some axes, x against y.

        Parameters
        ----------

        ax: Axes
            Matplotlib axes to plot on.

        x: unyt_array
            Horizontal axis data

        y: unyt_array
            Vertical axis data

        label: str
            Label associated with this data that will be included in the
            legend.

        x_lim: Union[List, None]
            A 2-length list containing the lower and upper limits of the X-axis range.

        y_lim: Union[List, None]
            A 2-length list containing the lower and upper limits of the Y-axis range.

        min_num_points_highlight: int, optional
            Minimum number of data points with the highest values of x to highlight in
            the median-line or mean-line plots.

        Notes
        -----

        If self.scatter is set to "none", this is plotted assuming the scatter
        is zero.
        """

        if not self.plot:
            return

        centers, heights, errors, additional_x, additional_y = self.create_line(
            x=x, y=y, minimum_additional_points=min_num_points_highlight)

        if self.scatter == "none" or errors is None:
            (line, ) = ax.plot(centers, heights, label=label)
        elif self.scatter == "errorbar":
            (line, *_) = ax.errorbar(centers,
                                     heights,
                                     yerr=errors,
                                     label=label)
        elif self.scatter == "errorbar_both":
            (line, *_) = ax.errorbar(
                centers,
                heights,
                yerr=errors,
                xerr=abs(self.bins - centers),
                label=label,
                fmt=".",  # Do not plot as a line.
            )
        elif self.scatter == "shaded":
            (line, ) = ax.plot(centers, heights, label=label)

            # Deal with different + and -ve errors
            if errors.shape[0]:
                if errors.ndim > 1:
                    down, up = errors
                else:
                    up = errors
                    down = errors
            else:
                up = unyt_quantity(0, units=heights.units)
                down = unyt_quantity(0, units=heights.units)

            ax.fill_between(
                centers,
                heights - down,
                heights + up,
                color=line.get_color(),
                alpha=0.3,
                linewidth=0.0,
            )

        try:
            ax.scatter(additional_x.value,
                       additional_y.value,
                       color=line.get_color())

            # Enter only if the plot has a valid X-axis and Y-axis ranges and there are
            # any additional data points.
            if x_lim is not None and y_lim is not None and len(
                    additional_x) > 0:

                # Add arrows to the plot for each data point beyond X- or/and Y- axis
                # range
                self.highlight_data_outside_domain(
                    ax,
                    additional_x.value,
                    additional_y.value,
                    line.get_color(),
                    (x_lim[0].value, x_lim[1].value),
                    (y_lim[0].value, y_lim[1].value),
                )

        # In case the line object is undefined
        except NameError:
            ax.scatter(additional_x.value, additional_y.value)

        return
Exemplo n.º 20
0
def plot_gaussian(
    data, ax: plt.Axes, nBins=100, textpos="l", legend=False, short_text=False
):
    # make sure our data is an ndarray
    if type(data) == list:
        data = np.array(data)

    ### FITTING WITH A GAUSSIAN

    def func_gauss(x, N, mu, sigma):
        return N * stats.norm.pdf(x, mu, sigma)

    counts, bin_edges = np.histogram(data, bins=nBins)
    bin_centers = (bin_edges[1:] + bin_edges[:-1]) / 2
    s_counts = np.sqrt(counts)

    x = bin_centers[counts > 0]
    y = counts[counts > 0]
    sy = s_counts[counts > 0]

    popt_gauss, pcov_gauss = curve_fit(
        func_gauss, x, y, p0=[1, data.mean(), data.std()]
    )

    y_func = func_gauss(x, *popt_gauss)

    pKS = stats.ks_2samp(y, y_func)
    pKS_g1, pKS_g2 = pKS[0], pKS[1]

    # print('LOOK! \n \n \n pKS is {} \n \n \n '.format(pKS_g2))
    chi2_gauss = sum((y - y_func) ** 2 / sy ** 2)
    NDOF_gauss = nBins - 3
    prob_gauss = stats.chi2.sf(chi2_gauss, NDOF_gauss)


    if short_text == True:
        namesl = [
            "Gauss_N",
            "Gauss_Mu",
            "Gauss_Sigma",
        ]
        valuesl = [
            "{:.3f} +/- {:.3f}".format(val, unc)
            for val, unc in zip(popt_gauss, np.diagonal(pcov_gauss))
        ]

        del namesl[0]  # remove gauss n
        del valuesl[0]
    else:
        namesl = [
            "Gauss_N",
            "Gauss_Mu",
            "Gauss_Sigma",
            "KS stat",
            "KS_pval",
            "Chi2 / NDOF",
            "Prob",
        ]
        valuesl = (
            [
                "{:.3f} +/- {:.3f}".format(val, unc)
                for val, unc in zip(popt_gauss, np.diagonal(pcov_gauss))
            ]
            + ["{:.3f}".format(pKS_g1)]
            + ["{:.3f}".format(pKS_g2)]
            + ["{:.3f} / {}".format(chi2_gauss, NDOF_gauss)]
            + ["{:.3f}".format(prob_gauss)]
        )

    ax.errorbar(x, y, yerr=sy, xerr=0, fmt=".", elinewidth=1)
    ax.plot(x, y_func, "--", label="Gaussian")
    if textpos == "l":
        ax.text(
            0.02,
            0.98,
            nice_string_output(namesl, valuesl),
            family="monospace",
            transform=ax.transAxes,
            fontsize=10,
            verticalalignment="top",
            alpha=0.5,
        )
    elif textpos == "r":
        ax.text(
            0.6,
            0.98,
            nice_string_output(namesl, valuesl),
            family="monospace",
            transform=ax.transAxes,
            fontsize=10,
            verticalalignment="top",
            alpha=0.5,
        )
    if legend:
        ax.legend(loc="center left")
    return ax
Exemplo n.º 21
0
def plot_glitch(
    data: az.InferenceData,
    group: str = "posterior",
    kind: str = "full",
    x_var: str = "n",
    quantiles: Optional[List[float]] = None,
    observed: Union[bool, str] = "auto",
    use_alpha: bool = True,
    ax: plt.Axes = None,
    **kwargs,
) -> plt.Axes:
    """Plot the glitch from either the prior or posterior predictive contained
    in inference data.

    Args:
        data (arviz.InferenceData): Inference data object.
        group (str): One of ['posterior', 'prior'].
        kind (str): Kind of glitch to plot. One of ['full', 'helium', 'BCZ'].
        x_var (str): Variable name for x-axis. One of ['n', 'nu']. If 'nu', the
            median value of 'nu' in ``data['group']`` is used.
        quantiles (iterable, optional): Quantiles to plot as confidence
            intervals. If None, defaults to the 68% confidence interval. Pass
            an empty list to plot no confidence intervals.
        observed (bool or str): Whether to plot observed data. Default is
            "auto" which will plot observed data when group is "posterior".
        use_alpha (bool): Whether to use alpha channel for transparency. If
            False, will shade with lightened solid color.
        ax (matplotlib.axes.Axes): Axis on which to plot the glitch.
        **kwargs: Keyword arguments to pass to :func:`matplotlib.pyplot.plot`.

    Raises:
        ValueError: If kind is not valid.

    Returns:
        matplotlib.axes.Axes: Axis on which the glitch is plot.
    """
    dim = ("chain", "draw")  # dim over which to take stats
    predictive = _validate_predictive_group(data, group)

    if quantiles is None:
        quantiles = [0.16, 0.84]

    if observed == "auto":
        observed = group == "posterior"

    nu = data.observed_data.nu
    nu_err = data.constant_data.nu_err

    if x_var == "n":
        x = predictive.n
        x_pred = predictive.n_pred
        xlabel = "n"
    elif x_var == "nu":
        x = predictive.nu.median(dim=dim)
        x_pred = predictive.nu_pred.median(dim=dim)
        xlabel = r"$\nu$ ($\mathrm{\mu Hz}$)"

    if ax is None:
        _, ax = plt.subplots()

    kindl = kind.lower()
    if kindl == "full":
        dnu = predictive["dnu_he"] + predictive["dnu_cz"]
        dnu_pred = predictive["dnu_he_pred"] + predictive["dnu_cz_pred"]
    else:
        if kindl in {"helium", "he"}:
            kind = "helium"  # Set to helium for consistency in legend label
            dnu_key = "dnu_he"
        elif kindl in {"bcz", "cz"}:
            kind = "BCZ"  # Set to BCZ for consistency in legend label
            dnu_key = "dnu_cz"
        else:
            # Raise error if kind is not valid
            raise ValueError(f"Kind '{kindl}' is not one of " +
                             "['full', 'helium', 'BCZ'].")
        dnu = predictive[dnu_key]
        dnu_pred = predictive[dnu_key + "_pred"]
        # label = dnu.attrs.get("symbol", r"$\delta\nu_{" + kind + "}$")
    label = f"{kind} glitch model"

    if observed:
        # Plot observed - prior predictive should be independent of obs
        res = nu - predictive["nu"]
        dnu_obs = dnu + res
        # glitch = label
        # if "+" in label:
        #     glitch = "$($" + label + "$)$"
        # TODO: should we show model error on dnu_obs here?
        ax.errorbar(
            x,
            dnu_obs.median(dim=dim),
            yerr=nu_err,
            color="k",
            marker="o",
            linestyle="none",
            label=r"observed",
        )

    dnu_med = dnu_pred.median(dim=dim)
    label = kwargs.pop("label", label)
    (line, ) = ax.plot(x_pred, dnu_med, label=label, **kwargs)

    # Fill quantiles with alpha decreasing away from the median
    dnu_quant = dnu_pred.quantile(quantiles, dim=dim)
    num_quant = len(quantiles) // 2
    num_alphas = num_quant * 2 + 1
    alphas = np.linspace(0.1, 0.5, num_alphas)
    base_color = line.get_color()

    if use_alpha:
        colors = [base_color] * num_alphas
    else:
        # Mimic alpha by lightening the base color
        colors = [_lighten_color(base_color, 1.5 * a) for a in alphas]
        alphas = [None] * num_alphas  # reset alphas to None

    for i in range(num_quant):
        delta = quantiles[-i - 1] - quantiles[i]
        ax.fill_between(
            x_pred,
            dnu_quant[i],
            dnu_quant[-i - 1],
            color=colors[2 * i + 1],  # <-- same as model line color
            alpha=alphas[2 * i + 1],
            label=f"{delta:.1%} CI",
        )

    ax.xaxis.set_major_locator(MaxNLocator(integer=True))  # integer x-ticks
    ax.set_xlabel(xlabel)
    # ax.set_xlabel("radial order")

    # ylabel = [dnu.attrs.get("symbol", r"$\delta\nu$")]
    # ylabel = [r"$\delta\nu$"]
    unit = u.Unit(dnu.attrs.get("unit", "uHz"))
    # if str(unit) != "":
    # ylabel.append(unit.to_string(format="latex_inline"))

    # ax.set_ylabel("/".join(ylabel))
    ax.set_ylabel(r"$\delta\nu$ " +
                  f"({unit.to_string(format='latex_inline')})")
    ax.legend()
    return ax
def errorbox(
    plot: plt.Axes,
    plot_def: PlotDef,
    xaxis: Tuple[str, str],
    yaxis: Tuple[str, str],
    firstcolor: int = 0,
    firstshape: int = 0,
    markersize: int = 6,
    use_cross: bool = False,
) -> Optional[mpl.legend.Legend]:
    """Generate a figure with errorboxes that reflect the std dev of an entry.

    Args:
        plot: a pyplot Axes object
        plot_def: a `PlotDef` that defines properties of the plot
        xaxis: a tuple of two strings
        yaxis: a tuple of two strings
        firstcolor: index of the color that should be used for the first entry
        firstshape: index of the shape that should be used for the first entry
        markersize: size of the markers
        use_cross: if True, use cross instead of boxes
    """
    # ================================ constants for plots ========================================
    colors = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd"]
    colors += ["#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf"]
    pale_colors = ["#aec7e8", "#ffbb78", "#98df8a", "#ff9896", "#c5b0d5"]
    pale_colors += ["#c49c94", "#f7b6d2", "#c7c7c7", "#dbdb8d", "#9edae5"]
    shapes = ["o", "X", "D", "s", "^", "v", "<", ">", "*", "p", "P"]
    hatch_pattern = ["O", "o", "|", "-", "/", "\\", "+", "x", "*"]

    xaxis_measure, yaxis_measure = xaxis[0], yaxis[0]
    filled_counter = firstcolor

    max_xmean = 0.0
    min_xmean = 1.0

    max_ymean = 0.0
    min_ymean = 1.0

    for i, entry in enumerate(plot_def.entries):
        if entry.do_fill:
            color = colors[filled_counter]
            filled_counter += 1
        else:
            color = pale_colors[i + firstcolor - filled_counter]
        i_shp = firstshape + i

        xmean: float = entry.values[xaxis_measure].mean()
        xstd: float = entry.values[xaxis_measure].std()
        if xmean + (0.5 * xstd) > max_xmean:
            max_xmean = xmean + (0.5 * xstd)
        if xmean - (0.5 * xstd) < min_xmean:
            min_xmean = xmean - (0.5 * xstd)

        ymean: float = entry.values[yaxis_measure].mean()
        ystd: float = entry.values[yaxis_measure].std()
        if ymean + (0.5 * ystd) > max_ymean:
            max_ymean = ymean + (0.5 * ystd)
        if ymean - (0.5 * ystd) < min_ymean:
            min_ymean = ymean - (0.5 * ystd)

        if use_cross:
            plot.errorbar(
                xmean,
                ymean,
                xerr=0.5 * xstd,
                yerr=0.5 * ystd,
                marker=shapes[i_shp % len(shapes)],
                linestyle="",  # no connecting lines
                color=color,
                ecolor="#555555",
                capsize=4,
                elinewidth=1.0,
                zorder=3 + 2 * i_shp,
                label=entry.label,
                markersize=markersize,
            )
        else:
            plot.bar(
                xmean,
                ystd,
                bottom=ymean - 0.5 * ystd,
                width=xstd,
                align="center",
                color="none",
                edgecolor=color,
                linewidth=3,
                zorder=3 + 2 * i_shp,
                hatch=hatch_pattern[i % len(hatch_pattern)],
            )
            plot.plot(
                xmean,
                ymean,
                marker=shapes[i_shp % len(shapes)],
                linestyle="",  # no connecting lines
                color=color,
                label=entry.label,
                zorder=4 + 2 * i_shp,
                markersize=markersize,
            )

    x_min = min_xmean * 0.99
    x_max = max_xmean * 1.01

    y_min = min_ymean * 0.99
    y_max = max_ymean * 1.01

    plot.set_xlim(x_min, x_max)
    plot.set_ylim(y_min, y_max)
    return common_plotting_settings(plot, plot_def, xaxis[1], yaxis[1])
Exemplo n.º 23
0
def plot_echelle(
    data: az.InferenceData,
    group="posterior",
    kind: str = "full",
    delta_nu: Optional[float] = None,
    quantiles: Optional[List[float]] = None,
    observed: Union[bool, str] = "auto",
    use_alpha: bool = True,
    ax: plt.Axes = None,
    **kwargs,
) -> plt.Axes:
    """Plot an echelle diagram of the data.

    Choose to plot the full mode, background model or glitchless model. This is
    compatible with data from inference on models like :class:`GlitchModel`.

    Args:
        data (az.InferenceData): Inference data object.
        group (str): On of ['posterior', 'prior']. Defaults to 'posterior'.
        kind (str): One of ['full', 'glitchless', 'background']. Defaults to
            'full' which plots the full model for nu. Use 'glitchless' to plot
            the model without the glitch components. Use 'background' to plot
            the background component of the model.
        delta_nu (float, optional): Large frequency separation to modulo by.
            If None, the median value from ``data['group']`` is used.
        quantiles (iterable, optional): Quantiles to plot as confidence
            intervals. If None, defaults to the 68% confidence interval. Pass
            an empty list to plot no confidence intervals.
        observed (bool or str): Whether to plot observed data. Default is
            "auto" which will plot observed data when group is "posterior".
        use_alpha (bool): Whether to use alpha channel for transparency. If
            False, will shade with lightened solid color.
        ax (matplotlib.axes.Axes): Axis on which to plot the echelle.
        **kwargs: Keyword arguments to pass to :func:`matplotlib.pyplot.plot`.

    Raises:
        ValueError: If kind is not valid.

    Returns:
        matplotlib.axes.Axes: Axis on which the echelle is plot.
    """
    if ax is None:
        _, ax = plt.subplots()

    if quantiles is None:
        quantiles = [0.16, 0.84]

    if observed == "auto":
        observed = group == "posterior"

    predictive = _validate_predictive_group(data, group)
    dim = ("chain", "draw")  # dim over which to take stats

    if delta_nu is None:
        if group == "prior":  # <-- currently no prior group
            delta_nu = predictive["delta_nu"].median().to_numpy()
        else:
            delta_nu = data[group]["delta_nu"].median().to_numpy()

    nu = data.observed_data.nu
    nu_err = data.constant_data.nu_err
    n_pred = predictive.n_pred

    if observed:
        # Plot observed - prior predictive should be independent of obs
        ax.errorbar(
            nu % delta_nu,
            nu,
            xerr=nu_err,
            color="k",
            marker="o",
            linestyle="none",
            label="observed",
        )

    # All mean function components for GP
    # full_mu = [
    # predictive["nu_bkg"].attrs.get("symbol", r"$\nu_\mathrm{bkg}$"),
    # predictive["dnu_he"].attrs.get("symbol", r"$\delta\nu_{He}$"),
    # predictive["dnu_cz"].attrs.get("symbol", r"$\delta\nu_{BCZ}$"),
    # ]
    kindl = kind.lower()
    if kindl == "full":
        y = predictive["nu_pred"]
        # label = r"$\mathrm{GP}($" + " + ".join(full_mu) + r"$,\,K)$"
    elif kindl == "background":
        y = predictive["nu_bkg_pred"]
        # label = full_mu[0]  # <-- just the background, no GP
    elif kindl == "glitchless":
        y = (predictive["nu_pred"] - predictive.get("dnu_he_pred", 0.0) -
             predictive.get("dnu_cz_pred", 0.0))
        y.attrs["unit"] = predictive["nu_pred"].attrs["unit"]
        # label = r"$\mathrm{GP}($" + full_mu[0] + r"$,\,K)$"
    else:
        raise ValueError(f"Kind '{kindl}' is not one of " +
                         "['full', 'background', 'glitchless'].")
    label = f"{kindl} model"

    y_mod = (y - n_pred * delta_nu) % delta_nu
    y_med = y.median(dim=dim)
    label = kwargs.pop("label", label)
    (line, ) = ax.plot(
        y_mod.median(dim=dim),
        y_med,
        label=label,
        **kwargs,
    )

    y_mod_quant = y_mod.quantile(quantiles, dim=dim)
    num_quant = len(quantiles) // 2
    num_alphas = num_quant * 2 + 1
    alphas = np.linspace(0.1, 0.5, num_alphas)
    base_color = line.get_color()

    if use_alpha:
        colors = [base_color] * num_alphas
    else:
        # Mimic alpha by lightening the base color
        colors = [_lighten_color(base_color, 1.5 * a) for a in alphas]
        alphas = [None] * num_alphas  # reset alphas to None

    for i in range(num_quant):
        delta = quantiles[-i - 1] - quantiles[i]
        ax.fill_betweenx(
            y_med,
            y_mod_quant[i],
            y_mod_quant[-i - 1],
            color=colors[2 * i + 1],
            alpha=alphas[2 * i + 1],
            label=f"{delta:.1%} CI",
        )

    # xlabel = [r"$\nu\,\mathrm{mod}.\,{" + f"{delta_nu:.2f}" + "}$"]
    unit = u.Unit(y.attrs.get("unit", ""))
    # if str(unit) != "":
    # xlabel.append(unit.to_string(format="latex_inline"))
    # ax.set_xlabel("/".join(xlabel))

    ax.set_xlabel(r"$\nu$ modulo " + f"{delta_nu:.2f} " +
                  f"({unit.to_string(format='latex_inline')})")
    # ylabel = [r"$\nu$"]
    unit = u.Unit(nu.attrs.get("unit", "uHz"))
    # if str(unit) != "":
    # ylabel.append(unit.to_string(format="latex_inline"))
    # ax.set_ylabel("/".join(ylabel))

    ax.set_ylabel(r"$\nu$ " + f"({unit.to_string(format='latex_inline')})")

    ax.legend()

    return ax