예제 #1
0
def plot_hist_auto(dfs, branch, cut_BDT=None, **kwargs):
    """ Retrieve the latex name of the branch and unit.
    Then, plot histogram with :py:func:`plot_hist`.

    Parameters
    ----------

    dfs             : dict(str:pandas.Dataframe)
        Dictionnary {name of the dataframe : pandas dataframe}
    cut_BDT         : float or str
        ``BDT > cut_BDT`` cut. Used in the name of saved figure.
    branch          : str
        branch (for instance: ``'B0_M'``), which should be in the dataframe(s)
    **kwargs        : dict
        arguments passed in :py:func:`plot_hist` (except ``branch``, ``latex_branch`` and ``unit``)

    Returns
    -------
    fig : matplotlib.figure.Figure
        Figure of the plot (only if ``ax`` is not specified)
    ax : matplotlib.figure.Axes
        Axis of the plot (only if ``ax`` is not specified)
    """

    # Retrieve particle name, and branch name and unit.
    #     particle, var = retrieve_particle_branch(branch)

    #     name_var = branches_params[var]['name']
    #     unit = branches_params[var]['unit']
    #     name_particle = particle_names[particle]

    latex_branch, unit = pt.get_latex_branches_units(branch)
    data_names = string.list_into_string(list(dfs.keys()))

    add_in_dic('fig_name', kwargs)
    add_in_dic('title', kwargs)
    kwargs['fig_name'] = pt._get_fig_name_given_BDT_cut(
        fig_name=kwargs['fig_name'],
        cut_BDT=cut_BDT,
        branch=branch,
        data_name=data_names)
    kwargs['title'] = pt._get_title_given_BDT_cut(title=kwargs['title'],
                                                  cut_BDT=cut_BDT)

    # Name of the folder = list of the names of the data
    pt._set_folder_name_from_data_name(kwargs, data_names)

    return plot_hist(dfs, branch, latex_branch, unit, **kwargs)
예제 #2
0
def _set_folder_name_from_data_name(kwargs, data_names):
    """ Change the key `"folder_name"` of a dictionnary by the list of data names (in place)

    Parameters
    ----------
    kwargs: dict
        with the key `"folder_name"`
    data_names : str or list(str)
        name of the dataset(s)
    """
    add_in_dic('folder_name', kwargs)
    if kwargs['folder_name'] is None:
        if isinstance(data_names, str):
            str_data_names = data_names
        else:
            str_data_names = string.list_into_string(data_names)
        kwargs['folder_name'] = str_data_names
