Exemple #1
0
def plot_roc(X_test, y_test, bdt, BDT_name=None):
    """ Plot and save the roc curve in ``{loc['plots']}/BDT/{BDT_name}/ROC.pdf``

    Parameters
    ----------
    X_test        : numpy.ndarray
        signal and background concatenated, testing sample
    y_test        : numpy.array
        signal and background concatenated, testing sample,
        0 if the events is background, 1 if it is signal
    bdt           : sklearn.ensemble.AdaBoostClassifier or sklearn.ensemble.GradientBoostingClassifier
        trained BDT
    BDT_name      : str
        name of the BDT, used for the name of the saved plot
    folder_name   : str
        name of the folder where to save the BDT

    Returns
    -------
    fig : matplotlib.figure.Figure
        Figure of the plot
    ax : matplotlib.figure.Axes
        Axis of the plot
    """

    # Get the results -----
    # result of the BDT of the test sample
    decisions = bdt.decision_function(X_test)
    fpr, tpr, thresholds = roc_curve(y_test, decisions)  # roc_curve
    # y_test: true results
    # decisions: result found by the BDT
    # fpr: Increasing false positive rates such that element i is the false positive rate of predictions with score >= thresholds[i].
    # tpr: Increasing true positive rates such that element i is the true positive rate of predictions with score >= thresholds[i].
    # thresholds: Decreasing thresholds on the decision function used to
    # compute fpr and tpr. thresholds[0] represents no instances being
    # predicted and is arbitrarily set to max(y_score) + 1
    fig, ax = plt.subplots(figsize=(8, 6))
    roc_auc = auc(fpr, tpr)

    # Plot the results -----
    ax.plot(fpr, tpr, lw=1, label='ROC (area = %0.2f)' % (roc_auc))
    ax.plot([0, 1], [0, 1], '--', color=(0.6, 0.6, 0.6), label='Luck')
    ax.set_xlim([-0.05, 1.05])
    ax.set_ylim([-0.05, 1.05])
    ax.set_xlabel('False Positive Rate', fontsize=25)
    ax.set_ylabel('True Positive Rate', fontsize=25)
    title = 'Receiver operating characteristic'

    ax.legend(loc="lower right", fontsize=20.)
    pt.show_grid(ax)
    pt.fix_plot(ax,
                factor_ymax=1.1,
                show_leg=False,
                fontsize_ticks=20.,
                ymin_to_0=False)
    # Save the results -----

    pt.save_fig(fig, "ROC", folder_name=f'BDT/{BDT_name}')

    return fig, ax
Exemple #2
0
def end_plot_function(fig,
                      save_fig=True,
                      fig_name=None,
                      folder_name=None,
                      default_fig_name=None,
                      ax=None):
    """ tight the layout and save the file or just return the ``matplotlib.figure.Figure`` and ``matplotlib.axes.Axes``

    Parameters
    ----------
    fig : matplotlib.figure.Figure
        Figure of the plot
    fig_name : str
        name of the file that will be saved
    folder_name : str
        name of the folder where the figure will be saved
    default_fig_name : str
        name of the figure that will be saved, in the case ``fig_name`` is ``None``
    ax : matplotlib.figure.Axes
        Axis of the plot

    Returns
    -------
    fig : matplotlib.figure.Figure
        Figure of the plot (only if ``axis_mode`` is ``False``)
    ax : matplotlib.figure.Axes
        Axis of the plot (only if ``axis_mode`` is ``False``)
    """
    plt.tight_layout()

    if save_fig and fig is not None:
        pt.save_fig(fig,
                    fig_name=fig_name,
                    folder_name=folder_name,
                    default_fig_name=default_fig_name)

    if fig is not None:
        return fig, ax
