コード例 #1
0
def main():
    # pylint: disable=missing-function-docstring
    setup_logging()
    cmdargs = parse_cmdargs()

    dgen, _, _, outdir, _, eval_specs = standard_eval_prologue(
        cmdargs, PRESETS_EVAL)

    plotdir = os.path.join(outdir, 'targets')
    os.makedirs(plotdir, exist_ok=True)

    true_energies = get_true_energies(dgen)
    weights = dgen.weights

    for label, energy in true_energies.items():
        for log_scale in [True, False]:
            f, _, rhist = plot_energy(energy, weights,
                                      eval_specs['name_map'][label],
                                      eval_specs['hist'][label], log_scale)

            save_fig(f, os.path.join(plotdir,
                                     '%s_log(%s)' % (label, log_scale)),
                     cmdargs.ext)

        np.savetxt(os.path.join(plotdir, "%s_hist.txt" % (label)), rhist.hist)
        np.savetxt(os.path.join(plotdir, "%s_bins.txt" % (label)),
                   rhist.bins_x)
コード例 #2
0
def plot_2d_embedding_scatter(truth, preds, labels, fname_base, ext):
    """Make a scatterplot of embedded data"""
    classes = np.arange(len(labels))

    f, ax = plt.subplots()

    # To make points visible and avoid clutter
    if len(truth) >= 20000:
        alpha = 0.25
    elif len(truth) >= 10000:
        alpha = 0.50
    else:
        alpha = 0.75

    for class_idx in classes:
        truth_mask = (truth == class_idx)
        preds_idx = preds[truth_mask, :]
        color = 'C%d' % (class_idx, )

        ax.scatter(preds_idx[:, 0],
                   preds_idx[:, 1],
                   label=labels[class_idx],
                   marker=',',
                   color=color,
                   alpha=alpha)

    add_nice_legend(ax)
    save_fig(f, fname_base + "_scatter", ext)
コード例 #3
0
def plot_distributions(truth, preds, weights, bins, labels, plotdir, ext):
    """Make and save plots of hists of pred PID values for each true component.

    Parameters
    ----------
    truth : ndarray, shape (N_SAMPLES,)
        Array of true targets.
    preds : ndarray, shape (N_SAMPLES, N_TARGETS)
        Array of predicted target scores.
    weights : ndarray, shape (N_SAMPLES,)
        Sample weights.
    bins : int
        Number of bins of the hist plots.
    labels : list of str, len(N_TARGETS)
        List of labels for each target.
    plotdir : str
        Directory where plots will be saved.
    ext : str or list of str
        Extension of the plots. If list then the plots will be saved in
        multiple formats.
    """

    for pred_idx,x_label in enumerate(labels):
        for log_scale in [ True, False ]:
            label = labels[pred_idx]

            f, _ax = plot_detailed_distributions(
                truth, preds[:, pred_idx], weights, bins, labels, log_scale,
                x_label
            )

            fname = 'distrib_%s_log(%s)' % (label, log_scale)
            save_fig(f, os.path.join(plotdir, fname), ext)
コード例 #4
0
def plot_separate_foms(rhist_fom_list, labels, fom_label, plotdir, ext):
    """Make and save separate plots of FOMs.

    Parameters
    ----------
    rhist_fom_list : list of cafplot.RHist1D
        List of figures of merit to be plotted.
    labels : list of str
        List of x axis labels. One for each item in `rhist_fom_list`.
    fom_label : str
        Name of the FOM.
    plotdir : str
        Directory where plots will be saved.
    ext : str or list of str
        Extension of the plots. If list then the plots will be saved in
        multiple formats.
    """

    for pred_idx, x_label in enumerate(labels):
        rhist_fom = rhist_fom_list[pred_idx]

        f, ax = plot_single_fom(rhist_fom, x_label, fom_label, color='C0')

        max_pos, max_val = find_fom_maxval(rhist_fom)
        plot_maxval_line(ax, max_pos, max_val, color='C1')

        ax.legend()

        fname = 'fom_%s_%s' % (fom_label, x_label)
        save_fig(f, os.path.join(plotdir, fname), ext)
        plt.close(f)