예제 #3
0
def plot_scatter2d(dfs,
                   branches,
                   latex_branches,
                   units=[None, None],
                   low=None,
                   high=None,
                   n_bins=100,
                   colors=['g', 'r', 'o', 'b'],
                   data_name=None,
                   title=None,
                   fig_name=None,
                   folder_name=None,
                   fontsize_label=default_fontsize['label'],
                   save_fig=True,
                   ax=None,
                   get_sc=False,
                   pos_text_LHC=None,
                   **params):
    """  Plot a 2D histogram of 2 branches.

    Parameters
    ----------
    dfs               : pandas.Dataframe or list(pandas.Dataframe)
        Dataset or list of datasets.
    branches          : [str, str]
        names of the two branches
    latex_branches    : [str, str]
        latex names of the two branches
    units             : str or [str, str]
        Common unit or list of two units of the two branches
    n_bins            : int or [int, int]
        number of bins
    log_scale         : bool
        if true, the colorbar is in logscale
    low               : float or [float, float]
        low  value(s) of the branches
    high              : float or [float, float]
        high value(s) of the branches
    data_name         : str
        name of the data, this is used to define the name of the figure,
        in the case ``fig_name`` is not defined, and define the legend if there is more than 1 dataframe.
    colors            : str or list(str)
        color(s) used for the histogram(s)
    title             : str
        title of the figure
    fig_name          : str
        name of the saved figure
    folder_name       : str
        name of the folder where to save the figure
    fontsize_label    : float
        fontsize of the label of the axes
    save_fig          : bool
        specifies if the figure is saved
    ax            : matplotlib.axes.Axes
        axis where to plot
    get_sc            : bool
        if True: get the scatter plot
    pos_text_LHC    : dict, list or str
        passed to :py:func:`HEA.plot.tools.set_text_LHCb` as the ``pos`` argument.

    Returns
    -------
    fig : matplotlib.figure.Figure
        Figure of the plot (only if ``ax`` is not specified)
    ax : matplotlib.figure.Axes
        Axis of the plot (only if ``ax`` is not specified)
    scs : matplotlib.PathCollection or list(matplotlib.PathCollection)
        scatter plot or list of scatter plots (only if ``get_sc`` is ``True``)
    """

    # low, high and units into a list of size 2
    low = el_to_list(low, 2)
    high = el_to_list(high, 2)

    units = el_to_list(units, 2)

    if ax is not None:
        save_fig = False

    fig, ax = get_fig_ax(ax)

    title = string.add_text(None, title, default=None)

    ax.set_title(title, fontsize=25)

    scs = [None] * len(dfs)
    for k, (data_name, df) in enumerate(dfs.items()):
        scs[k] = ax.scatter(df[branches[0]],
                            df[branches[1]],
                            c=colors[k],
                            label=data_name,
                            **params)
    if len(scs) == 1:
        scs = scs[0]

    ax.set_xlim([low[0], high[0]])
    ax.set_ylim([low[1], high[1]])

    # Label, color bar
    pt.set_label_ticks(ax)
    pt.set_text_LHCb(ax, pos=pos_text_LHC)

    set_label_2Dhist(ax, latex_branches, units, fontsize=fontsize_label)

    # Save the data
    if save_fig:
        pt.save_fig(
            fig, fig_name, folder_name,
            string.add_text(string.list_into_string(branches, '_vs_'),
                            string.list_into_string(data_name, '_'), '_'))

    if fig is not None:
        if get_sc:
            return fig, ax, scs
        else:
            return fig, ax
    else:
        if get_sc:
            return scs
예제 #4
0
def plot_hist2d(df,
                branches,
                latex_branches,
                units,
                low=None,
                high=None,
                n_bins=100,
                log_scale=False,
                title=None,
                fig_name=None,
                folder_name=None,
                data_name=None,
                save_fig=True,
                ax=None,
                pos_text_LHC=None):
    """  Plot a 2D histogram of 2 branches.

    Parameters
    ----------
    df                : pandas.Dataframe
        Dataframe that contains the 2 branches to plot
    branches          : [str, str]
        names of the two branches
    latex_branches    : [str, str]
        latex names of the two branches
    units             : str or [str, str]
        Common unit or list of two units of the two branches
    n_bins            : int or [int, int]
        number of bins
    log_scale         : bool
        if true, the colorbar is in logscale
    low               : float or [float, float]
        low  value(s) of the branches
    high              : float or [float, float]
        high value(s) of the branches
    title             : str
        title of the figure
    fig_name       : str
        name of the saved figure
    folder_name     : str
        name of the folder where to save the figure
    data_name         : str
        name of the data, this is used to define the name of the figure,
        in the case ``fig_name`` is not defined.
    save_fig        : bool
        specifies if the figure is saved
    ax            : matplotlib.axes.Axes
        axis where to plot
    pos_text_LHC    : dict, list or str
        passed to :py:func:`HEA.plot.tools.set_text_LHCb` as the ``pos`` argument.

    Returns
    -------
    fig : matplotlib.figure.Figure
        Figure of the plot (only if ``ax`` is not specified)
    ax : matplotlib.figure.Axes
        Axis of the plot (only if ``ax`` is not specified)
    """

    # low, high and units into a list of size 2
    low = el_to_list(low, 2)
    high = el_to_list(high, 2)

    units = el_to_list(units, 2)

    for i in range(2):
        low[i], high[i] = pt._redefine_low_high(low[i], high[i],
                                                df[branches[i]])

    # Plotting
    fig, ax = get_fig_ax(ax)

    title = string.add_text(data_name, title, default=None)

    ax.set_title(title, fontsize=25)

    if log_scale:
        _, _, _, h = ax.hist2d(df[branches[0]],
                               df[branches[1]],
                               range=[[low[0], high[0]], [low[1], high[1]]],
                               bins=n_bins,
                               norm=LogNorm())
    else:
        _, _, _, h = ax.hist2d(df[branches[0]],
                               df[branches[1]],
                               range=[[low[0], high[0]], [low[1], high[1]]],
                               bins=n_bins)

    # Label, color bar
    pt.set_label_ticks(ax)
    pt.set_text_LHCb(ax, pos=pos_text_LHC)

    set_label_2Dhist(ax, latex_branches, units, fontsize=25)
    cbar = plt.colorbar(h)
    cbar.ax.tick_params(labelsize=20)

    return end_plot_function(fig,
                             save_fig=save_fig,
                             fig_name=fig_name,
                             folder_name=folder_name,
                             default_fig_name=string.add_text(
                                 string.list_into_string(branches, '_vs_'),
                                 data_name, '_'),
                             ax=ax)