Exemple #3
0
def plot_scatter2d(dfs,
                   branches,
                   latex_branches,
                   units=[None, None],
                   low=None,
                   high=None,
                   n_bins=100,
                   colors=['g', 'r', 'o', 'b'],
                   data_name=None,
                   title=None,
                   fig_name=None,
                   folder_name=None,
                   fontsize_label=default_fontsize['label'],
                   save_fig=True,
                   ax=None,
                   get_sc=False,
                   pos_text_LHC=None,
                   **params):
    """  Plot a 2D histogram of 2 branches.

    Parameters
    ----------
    dfs               : pandas.Dataframe or list(pandas.Dataframe)
        Dataset or list of datasets.
    branches          : [str, str]
        names of the two branches
    latex_branches    : [str, str]
        latex names of the two branches
    units             : str or [str, str]
        Common unit or list of two units of the two branches
    n_bins            : int or [int, int]
        number of bins
    log_scale         : bool
        if true, the colorbar is in logscale
    low               : float or [float, float]
        low  value(s) of the branches
    high              : float or [float, float]
        high value(s) of the branches
    data_name         : str
        name of the data, this is used to define the name of the figure,
        in the case ``fig_name`` is not defined, and define the legend if there is more than 1 dataframe.
    colors            : str or list(str)
        color(s) used for the histogram(s)
    title             : str
        title of the figure
    fig_name          : str
        name of the saved figure
    folder_name       : str
        name of the folder where to save the figure
    fontsize_label    : float
        fontsize of the label of the axes
    save_fig          : bool
        specifies if the figure is saved
    ax            : matplotlib.axes.Axes
        axis where to plot
    get_sc            : bool
        if True: get the scatter plot
    pos_text_LHC    : dict, list or str
        passed to :py:func:`HEA.plot.tools.set_text_LHCb` as the ``pos`` argument.

    Returns
    -------
    fig : matplotlib.figure.Figure
        Figure of the plot (only if ``ax`` is not specified)
    ax : matplotlib.figure.Axes
        Axis of the plot (only if ``ax`` is not specified)
    scs : matplotlib.PathCollection or list(matplotlib.PathCollection)
        scatter plot or list of scatter plots (only if ``get_sc`` is ``True``)
    """

    # low, high and units into a list of size 2
    low = el_to_list(low, 2)
    high = el_to_list(high, 2)

    units = el_to_list(units, 2)

    if ax is not None:
        save_fig = False

    fig, ax = get_fig_ax(ax)

    title = string.add_text(None, title, default=None)

    ax.set_title(title, fontsize=25)

    scs = [None] * len(dfs)
    for k, (data_name, df) in enumerate(dfs.items()):
        scs[k] = ax.scatter(df[branches[0]],
                            df[branches[1]],
                            c=colors[k],
                            label=data_name,
                            **params)
    if len(scs) == 1:
        scs = scs[0]

    ax.set_xlim([low[0], high[0]])
    ax.set_ylim([low[1], high[1]])

    # Label, color bar
    pt.set_label_ticks(ax)
    pt.set_text_LHCb(ax, pos=pos_text_LHC)

    set_label_2Dhist(ax, latex_branches, units, fontsize=fontsize_label)

    # Save the data
    if save_fig:
        pt.save_fig(
            fig, fig_name, folder_name,
            string.add_text(string.list_into_string(branches, '_vs_'),
                            string.list_into_string(data_name, '_'), '_'))

    if fig is not None:
        if get_sc:
            return fig, ax, scs
        else:
            return fig, ax
    else:
        if get_sc:
            return scs