コード例 #5
0
def plot_whist(whist_train, whist_test, bins, plotdir, ext):
    """Plot inverse of the TrueE histogram"""
    f, ax = plt.subplots()
    ax.set_yscale('log')

    if whist_train is not None:
        plot_nphist1d_base(ax,
                           whist_train,
                           bins,
                           histtype='step',
                           linewidth=2,
                           label='Train')

    plot_nphist1d_base(ax,
                       whist_test,
                       bins,
                       histtype='step',
                       linewidth=2,
                       label='Test')

    ax.legend()
    ax.minorticks_on()
    ax.grid(True, which='major', linestyle='dashed', linewidth=1.0)
    ax.grid(True, which='minor', linestyle='dashed', linewidth=0.5)
    ax.set_xlabel('Target')
    ax.set_ylabel('Weight')

    save_fig(f, os.path.join(plotdir, "weights_%s" % (whist_train is None)),
             ext)
コード例 #6
0
def plot_binstats(list_of_pred_true_weight_label_color,
                  plot_specs_abs,
                  plot_specs_rel,
                  fname,
                  ext,
                  stat_list=['mean', 'rms']):
    """Make and save binstat plots of energy resolution vs true energy.

    Parameters
    ----------
    list_of_pred_true_weight_label_color : list
        List of tuples of the form (pred, true, weights, label, color) where:
        pred : dict
            Dictionary where keys are energy labels and values are the
            `ndarray` (shape (N,)) of predicted energies.
        true : `ndarray`, shape (N,)
            Dictionary where keys are energy labels and values are the
            `ndarray` (shape (N,)) of true energies.
        weights : `ndarray`, shape (N,)
            Sample weights.
        label : str
            Plot label.
        color : str
            Line color.
        A separate plot will be made for each key in the `pred`.
        Lines for all elements of the `list_of_pred_true_weight_label_color`
        will be drawn on each plot.
    plot_specs_abs : dict
        Dictionary where keys are energy labels and values are `PlotSpec` that
        specify axes and bins of the absolute energy resolution plots.
    plot_specs_rel : dict
        Dictionary where keys are energy labels and values are `PlotSpec` that
        specify axes and bins of the relative energy resolution plots.
    fname : str
        Prefix of the path that will be used to build plot file names.
    ext : str or list of str
        Extension of the plot. If list then the plot will be saved in multiple
        formats.
    stat_list : list, optional
        List of statistic properties for which binstat plots will be made.
        Default: [ 'mean', 'rms' ]
    """
    # pylint: disable=dangerous-default-value

    plot_types = plot_specs_abs.keys()

    for is_rel, spec, rel_label in zip([True, False],
                                       [plot_specs_rel, plot_specs_abs],
                                       ['rel', 'abs']):
        for k in plot_types:
            for stat in stat_list:
                f, _ = plot_binstat_base(list_of_pred_true_weight_label_color,
                                         k, spec[k], stat, is_rel)

                fullname = "%s_%s_%s_%s" % (fname, k, stat, rel_label)
                save_fig(f, fullname, ext)
                plt.close(f)
コード例 #7
0
def plot_fom(list_of_rhist_stats_labels_pos_colors, plot_specs, fname, ext):
    """Make and save plots of relative energy resolution histograms

    Parameters
    ----------
    list_of_rhist_stats_labels_pos_colors : list
        List of tuples of the form (rhist, stats, label, pos, color) where:
        rhist : dict
            Dictionary where keys are energy labels and the values are
            `cafplot.RHist1D` containing relative energy resolution histograms.
            C.f.  `calc_fom_stats_hists`.
        stats : dict
            Dictionary where keys are energy labels and values are the
            dictionaries of values of various statistical properties of the
            relative energy resolution.
            C.f. `calc_fom_stats_hists`.
        label : str
            Plot label.
        pos : str
            Position of the textbox with statistical information.
            It should be of the form "([top|bottom]-)?[left|right]".
        color : str
            Line color.
        A separate plot will be made for each key in the `pred`.
        Lines for all elements of the `list_of_pred_true_weight_label_color`
        will be drawn on each plot.
    plot_specs : dict
        Dictionary where keys are energy labels and values are `PlotSpec` that
        specify axes and bins of the relative energy resolution plots.
    fname : str
        Prefix of the path that will be used to build plot file names.
    ext : str or list of str
        Extension of the plot. If list then the plot will be saved in multiple
        formats.
    """

    for k, spec in plot_specs.items():

        f, ax = plt.subplots()

        for (rhist_dict,stats_dict,label,pos,color) in \
                list_of_rhist_stats_labels_pos_colors:

            plot_fom_base(ax, rhist_dict[k], stats_dict[k], label, pos, color,
                          spec)

        remove_bottom_margin(ax)

        save_fig(f, '%s_%s' % (fname, k), ext)
        plt.close(f)