예제 #5
0
파일: BDT.py 프로젝트: anthony-correia/HEA
def signal_background(data1,
                      data2,
                      column=None,
                      range_column=None,
                      grid=True,
                      xlabelsize=None,
                      ylabelsize=None,
                      sharex=False,
                      sharey=False,
                      figsize=None,
                      layout=None,
                      n_bins=40,
                      fig_name=None,
                      folder_name=None,
                      colors=['red', 'green'],
                      **kwds):
    """Draw histogram of the DataFrame's series comparing the distribution
    in ``data1`` to ``data2`` and save the result in
    ``{loc['plot']}/BDT/{folder_name}/1D_hist_{fig_name}``

    Parameters
    ----------
    data1        : pandas.Dataframe
        First dataset
    data2        : pandas.Dataframe
        Second dataset
    column       : str or list(str)
        If passed, will be used to limit data to a subset of columns
    grid         : bool
        Whether to show axis grid lines
    xlabelsize   : int
        if specified changes the x-axis label size
    ylabelsize   : int
        if specified changes the y-axis label size
    ax           : matplotlib.axes.Axes
    sharex       : bool
        if ``True``, the X axis will be shared amongst all subplots.
    sharey       : bool
        if ``True``, the Y axis will be shared amongst all subplots.
    figsize      : tuple
        the size of the figure to create in inches by default
    bins         : int,
        number of histogram bins to be used
    fig_name    : str
        name of the saved file
    folder_name  : str
        name of the folder where to save the plot
    colors       : [str, str]
        colors used for the two datasets
    **kwds       : dict
        other plotting keyword arguments, to be passed to the `ax.hist()` function

    Returns
    -------
    fig : matplotlib.figure.Figure
        Figure of the plot
    ax : matplotlib.figure.Axes
        Axis of the plot
    """
    if 'alpha' not in kwds:
        kwds['alpha'] = 0.5

    if column is not None:
        # column is not a list, convert it into a list.
        if not isinstance(column, (list, np.ndarray, Index)):
            column = [column]
        data1 = data1[column]
        data2 = data2[column]

    data1 = data1._get_numeric_data()  # select only numbers
    data2 = data2._get_numeric_data()  # seject only numbers
    naxes = len(data1.columns)  # number of axes = number of available columns

    max_nrows = 4
    # subplots
    fig, axes = plt.subplots(nrows=min(naxes, max_nrows),
                             ncols=1 + naxes // max_nrows,
                             squeeze=False,
                             sharex=sharex,
                             sharey=sharey,
                             figsize=figsize)

    _axes = axes.flat

    if range_column is None:
        range_column = [[None, None] for i in range(len(column))]
    # data.columns = the column labels of the DataFrame.
    for i, col in enumerate(data1.columns):
        # col = name of the column/variable
        ax = _axes[i]

        if range_column[i] is None:
            range_column[i] = [None, None]
        if range_column[i][0] is None:
            low = min(data1[col].min(), data2[col].min())
        else:
            low = range_column[i][0]
        if range_column[i][1] is None:
            high = max(data1[col].max(), data2[col].max())
        else:
            high = range_column[i][1]

        low, high = pt.redefine_low_high(range_column[i][0],
                                         range_column[i][1],
                                         [data1[col], data2[col]])
        _, _, _, _ = h.plot_hist_alone(ax,
                                       data1[col].dropna().values,
                                       n_bins,
                                       low,
                                       high,
                                       colors[1],
                                       mode_hist=True,
                                       alpha=0.5,
                                       density=True,
                                       label='background',
                                       label_ncounts=True)
        _, _, _, _ = h.plot_hist_alone(ax,
                                       data2[col].dropna().values,
                                       n_bins,
                                       low,
                                       high,
                                       colors[0],
                                       mode_hist=True,
                                       alpha=0.5,
                                       density=True,
                                       label='signal',
                                       label_ncounts=True)

        bin_width = (high - low) / n_bins
        latex_branch, unit = RVariable.get_latex_branch_unit_from_branch(col)
        h.set_label_hist(ax,
                         latex_branch,
                         unit,
                         bin_width=bin_width,
                         density=False,
                         fontsize=20)
        pt.fix_plot(ax,
                    factor_ymax=1 + 0.3,
                    show_leg=True,
                    fontsize_ticks=15.,
                    fontsize_leg=20.)
        pt.show_grid(ax, which='major')

    i += 1
    while i < len(_axes):
        ax = _axes[i]
        ax.axis('off')
        i += 1

    #fig.subplots_adjust(wspace=0.3, hspace=0.7)
    if fig_name is None:
        fig_name = string.list_into_string(column)

    plt.tight_layout()
    pt.save_fig(fig, f"1D_hist_{fig_name}", folder_name=f'BDT/{folder_name}')

    return fig, axes