Exemple #4
0
def plot_hist_fit(df,
                  branch,
                  latex_branch=None,
                  unit=None,
                  weights=None,
                  obs=None,
                  n_bins=50,
                  low_hist=None,
                  high_hist=None,
                  color='black',
                  bar_mode=True,
                  models=None,
                  models_names=None,
                  models_types=None,
                  linewidth=2.5,
                  colors=None,
                  title=None,
                  plot_pull=True,
                  bar_mode_pull=True,
                  show_leg=None,
                  fontsize_leg=default_fontsize['legend'],
                  loc_leg='upper left',
                  show_chi2=False,
                  params=None,
                  latex_params=None,
                  colWidths=[0.04, 0.01, 0.06, 0.06],
                  fontsize_res=default_fontsize['legend'],
                  loc_res='upper right',
                  fig_name=None,
                  folder_name=None,
                  data_name=None,
                  save_fig=True,
                  pos_text_LHC=None):
    """ Plot complete histogram with fitted curve, pull histogram and results of the fits. Save it in the plot folder.

    Parameters
    ----------
    df            : pandas.Dataframe
        dataframe that contains the branch to plot
    branch        : str
        name of the branch to plot and that was fitted
    latex_branch  : str
        latex name of the branch, to be used in the xlabel of the plot
    unit            : str
        Unit of the physical quantity
    weights         : numpy.array
        weights passed to ``plt.hist``
    obs           : zfit.Space
        Space used for the fit
    n_bins        : int
        number of desired bins of the histogram
    low_hist      : float
        lower range value for the histogram (if not specified, use the value contained in ``obs``)
    high_hist     : float
        lower range value for the histogram (if not specified, use the value contained in ``obs``)
    color         : str
        color of the histogram
    bar_mode     : bool

        * if True, plot with bars
        * else, plot with points and error bars
    models        : zfit.pdf.BasePDF or list(zfit.pdf.BasePDF) or list(list(zfit.pdf.BasePDF)) or ...
        passed to :py:func:`plot_fitted_curves`
    models_names : str or list(str) or list(list(str))
        passed to :py:func:`plot_fitted_curves`
    models_types  : str
        passed to :py:func:`plot_fitted_curves`
    linewidth     : str
        width of the fitted curve line
    colors        : str
        colors of the fitted curves
    title         : str
        title of the plot
    plot_pull     : bool
        if ``True``, plot the pull diagram
    bar_mode_pull: bool
        if ``True``, the pull diagram is plotted with bars instead of points + error bars
    show_leg      : bool
        if ``True``, show the legend
    fontsize_leg  : float
        fontsize of the legend
    loc_leg       : str
        position of the legend, ``loc`` argument in ``plt.legend``
    show_chi2     : bool
        if ``True``, show the :math:`\\chi^2` in the label of the x-axis of the pull diagram
    params        : dict[zfit.zfitParameter, float]
        Result ``result.params`` of the minimisation of the loss function (given by :py:func:`HEA.fit.fit.launch_fit`)
    latex_params  :
        Dictionnary with the name of the params.
        Also indicated the branchs to show in the table among all the branchs in params
    colWidths     : [float, float, float, float]
        passed to :py:func:`plot_result_fit`
    fontsize_res   : float
        fontsize of the text in the result table
    loc_res       : str
        position of the result table, loc argument specified in in ``plt.table``
    fig_name      : str
        name of the saved file
    folder_name   : str
        name of the folder where to save the plot
    data_name     : str
        name of the data used to constitute the name of the saved file, if ``fig_name`` is not specified.
    save_fig      : str
        name of the figure to save
    pos_text_LHC    : dict, list or str
        passed to :py:func:`HEA.plot.tools.set_text_LHCb` as the ``pos`` argument.

    Returns
    -------
    fig   : matplotlib.figure.Figure
        Figure of the plot (only if ``axis_mode`` is ``False``)
    ax[0] : matplotlib.figure.Axes
        Axis of the histogram + fitted curves + table
    ax[1] : matplotlib.figure.Axes
        Axis of the pull diagram (only if ``plot_pull`` is ``True``)
    """

    # Create figure
    if plot_pull:
        fig = plt.figure(figsize=(12, 10))
        gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1])
        ax = [plt.subplot(gs[i]) for i in range(2)]
    else:
        fig, ax = plt.subplots(figsize=(8, 6))
        ax = [ax]

    # Retrieve low,high (of x-axis)
    low = float(obs.limits[0])
    high = float(obs.limits[1])

    if low_hist is None:
        low_hist = low
    if high_hist is None:
        high_hist = high

    if latex_branch is None:
        latex_branch = string._latex_format(branch)

    ax[0].set_title(title, fontsize=25)

    # plot 1D histogram of data
    # Histogram
    counts, edges, centres, err = plot_hist_alone(ax[0],
                                                  df[branch],
                                                  n_bins,
                                                  low_hist,
                                                  high_hist,
                                                  color,
                                                  bar_mode,
                                                  alpha=0.1,
                                                  weights=weights)

    # Label
    bin_width = get_bin_width(low_hist, high_hist, n_bins)
    set_label_hist(ax[0], latex_branch, unit, bin_width, fontsize=25)

    # Ticks
    pt.set_label_ticks(ax[0])
    pt.set_text_LHCb(ax[0], pos=pos_text_LHC)

    # Plot fitted curve
    if isinstance(models, list):
        model = models[0]  # the first model is the "global" one
    else:
        model = models

    plot_scaling = _get_plot_scaling(counts, low_hist, high_hist, n_bins)
    plot_fitted_curves(ax[0],
                       models,
                       plot_scaling,
                       low,
                       high,
                       models_names=models_names,
                       models_types=models_types,
                       line_width=2.5,
                       colors=colors,
                       fontsize_legend=fontsize_leg,
                       loc_leg=loc_leg,
                       show_legend=show_leg)

    pt.change_range_axis(ax[0], factor_max=1.1)

    color_pull = colors if not isinstance(colors, list) else colors[0]
    # Plot pull histogram
    if plot_pull:
        plot_pull_diagram(ax[1],
                          model,
                          counts,
                          edges,
                          centres,
                          err,
                          color=color_pull,
                          low=low,
                          high=high,
                          plot_scaling=plot_scaling,
                          show_chi2=show_chi2,
                          bar_mode_pull=bar_mode_pull)

    # Plot the fitted parameters of the fit
    if params is not None:
        plot_result_fit(ax[0],
                        params,
                        latex_params=latex_params,
                        fontsize=fontsize_res,
                        colWidths=colWidths,
                        loc=loc_res)

    # Save result
    plt.tight_layout()
    if save_fig:
        pt.save_fig(fig, fig_name, folder_name, f'{branch}_{data_name}_fit')

    if plot_pull:
        return fig, ax[0], ax[1]
    else:
        return fig, ax[0]