コード例 #8
0
ファイル: hist.py プロジェクト: wswxyq/lstm_ee
def plot_energy_hists(
    list_of_data_weight_label_color, plot_specs, fname, ext, log = False
):
    """Make and save plots of energy histograms.

    Parameters
    ----------
    list_of_pred_true_weight_label_color : list
        List of tuples of the form (pred, true, weights, label, color) where:
        data : dict
            Dictionary where keys are energy labels and values are the
            `ndarray` (shape (N,)) of energies.
        weights : `ndarray`, shape (N,)
            Sample weights.
        label : str
            Plot label.
        color : str
            Line color.
        A separate plot will be made for each key in the `data`.
        Lines for all elements of the `list_of_data_weight_label_color`
        will be drawn on each plot.
    plot_specs : dict
        Dictionary where keys are energy labels and values are `PlotSpec` that
        specify axes and bins of the energy plots.
    fname : str
        Prefix of the path that will be used to build plot file names.
    ext : str or list of str
        Extension of the plot. If list then the plot will be saved in multiple
        formats.
    log : bool
        If True then the vertical axis will have logarithmic scale.
        Default: False.
    """

    for k in plot_specs.keys():
        for ratio_plot_type in [ None, 'auto', 'fixed' ]:
            for stat_err in [ True, False ]:

                f, _ = plot_hist_base(
                    list_of_data_weight_label_color, k, plot_specs[k],
                    ratio_plot_type, stat_err, log
                )

                fullname = "%s_%s_ratio-%s_staterr-%s" % (
                    fname, k, ratio_plot_type, stat_err
                )
                save_fig(f, fullname, ext)
                plt.close(f)
コード例 #9
0
def plot_overlayed_foms(fom_dict, foms_to_overlay, labels, plotdir, ext):
    """Make and save separate plots of FOMs.

    Multiple plots will be made -- one for each target.
    Multiple FOMs will be overlayed on a single plot -- one for each element
    in `foms_to_overlay`.

    Parameters
    ----------
    fom_dict : dict
        Dictionary of FOMs where keys are names of the FOM and values are
        lists of cafplot.RHist1D containing FOM for different targets.
    foms_to_overlay : list of str
        List of FOM names that will be overlayed on a single plot.
    labels : list of str
        List of target names. Each list in the `fom_dict` should have the
        same length as `labels`.
    plotdir : str
        Directory where plots will be saved.
    ext : str or list of str
        Extension of the plots. If list then the plots will be saved in
        multiple formats.
    """

    for pred_idx, x_label in enumerate(labels):
        f, ax = plt.subplots()

        for fom_idx, fom_label in enumerate(foms_to_overlay):
            rhist_fom = fom_dict[fom_label][pred_idx]

            rhist_fom.scale(1 / np.max(rhist_fom.hist))
            plot_rhist1d(ax,
                         rhist_fom,
                         fom_label.capitalize(),
                         histtype='step',
                         color='C%d' % (fom_idx, ))

        decorate_fom_axes(ax, x_label)

        ax.set_ylabel('FOMs')
        ax.set_title('Normalized FOMs')
        ax.legend()

        fname = 'overlayed_foms_%s' % (x_label, )
        save_fig(f, os.path.join(plotdir, fname), ext)
        plt.close(f)
コード例 #10
0
ファイル: sample_stats.py プロジェクト: usert5432/slice_lid
def plot_counts(counts, labels, plotdir, ext):
    # pylint: disable=missing-function-docstring
    norm = np.sum(counts)

    f, ax = plt.subplots()

    rhist = RHist1D(np.arange(len(labels) + 1), counts)
    rhist.scale(100 / norm)

    plot_rhist1d(ax, rhist, label = "", histtype = 'bar', color = 'tab:blue')

    ax.set_xticks([ i + 0.5 for i in range(len(labels))])
    ax.set_xticklabels(labels)

    ax.set_ylabel("Fraction [%]")
    ax.set_title("Distribution of events per category")

    save_fig(f, os.path.join(plotdir, "sample_distribution_test"), ext)
