예제 #1
0
파일: fom.py 프로젝트: wswxyq/lstm_ee
def calc_fom_hist(pred, true, weights, bins, range):
    """Calculate relative energy resolution histogram.

    Parameters
    ----------
    pred : ndarray, shape (N,)
        Array of predicted energies.
    true : ndarray, shape (N,)
        Array of true energies.
    weights : ndarray, shape (N,)
        Array of sample weights.
    bins : int or ndarray
        If int then defined the number of bins in a histogram.
        If ndarray then `bins` defines edges of the histogram.
        C.f. `np.histogram`
    range : (float, float) or None
        Range of a histogram (lower, upper). If None, range will be determined
        according to the `np.histogram` rules.

    Returns
    -------
    cafplot.RHist1D
        `Rhist1D` object containing the relative energy resolution histogram.
    """
    # pylint: disable=redefined-builtin
    fom  = (pred - true) / true
    return RHist1D.from_data(fom, bins, weights, range)
예제 #2
0
파일: hist.py 프로젝트: wswxyq/lstm_ee
def plot_hist_base(
    list_of_data_weight_label_color, key, spec, ratio_plot_type, stat_err,
    log = False
):
    """Plot multiple energy histograms"""

    if ratio_plot_type is not None:
        f, ax, axr = make_figure_with_ratio()
    else:
        f, ax = plt.subplots()

    if log:
        ax.set_yscale('log')

    list_of_rhist_color = []

    for (data,weights,label,color) in list_of_data_weight_label_color:
        rhist = RHist1D.from_data(data[key], spec.bins_x, weights)

        centers = (rhist.bins_x[1:] + rhist.bins_x[:-1]) / 2
        mean = np.average(centers, weights = rhist.hist)

        plot_rhist1d(
            ax, rhist,
            histtype = 'step',
            marker    = None,
            linestyle = '-',
            linewidth = 2,
            label     = "%s. MEAN = %.3e" % (label, mean),
            color     = color,
        )

        if stat_err:
            plot_rhist1d_error(
                ax, rhist, err_type = 'bar', color = color, linewidth = 2,
                alpha = 0.8
            )

        list_of_rhist_color.append((rhist, color))

    spec.decorate(ax, ratio_plot_type)

    if not log:
        remove_bottom_margin(ax)

    ax.legend()

    if ratio_plot_type is not None:
        plot_rhist1d_ratios(
            axr,
            [rhist_color[0] for rhist_color in list_of_rhist_color],
            [rhist_color[1] for rhist_color in list_of_rhist_color],
            err_kwargs = { 'err_type' : 'bar' if stat_err else None },
        )
        spec.decorate_ratio(axr, ratio_plot_type)

    return f, ax
예제 #3
0
def get_sgn_bkg_preds(truth, preds, weights, truth_idx, pred_idx, **kwargs):
    """Calculate Signal and Background histograms.

    Parameters
    ----------
    truth : ndarray, shape (N_SAMPLES,)
        Array of true targets.
    preds : ndarray, shape (N_SAMPLES, N_TARGETS)
        Array of predicted target scores.
    weights : ndarray, shape (N_SAMPLES,)
        Sample weights.
    truth_idx : int
        Value of `truth` target that indicates signal sample.
    pred_idx : int
        Index of the target to make histogram of.
    kwargs : dict
        Dictionary of values that will be passed to the `RHist1D` constructor.

    Returns
    -------
    h_sgn : RHist1D
        Histogram of signal values.
    h_bkg : RHist1D
        Histogram of background values.
    """

    sgn_mask = (truth == truth_idx)

    preds_sgn = preds[sgn_mask, pred_idx]
    preds_bkg = preds[~sgn_mask, pred_idx]

    w_sgn = weights[sgn_mask]
    w_bkg = weights[~sgn_mask]

    return (
        RHist1D.from_data(preds_sgn, weights=w_sgn, **kwargs),
        RHist1D.from_data(preds_bkg, weights=w_bkg, **kwargs),
    )
예제 #4
0
def plot_energy(data, weights, name, spec, log_scale=False):
    """Plot single energy distribution."""
    f, ax = plt.subplots()
    if log_scale:
        ax.set_yscale('log')

    rhist = RHist1D.from_data(data, spec.bins_x, weights)

    plot_rhist1d(
        ax,
        rhist,
        histtype='step',
        linewidth=2,
        color='C0',
        label="True %s. Mean: %.2e" %
        (name, np.average(data, weights=weights)),
    )
    plot_rhist1d_error(ax, rhist, err_type='bar', color='C0', linewidth=2)

    spec.decorate(ax)
    ax.legend()
    remove_bottom_margin(ax)

    return f, ax, rhist