Exemple #5
0
def compare_train_test(bdt,
                       X_train,
                       y_train,
                       X_test,
                       y_test,
                       bins=30,
                       BDT_name="",
                       colors=['red', 'green']):
    """ Plot and save the overtraining plot in ``{loc['plots']}/BDT/{folder_name}/overtraining_{BDT_name}.pdf``

    Parameters
    ----------
    bdt           : sklearn.ensemble.AdaBoostClassifier or sklearn.ensemble.GradientBoostingClassifier
        trained BDT classifier
    X_train : numpy.ndarray
        Array with signal and MC data concatenated and shuffled for training
    y_train : numpy.array
        Array with 1 for the signal events, and 0 for background events (shuffled) for training
    X_text  : numpy.ndarray
        Array with signal and MC data concatenated and shuffled for test
    y_test  : numpy.array
        Array with 1 for the signal events, and 0 for background events (shuffled) for test
    bins          : int
        number of bins of the plotted histograms
    BDT_name      : str
        name of the BDT, used for the folder where the figure is saved

    Returns
    -------
    fig              : matplotlib.figure.Figure
        Figure of the plot
    ax               : matplotlib.figure.Axes
        Axis of the plot
    s_2samp_sig      : float
        Kolmogorov-Smirnov statistic for the signal distributions
    ks_2samp_bkg     : float
        Kolmogorov-Smirnov statistic for the background distributions
    pvalue_2samp_sig : float
        p-value of the Kolmogorov-Smirnov test for the signal distributions
    pvalue_2samp_bkg : float
        p-value of the Kolmogorov-Smirnov test for the background distributions
    """
    fig, ax = plt.subplots(figsize=(8, 6))

    ## decisions = [d(X_train_signal), d(X_train_background),d(X_test_signal), d(X_test_background)]
    decisions = []
    for X, y in ((X_train, y_train), (X_test, y_test)):
        d1 = bdt.decision_function(X[y > 0.5]).ravel()
        d2 = bdt.decision_function(X[y < 0.5]).ravel()
        decisions += [d1, d2]  # [signal, background]
    '''
    decisions[0]: train, background
    decisions[1]: train, signal
    decisions[2]: test, background
    decisions[3]: test, signal
    '''

    # Range of the full plot
    low = min(np.min(d) for d in decisions)
    high = max(np.max(d) for d in decisions)
    low_high = (low, high)

    # Plot for the train data the stepfilled histogram of background (y<0.5)
    # and signal (y>0.5)
    ax.hist(decisions[0],
            color=colors[0],
            alpha=0.5,
            range=low_high,
            bins=bins,
            histtype='stepfilled',
            density=True,
            label='S (train)')
    ax.hist(decisions[1],
            color=colors[1],
            alpha=0.5,
            range=low_high,
            bins=bins,
            histtype='stepfilled',
            density=True,
            label='B (train)')

    # Plot for the test data the points with uncertainty of background (y<0.5)
    # and signal (y>0.5)
    hist, bins = np.histogram(decisions[2],
                              bins=bins,
                              range=low_high,
                              density=True)
    scale = len(decisions[2]) / sum(hist)
    # Compute and rescale the error
    err = np.sqrt(hist * scale) / scale

    width = (bins[1] - bins[0])
    center = (bins[:-1] + bins[1:]) / 2
    ax.errorbar(center, hist, yerr=err, fmt='o', c=colors[0], label='S (test)')

    hist, bins = np.histogram(decisions[3],
                              bins=bins,
                              range=low_high,
                              density=True)
    # Compute and rescale the error
    scale = len(decisions[2]) / sum(hist)
    err = np.sqrt(hist * scale) / scale

    ax.errorbar(center, hist, yerr=err, fmt='o', c=colors[1], label='B (test)')

    ax.set_xlabel("BDT output", fontsize=25.)
    ax.set_ylabel("Arbitrary units", fontsize=25.)
    ax.legend(loc='best', fontsize=20.)
    pt.show_grid(ax)

    pt.fix_plot(ax,
                factor_ymax=1.1,
                show_leg=False,
                fontsize_ticks=20.,
                ymin_to_0=False)

    pt.save_fig(fig, "overtraining", folder_name=f'BDT/{BDT_name}')

    ks_2samp_sig = ks_2samp(decisions[0], decisions[2]).statistic
    ks_2samp_bkg = ks_2samp(decisions[1], decisions[3]).statistic
    pvalue_2samp_sig = ks_2samp(decisions[0], decisions[2]).pvalue
    pvalue_2samp_bkg = ks_2samp(decisions[1], decisions[3]).pvalue
    print('Kolmogorov-Smirnov statistic')
    print(f"signal    : {ks_2samp_sig}")
    print(f"Background: {ks_2samp_bkg}")

    print('p-value')
    print(f"signal    : {pvalue_2samp_sig}")
    print(f"Background: {pvalue_2samp_bkg}")
    return fig, ax, ks_2samp_sig, ks_2samp_bkg, pvalue_2samp_sig, pvalue_2samp_bkg