コード例 #11
0
def plot_2d_embedding_density(truth, preds, labels, bins, fname_base, ext):
    """Make a density plot of embedded coordinates."""
    color_hist, xbins, ybins = get_color_hist(truth, preds, labels, bins)
    f, ax = plt.subplots()

    ax.imshow(np.transpose(color_hist, (1, 0, 2)),
              interpolation='nearest',
              origin='lower',
              extent=[xbins[0], xbins[-1], ybins[0], ybins[-1]])

    ax.set_aspect('auto')

    handles = [
        mpatches.Patch(color='C%d' % (idx), label=label)
        for idx, label in enumerate(labels)
    ]

    add_nice_legend(ax, handles=handles)
    save_fig(f, fname_base + "_density", ext)
コード例 #12
0
ファイル: error_matrix.py プロジェクト: usert5432/slice_lid
def plot_error_matrix_base(err_mat, labels, by_truth, fname, ext):
    """Make and save plot of a single error matrix

    Parameters
    ----------
    err_mat : ndarray, shape (N, N)
        Error matrix. First axis is true dimension, second is predicted.
    labels : list of str, len(N)
        List of labels for each target of the error matrix `err_mat`.
    by_truth : bool
        Indicates whether error matrix `err_mat` is normalized by true values.
        If False, then it is assumed that the error matrix is normalized by
        predicted values.
    fname : str
        File path without extension where the error matrix will be saved.
    ext : str or list of str
        Extension of the plot. If list then the plot will be saved in multiple
        formats.
    """
    f, ax = plt.subplots()

    im = ax.imshow(err_mat)

    pplot_matrix_values(ax, err_mat)

    ax.set_xlabel("Prediction")
    ax.set_ylabel("Truth")

    ax.set_xticks(range(len(labels)))
    ax.set_xticklabels(labels)

    ax.set_yticks(range(len(labels)))
    ax.set_yticklabels(labels, rotation='vertical', va='center')

    if by_truth is not None:
        if by_truth:
            ax.set_title('Normalized by Truth')
        else:
            ax.set_title('Normalized by Preds')

    f.colorbar(im)
    save_fig(f, fname, ext)
コード例 #13
0
ファイル: aux.py プロジェクト: wswxyq/lstm_ee
def plot_rel_res_vs_true(pred_dict, true_dict, weights, plot_specs, fname,
                         ext):
    """
    Make and save 2D hist plots of relative energy resolution vs true energy.

    Parameters
    ----------
    pred_dict : dict
        Dictionary where keys are energy labels and values are `ndarray`
        (shape (N,)) of predicted energies.
    true_dict : dict
        Dictionary where keys are energy labels and values are `ndarray`
        (shape (N,)) of true energies.
    weights : ndarray, shape (N,)
        Sample weights.
    plot_specs : dict
        Dictionary where keys are energy labels and values are `PlotSpec` that
        specify the plot style and histogram bins.
    fname : str
        Prefix of the path that will be used to build plot file names.
    ext : str or list of str
        Extension of the plot. If list then the plot will be saved in multiple
        formats.
    """

    for k in pred_dict.keys():
        pred = pred_dict[k]
        true = true_dict[k]
        spec = plot_specs[k]

        for logNorm in [True, False]:
            try:
                f, _ = plot_rel_res_vs_true_base(pred, true, weights, spec,
                                                 logNorm)
            except ValueError as e:
                print("Failed to make plot: %s" % (str(e)))
                continue

            path = "%s_%s_log(%s)" % (fname, k, logNorm)

            save_fig(f, path, ext)
            plt.close(f)