예제 #6
0
파일: BDT.py 프로젝트: anthony-correia/HEA
def correlations(data, fig_name=None, folder_name=None, title=None, **kwds):
    """ Calculate pairwise correlation between features of the dataframe data
    and save the figure in ``{loc['plot']}/BDT/{folder_name}/corr_matrix_{fig_name}``

    Parameters
    ----------
    data         : pandas.Dataframe
        dataset
    fig_name     : str
        name of the saved file
    folder_name  : str
        name of the folder where to save the plot
    **kwds       : dict
        other plotting keyword arguments, to be passed to ``pandas.DataFrame.corr()``

    Returns
    -------
    fig : matplotlib.figure.Figure
        Figure of the plot
    ax : matplotlib.figure.Axes
        Axis of the plot
    """

    # simply call df.corr() to get a table of
    # correlation values if you do not need
    # the fancy plotting
    corrmat = data.corr(**kwds)  # correlation

    fig, ax1 = plt.subplots(ncols=1, figsize=(12, 10))  # 1 plot

    opts = {
        'cmap': plt.get_cmap("RdBu"),  # red blue color mode
        'vmin': -1,
        'vmax': +1
    }  # correlation between -1 and 1
    heatmap1 = ax1.pcolor(corrmat, **opts)  # create a pseudo color plot
    plt.colorbar(heatmap1, ax=ax1)  # color bar

    title = string.add_text("Correlations", title, ' - ')
    ax1.set_title(title)

    labels = list(corrmat.columns.values)  # get the list of labels
    for i, label in enumerate(labels):
        latex_branch, _ = RVariable.get_latex_branch_unit_from_branch(label)
        labels[i] = latex_branch
    # shift location of ticks to center of the bins
    ax1.set_xticks(np.arange(len(labels)) + 0.5, minor=False)
    ax1.set_yticks(np.arange(len(labels)) + 0.5, minor=False)
    ax1.set_xticklabels(labels, minor=False, ha='right', rotation=70)
    ax1.set_yticklabels(labels, minor=False)

    plt.tight_layout()

    if fig_name is None:
        fig_name = string.list_into_string(column)

    pt.save_fig(fig,
                f"corr_matrix_{fig_name}",
                folder_name=f'BDT/{folder_name}')

    return fig, ax1