Exemple #6
0
def signal_background(data1,
                      data2,
                      column=None,
                      range_column=None,
                      grid=True,
                      xlabelsize=None,
                      ylabelsize=None,
                      sharex=False,
                      sharey=False,
                      figsize=None,
                      layout=None,
                      n_bins=40,
                      fig_name=None,
                      folder_name=None,
                      colors=['red', 'green'],
                      **kwds):
    """Draw histogram of the DataFrame's series comparing the distribution
    in ``data1`` to ``data2`` and save the result in
    ``{loc['plot']}/BDT/{folder_name}/1D_hist_{fig_name}``

    Parameters
    ----------
    data1        : pandas.Dataframe
        First dataset
    data2        : pandas.Dataframe
        Second dataset
    column       : str or list(str)
        If passed, will be used to limit data to a subset of columns
    grid         : bool
        Whether to show axis grid lines
    xlabelsize   : int
        if specified changes the x-axis label size
    ylabelsize   : int
        if specified changes the y-axis label size
    ax           : matplotlib.axes.Axes
    sharex       : bool
        if ``True``, the X axis will be shared amongst all subplots.
    sharey       : bool
        if ``True``, the Y axis will be shared amongst all subplots.
    figsize      : tuple
        the size of the figure to create in inches by default
    bins         : int,
        number of histogram bins to be used
    fig_name    : str
        name of the saved file
    folder_name  : str
        name of the folder where to save the plot
    colors       : [str, str]
        colors used for the two datasets
    **kwds       : dict
        other plotting keyword arguments, to be passed to the `ax.hist()` function

    Returns
    -------
    fig : matplotlib.figure.Figure
        Figure of the plot
    ax : matplotlib.figure.Axes
        Axis of the plot
    """
    if 'alpha' not in kwds:
        kwds['alpha'] = 0.5

    if column is not None:
        # column is not a list, convert it into a list.
        if not isinstance(column, (list, np.ndarray, Index)):
            column = [column]
        data1 = data1[column]
        data2 = data2[column]

    data1 = data1._get_numeric_data()  # select only numbers
    data2 = data2._get_numeric_data()  # seject only numbers
    naxes = len(data1.columns)  # number of axes = number of available columns

    max_nrows = 4
    # subplots
    fig, axes = plt.subplots(nrows=min(naxes, max_nrows),
                             ncols=1 + naxes // max_nrows,
                             squeeze=False,
                             sharex=sharex,
                             sharey=sharey,
                             figsize=figsize)

    _axes = axes.flat

    if range_column is None:
        range_column = [[None, None] for i in range(len(column))]
    # data.columns = the column labels of the DataFrame.
    for i, col in enumerate(data1.columns):
        # col = name of the column/variable
        ax = _axes[i]

        if range_column[i] is None:
            range_column[i] = [None, None]
        if range_column[i][0] is None:
            low = min(data1[col].min(), data2[col].min())
        else:
            low = range_column[i][0]
        if range_column[i][1] is None:
            high = max(data1[col].max(), data2[col].max())
        else:
            high = range_column[i][1]

        low, high = pt.redefine_low_high(range_column[i][0],
                                         range_column[i][1],
                                         [data1[col], data2[col]])
        _, _, _, _ = h.plot_hist_alone(ax,
                                       data1[col].dropna().values,
                                       n_bins,
                                       low,
                                       high,
                                       colors[1],
                                       mode_hist=True,
                                       alpha=0.5,
                                       density=True,
                                       label='background',
                                       label_ncounts=True)
        _, _, _, _ = h.plot_hist_alone(ax,
                                       data2[col].dropna().values,
                                       n_bins,
                                       low,
                                       high,
                                       colors[0],
                                       mode_hist=True,
                                       alpha=0.5,
                                       density=True,
                                       label='signal',
                                       label_ncounts=True)

        bin_width = (high - low) / n_bins
        latex_branch, unit = RVariable.get_latex_branch_unit_from_branch(col)
        h.set_label_hist(ax,
                         latex_branch,
                         unit,
                         bin_width=bin_width,
                         density=False,
                         fontsize=20)
        pt.fix_plot(ax,
                    factor_ymax=1 + 0.3,
                    show_leg=True,
                    fontsize_ticks=15.,
                    fontsize_leg=20.)
        pt.show_grid(ax, which='major')

    i += 1
    while i < len(_axes):
        ax = _axes[i]
        ax.axis('off')
        i += 1

    #fig.subplots_adjust(wspace=0.3, hspace=0.7)
    if fig_name is None:
        fig_name = string.list_into_string(column)

    plt.tight_layout()
    pt.save_fig(fig, f"1D_hist_{fig_name}", folder_name=f'BDT/{folder_name}')

    return fig, axes