コード例 #14
0
def plot_train_val_history(x_name, y_name, log, skip, plotdir, ext):
    """Make and save plot of the training metric for train/validation samples

    Parameters
    ----------
    x_name ; str
        Label in the `log` which values will be used as x axis.
    y_name ; str
        Label in the `log` which values will be used as y axis.
        It is assumed that metrics evaluated on the training dataset
        will have name `y_name` and metrics evaluated on the validation
        dataset will have name "val_" + `y_name".
    log : pandas.DataFrame
        `DataFrame` that holds training history.
    skip : int
        Number of initial data point to skip when making plot.
    plotdir : str
        Directory where plot will be saved.
    ext : str or list of str
        Extension of the plot. If list then the plot will be saved in multiple
        formats.
    """

    x = log.loc[skip:, x_name].values.ravel()

    y_train = log.loc[skip:, y_name]
    y_val = log.loc[skip:, 'val_' + y_name]

    f, ax = plt.subplots()

    plot_scatter_with_average(ax, x, y_train, 'C0', 'Train', 10)
    plot_scatter_with_average(ax, x, y_val, 'C1', 'Test', 10)

    ax.set_xlabel(x_name)
    ax.set_ylabel(y_name)
    ax.set_title('%s vs %s' % (y_name, x_name))

    ax.legend()

    fname = "%s/plot_%s_vs_%s_%d" % (plotdir, y_name, x_name, skip)
    save_fig(f, fname, ext)
コード例 #15
0
def plot_profile(var_list,
                 stats,
                 stat_name,
                 base_stats=None,
                 label_x=None,
                 label_y=None,
                 sort_type=None,
                 annotate=False,
                 categorical=True,
                 fname=None,
                 ext='png'):
    """
    Make and save plots of evaluation metric vs training config parameter(-s).

    This function will make and save a number plots (one for different
    axis scales) of the evaluation metric `stat_name` vs values of the
    training config parameter in `var_list` for different models.

    Parameters
    ----------
    var_list : list
        List of values of the configuration parameters. Each element in
        `var_list` specifies a different training. Value of the configuration
        parameter can be of any type. These values will be used for the x axis.
    stats : dict
        Dictionary where keys are the names of evaluation metrics and the
        values are the `ndarray` of evaluation metric. Each `ndarray` should
        have length equal to len(`var_list`) and elements of the `ndarray`
        correspond to the elements of `var_list`.
    stat_name : str
        Name of the evaluation metric in `stats` that will be used as y
        coordinate when making a plot.
    base_stats : dict or None, optional
        Dictionary where keys are the names of evaluation metrics and the
        values are the `ndarray` of evaluation metric for the baseline
        energies. Each `ndarray` is assumed to have length of `var_list`, but
        contain copies of the same value (baseline energy is independent
        of training). Therefore, only the first element of each `ndarray` will
        be used when plotting baseline metrics.
        If None, then the values of baseline metrics will not be shown.
        Default: None.
    label_x : str or None, optional
        Label of the x axis. Default: None.
    label_y : str or None, optional
        Label of the y axis. Default: None.
    sort_type : { 'x', 'y', None }, optional
        If not None, then the points will be ordered by their coordinate
        specified by `sort_type`.
        For example, if the configuration parameter is a categorical variable
        (e.g. model name) then the x axis won't have any natural order, and
        it may make sense to order points by their y coordinates.
        Default: None.
    annotate : bool, optional
        If True, then y value will be shown b next to each data point.
        Default: False.
    categorical : bool, optional
        Whether to assume that the x variable is categorical (as opposed to
        numerical). For example, if `var_list` contains values of the learning
        rate, then it is a numerical variable. On the other hand if `var_list`
        contains names of the models (str), then such variable cannot be
        represented as a number and therefore categorical.
        If x variable is categorical then it does not make sense to plot it
        in logarithmic scale or convert values to numbers. `categorical`
        parameter hints `plot_profile_base` not to do those things.
        Default: True
    fname : str
        Prefix of the path that will be used to build plot file names.
    ext : str or list of str
        Extension of the plots. If list then the plots will be saved in
        multiple formats.
    """
    x = prepare_x_var(var_list, categorical)
    y = stats[stat_name].values.ravel()
    err = get_err(stats, stat_name)

    x, y, err = sort_data(x, y, err, sort_type)

    scale_x_list = get_x_scales(x, categorical)
    scale_y_list = ['linear', 'symlog' if (np.any(y <= 0)) else 'log']

    for scale_y in scale_y_list:
        for scale_x in scale_x_list:
            f, ax = plot_profile_base(x, y, err, label_x, label_y, annotate,
                                      categorical, scale_x, scale_y)
            plot_baseline_stats(ax, base_stats, stat_name)

            fullname = "%s_xs(%s)_ys(%s)" % (fname, scale_x, scale_y)
            save_fig(f, fullname, ext)
            plt.close(f)