Exemple #7
0
def correlations(data, fig_name=None, folder_name=None, title=None, **kwds):
    """ Calculate pairwise correlation between features of the dataframe data
    and save the figure in ``{loc['plot']}/BDT/{folder_name}/corr_matrix_{fig_name}``

    Parameters
    ----------
    data         : pandas.Dataframe
        dataset
    fig_name     : str
        name of the saved file
    folder_name  : str
        name of the folder where to save the plot
    **kwds       : dict
        other plotting keyword arguments, to be passed to ``pandas.DataFrame.corr()``

    Returns
    -------
    fig : matplotlib.figure.Figure
        Figure of the plot
    ax : matplotlib.figure.Axes
        Axis of the plot
    """

    # simply call df.corr() to get a table of
    # correlation values if you do not need
    # the fancy plotting
    corrmat = data.corr(**kwds)  # correlation

    fig, ax1 = plt.subplots(ncols=1, figsize=(12, 10))  # 1 plot

    opts = {
        'cmap': plt.get_cmap("RdBu"),  # red blue color mode
        'vmin': -1,
        'vmax': +1
    }  # correlation between -1 and 1
    heatmap1 = ax1.pcolor(corrmat, **opts)  # create a pseudo color plot
    plt.colorbar(heatmap1, ax=ax1)  # color bar

    title = string.add_text("Correlations", title, ' - ')
    ax1.set_title(title)

    labels = list(corrmat.columns.values)  # get the list of labels
    for i, label in enumerate(labels):
        latex_branch, _ = RVariable.get_latex_branch_unit_from_branch(label)
        labels[i] = latex_branch
    # shift location of ticks to center of the bins
    ax1.set_xticks(np.arange(len(labels)) + 0.5, minor=False)
    ax1.set_yticks(np.arange(len(labels)) + 0.5, minor=False)
    ax1.set_xticklabels(labels, minor=False, ha='right', rotation=70)
    ax1.set_yticklabels(labels, minor=False)

    plt.tight_layout()

    if fig_name is None:
        fig_name = string.list_into_string(column)

    pt.save_fig(fig,
                f"corr_matrix_{fig_name}",
                folder_name=f'BDT/{folder_name}')

    return fig, ax1
Exemple #8
0
def plot_x_list_ys(x,
                   y,
                   name_x,
                   names_y,
                   latex_name_x=None,
                   latex_names_y=None,
                   fig_name=None,
                   folder_name=None,
                   log_scale=None,
                   save_fig=True,
                   **kwgs):
    """ plot y or a list of y as a function of x. If they are different curves, their points should have the same abscissa.

    Parameters
    ----------
    x                : list(float)
        abcissa of the points
    y                : numpy.array(uncertainties.ufloat) or numpy.array(float) or list(numpy.array(uncertainties.ufloat)) or list(numpy.array(float)) or
        (list instead of numpy.array might work)
        ordinate of the points of the curve(s)
    name_x           : str
        name of the x variable, used for then name of the saved figure
    name_y           : str or list(str)
        name of each list in ``l_y``, used for then name of the saved figure
    latex_name_x : str
        latex name of the x variable, used to label the x-axis
    latex_names_y    : str or list(str)
        surname of each list in ``l_y`` - used for labelling each curve
    fig_name         : str
        name of the file to save
    folder_name      : str
        name of the folder where the image is saved
    factor_ymax      : float
        ymax is multiplied by factor_ymax
    log_scale        : 'both', 'x' ot 'y'
        specifies which axis will be set in log scale
    **kwgs           : dict
        passed to ``plot_xys``

    Returns
    -------
    fig : matplotlib.figure.Figure
        Figure of the plot
    ax : matplotlib.figure.Axes or list(matplotlib.figure.Axes)
        Axis of the plot or list of axes of the plot
    """

    if latex_name_x is None:
        latex_name_x = name_x

    groups_ly = _el_or_list_to_2D_list(y)
    groups_names_y = _el_or_list_to_2D_list(names_y, str)

    if latex_names_y is not None:
        groups_latex_names_y = _el_or_list_to_2D_list(latex_names_y, str)
    else:
        groups_latex_names_y = groups_names_y

    fig, axs = plt.subplots(len(groups_ly), 1, figsize=(8, 4 * len(groups_ly)))

    for k, ly in enumerate(groups_ly):
        if len(groups_ly) == 1:
            ax = axs
        else:
            ax = axs[k]

        # In the same groups_ly, we plot the curves in the same plot
        plot_xys(ax,
                 x,
                 ly,
                 xlabel=name_x,
                 labels=groups_latex_names_y[k],
                 **kwgs)

        pt.set_log_scale(ax, axis=log_scale)

    plt.tight_layout()

    if save_fig:
        pt.save_fig(
            fig, fig_name, folder_name,
            f'{name_x}_vs_{list_into_string(flatten_2Dlist(names_y))}')

    return fig, axs