Exemple #1
0
def plot_rel_resp_data(rel_resp_df, analyspar, sesspar, stimpar, permpar, 
                       figpar, title=None, small=True):
    """
    plot_rel_resp_data((rel_resp_df, analyspar, sesspar, stimpar, permpar, 
                       figpar)

    Plots relative response errorbar data across sessions.

    Required args:
        - rel_resp_df (pd.DataFrame):
            dataframe with one row per session/line/plane, and the following 
            columns, in addition to the basic sess_df columns: 
            - rel_reg or rel_exp (list): 
                data stats for regular or expected data (me, err)
            - rel_unexp (list): data stats for unexpected data (me, err)
            for reg/exp/unexp data types, session comparisons, e.g. 1v2:
            - {data_type}_raw_p_vals_{}v{} (float): uncorrected p-value for 
                data differences between sessions 
            - {data_type}_p_vals_{}v{} (float): p-value for data between 
                sessions, corrected for multiple comparisons and tails

        - analyspar (dict): 
            dictionary with keys of AnalysPar namedtuple
        - sesspar (dict):
            dictionary with keys of SessPar namedtuple
        - stimpar (dict):
            dictionary with keys of StimPar namedtuple
        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - title (str):
            plot title
            default: None
        - small (bool):
            if True, subplots are smaller
            default: False

    Returns:
        - ax (2D array): 
            array of subplots
    """

    sess_ns = misc_analys.get_sess_ns(sesspar, rel_resp_df)

    figpar = sess_plot_util.fig_init_linpla(figpar)

    figpar["init"]["sharey"] = "row"
    if small:
        figpar["init"]["subplot_hei"] = 4.1
        figpar["init"]["subplot_wid"] = 2.6
        figpar["init"]["gs"] = {"hspace": 0.2, "wspace": 0.25}
        figpar["init"]["gs"]["wspace"] = 0.25
    else:
        figpar["init"]["subplot_hei"] = 4.4
        figpar["init"]["subplot_wid"] = 3.0
        figpar["init"]["gs"] = {"wspace": 0.3}

    fig, ax = plot_util.init_fig(plot_helper_fcts.N_LINPLA, **figpar["init"])

    if title is not None:
        fig.suptitle(title, y=0.98, weight="bold")

    if stimpar["stimtype"] == "gabors":
        data_types = ["rel_reg", "rel_unexp"]
    elif stimpar["stimtype"] == "visflow":
        data_types = ["rel_exp", "rel_unexp"]
    else:
        gen_util.accepted_values_error(
            "stimpar['stimtype']", stimpar["stimtype"], ["gabors", "visflow"]
            )

    for (line, plane), lp_df in rel_resp_df.groupby(["lines", "planes"]):
        li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs(line, plane)
        sub_ax = ax[pl, li]

        sess_indices = []
        lp_sess_ns = []
        for sess_n in sess_ns:
            rows = lp_df.loc[lp_df["sess_ns"] == sess_n]
            if len(rows) == 1:
                sess_indices.append(rows.index[0])
                lp_sess_ns.append(sess_n)
            elif len(rows) > 1:
                raise RuntimeError("Expected 1 row per line/plane/session.")

        sub_ax.axhline(
            y=1, ls=plot_helper_fcts.HDASH, c="k", lw=3.0, alpha=0.5, zorder=-13
            )

        colors = ["gray", col]
        fmts = ["-d", "-o"]
        alphas = [0.6, 0.8]
        ms = [12, None]
        for d, data_type in enumerate(data_types):
            data = np.asarray([lp_df.loc[i, data_type] for i in sess_indices])
            plot_util.plot_errorbars(
                sub_ax, data[:, 0], data[:, 1:].T, lp_sess_ns, color=colors[d], 
                alpha=alphas[d], ms=ms[d], fmt=fmts[d], line_dash=dash)
            
    highest = None
    for dry_run in [True, False]: # to get correct data heights
        for data_type in data_types:
            ctrl = ("unexp" not in data_type)
            highest = add_between_sess_sig(
                ax, rel_resp_df, permpar, data_col=data_type, highest=highest, 
                ctrl=ctrl, p_val_prefix=True, dry_run=dry_run)
            if not dry_run:
                highest = [val * 1.05 for val in highest] # increment a bit
            

    sess_plot_util.format_linpla_subaxes(ax, fluor=analyspar["fluor"], 
        datatype="roi", sess_ns=sess_ns, kind="reg", xticks=sess_ns, 
        modif_share=False)

    return ax

    
Exemple #2
0
def plot_sess_data(data_df, analyspar, sesspar, permpar, figpar, 
                   between_sess_sig=True, data_col="diff_stats", 
                   decoder_data=False, title=None, wide=False):
    """
    plot_sess_data(data_df, analyspar, sesspar, permpar, figpar)

    Plots errorbar data across sessions.

    Required args:
        - data_df (pd.DataFrame):
            dataframe with one row per session/line/plane, and the following 
            columns, in addition to the basic sess_df columns: 
            - {data_key} (list): data stats (me, err)
            - null_CIs (list): adjusted null CI for data
            - raw_p_vals (float): uncorrected p-value for data within 
                sessions
            - p_vals (float): p-value for data within sessions, 
                corrected for multiple comparisons and tails
            for session comparisons, e.g. 1v2:
            - raw_p_vals_{}v{} (float): uncorrected p-value for data 
                differences between sessions 
            - p_vals_{}v{} (float): p-value for data between sessions, 
                corrected for multiple comparisons and tails

        - analyspar (dict): 
            dictionary with keys of AnalysPar namedtuple
        - sesspar (dict):
            dictionary with keys of SessPar namedtuple
        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - between_sess_sig (bool):
            if True, significance between sessions is logged and plotted
            default: True         
        - data_col (str):
            dataframe column in which data to plot is stored
            default: "diff_stats"
        - decoder_data (bool):
            if True, data plotted is decoder data
            default: False
        - title (str):
            plot title
            default: None
        - wide (bool):
            if True, subplots are wider
            default: False
        
    Returns:
        - ax (2D array): 
            array of subplots
    """

    sess_ns = misc_analys.get_sess_ns(sesspar, data_df)

    figpar = sess_plot_util.fig_init_linpla(figpar)
    
    sharey = True if decoder_data else "row"
    figpar["init"]["sharey"] = sharey
    figpar["init"]["subplot_hei"] = 4.4
    figpar["init"]["gs"] = {"hspace": 0.2}
    if wide:
        figpar["init"]["subplot_wid"] = 3.0
        figpar["init"]["gs"]["wspace"] = 0.3
    else:
        figpar["init"]["subplot_wid"] = 2.6
        figpar["init"]["gs"]["wspace"] = 0.3

    fig, ax = plot_util.init_fig(plot_helper_fcts.N_LINPLA, **figpar["init"])

    if title is not None:
        fig.suptitle(title, y=0.97, weight="bold")

    sensitivity = misc_analys.get_sensitivity(permpar)
    comp_info = misc_analys.get_comp_info(permpar)

    for pass_n in [0, 1]: # add significance markers on the second pass
        if pass_n == 1:
            logger.info(f"{comp_info}:", extra={"spacing": "\n"})
            
        for (line, plane), lp_df in data_df.groupby(["lines", "planes"]):
            li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs(
                line, plane
                )
            line_plane_name = plot_helper_fcts.get_line_plane_name(line, plane)
            sub_ax = ax[pl, li]

            sess_indices = []
            lp_sess_ns = []
            for sess_n in sess_ns:
                rows = lp_df.loc[lp_df["sess_ns"] == sess_n]
                if len(rows) == 1:
                    sess_indices.append(rows.index[0])
                    lp_sess_ns.append(sess_n)
                elif len(rows) > 1:
                    raise RuntimeError("Expected 1 row per line/plane/session.")

            data = np.asarray([lp_df.loc[i, data_col] for i in sess_indices])

            if pass_n == 0:
                # plot errorbars
                plot_util.plot_errorbars(
                    sub_ax, data[:, 0], data[:, 1:].T, lp_sess_ns, color=col, 
                    alpha=0.8, xticks="auto", line_dash=dash
                    )

            if pass_n == 1:
                # plot CIs
                CIs = np.asarray(
                    [lp_df.loc[i, "null_CIs"] for i in sess_indices]
                    )
                CI_meds = CIs[:, 1]
                CIs = CIs[:, np.asarray([0, 2])]

                plot_util.plot_CI(sub_ax, CIs.T, med=CI_meds, x=lp_sess_ns, 
                    width=0.45, color="lightgrey", med_col="gray", med_rat=0.03, 
                    zorder=-12)

                # add significance markers within sessions
                y_maxes = data[:, 0] + data[:, -1]
                sides = [
                    np.sign(sub[0] - CI_med) 
                    for sub, CI_med in zip(data, CI_meds)
                    ]
                p_vals_corr = [lp_df.loc[i, "p_vals"] for i in sess_indices]
                lp_sig_str = f"{line_plane_name:6} (within session):"
                for s, sess_n in enumerate(lp_sess_ns):
                    sig_str = misc_analys.get_sig_symbol(
                        p_vals_corr[s], sensitivity=sensitivity, side=sides[s], 
                        tails=permpar["tails"], p_thresh=permpar["p_val"]
                        )

                    if len(sig_str):
                        plot_util.add_signif_mark(sub_ax, sess_n, y_maxes[s], 
                            rel_y=0.15, color=col, mark=sig_str)  

                    lp_sig_str = (
                        f"{lp_sig_str}{TAB} S{sess_n}: "
                        f"{p_vals_corr[s]:.5f}{sig_str:3}"
                        )

                logger.info(lp_sig_str, extra={"spacing": TAB})
        
    if between_sess_sig:
        add_between_sess_sig(ax, data_df, permpar, data_col=data_col)

    area, ylab = True, None
    if decoder_data:
        area = False
        if "balanced" in data_col:
            ylab = "Balanced accuracy (%)" 
        else:
            ylab = "Accuracy %"

    sess_plot_util.format_linpla_subaxes(ax, fluor=analyspar["fluor"], 
        area=area, ylab=ylab, datatype="roi", sess_ns=sess_ns, kind="reg", 
        xticks=sess_ns, modif_share=False)

    return ax
Exemple #3
0
def plot_sess_traces(data_df, analyspar, sesspar, figpar, 
                     trace_col="trace_stats", row_col="sess_ns", 
                     row_order=None, split="by_exp", title=None, size="reg"):
    """
    plot_sess_traces(data_df, analyspar, sesspar, figpar) 
    
    Plots traces from dataframe.

    Required args:
        - data_df (pd.DataFrame):
            traces data frame with, in addition to the basic sess_df columns, 
            columns specified by trace_col, row_col, and a "time_values" column
        - analyspar (dict): 
            dictionary with keys of AnalysPar namedtuple
        - sesspar (dict):
            dictionary with keys of SessPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - trace_col (str):
             dataframe column containing trace statistics, as 
             split x ROIs x frames x stats 
             default: "trace_stats"
        - row_col (str):
            dataframe column specifying the variable that defines rows 
            within each line/plane
            default: "sess_ns"
        - row_order (list):
            ordered list specifying the order in which to plot from row_col.
            If None, automatic sorting order is used.
            default: None 
        - split (str):
            data split, e.g. "exp_lock", "unexp_lock", "stim_onset" or 
            "stim_offset"
            default: False
        - title (str):
            plot title
            default: None
        - size (str):
            subplot sizes
            default: "reg"

    Returns:
        - ax (2D array): 
            array of subplots
    """
    
    # retrieve session numbers, and infer row_order, if necessary
    sess_ns = None
    if row_col == "sess_ns":
        sess_ns = row_order
        if row_order is None:
            row_order = misc_analys.get_sess_ns(sesspar, data_df)

    elif row_order is None:
        row_order = data_df[row_col].unique()

    figpar = sess_plot_util.fig_init_linpla(
        figpar, kind="traces", n_sub=len(row_order), sharey=False
        )

    if size == "small":
        figpar["init"]["subplot_hei"] = 1.51
        figpar["init"]["subplot_wid"] = 3.7
    elif size == "wide":
        figpar["init"]["subplot_hei"] = 1.36
        figpar["init"]["subplot_wid"] = 4.8
        figpar["init"]["gs"] = {"wspace": 0.3, "hspace": 0.5}
    elif size == "reg":
        figpar["init"]["subplot_hei"] = 1.36
        figpar["init"]["subplot_wid"] = 3.4
    else:
        gen_util.accepted_values_error("size", size, ["small", "wide", "reg"])

    fig, ax = plot_util.init_fig(len(row_order) * 4, **figpar["init"])
    if title is not None:
        fig.suptitle(title, y=1.0, weight="bold")

    for (line, plane), lp_df in data_df.groupby(["lines", "planes"]):
        li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs(line, plane)

        for r, row_val in enumerate(row_order):
            rows = lp_df.loc[lp_df[row_col] == row_val]
            if len(rows) == 0:
                continue
            elif len(rows) > 1:
                raise RuntimeError(
                    "Expected row_order instances to be unique per line/plane."
                    )
            row = rows.loc[rows.index[0]]

            sub_ax = ax[r + pl * len(row_order), li]

            if line == "L2/3-Cux2":
                exp_col = "darkgray" # oddly, lighter than gray
            else:
                exp_col = "gray"

            plot_traces(
                sub_ax, row["time_values"], row[trace_col], split=split, 
                col=col, ls=dash, exp_col=exp_col, lab=False
                )
            
    for sub_ax in ax.reshape(-1):
        plot_util.set_minimal_ticks(sub_ax, axis="y")

    sess_plot_util.format_linpla_subaxes(ax, fluor=analyspar["fluor"], 
        area=False, datatype="roi", sess_ns=sess_ns, xticks=None, 
        kind="traces", modif_share=False)

   # fix x ticks and lims
    plot_util.set_interm_ticks(ax, 3, axis="x", fontweight="bold")
    xlims = [np.min(row["time_values"]), np.max(row["time_values"])]
    if split != "by_exp":
        xlims = [-xlims[1], xlims[1]]
    sub_ax.set_xlim(xlims)

    return ax
Exemple #4
0
def plot_tracked_idx_stats(idx_stats_df, sesspar, figpar, permpar=None,
                           absolute=True, between_sess_sig=True, 
                           by_mouse=False, bootstr_err=None, title=None, 
                           wide=False):
    """
    plot_tracked_idx_stats(idx_stats_df, sesspar, figpar)

    Plots tracked ROI USI statistics.

    Required args:
        - idx_stats_df (pd.DataFrame):
            dataframe with one row per session, and the following columns, in 
            addition to the basic sess_df columns:
            - roi_idxs (list): index statistics
            or
            - abs_roi_idxs (list): absolute index statistics
        - sesspar (dict): 
            dictionary with keys of SessPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - permpar (dict): 
            dictionary with keys of PermPar namedtuple. Required if 
            between_sess_sig is True.
            default: None
        - absolute (bool):
            if True, data statistics are on absolute ROI indices
            default: True
        - between_sess_sig (bool):
            if True, significance between sessions is logged and plotted
            default: True
        - by_mouse (bool):
            if True, plotting is done per mouse
            default: False
        - bootstr_err (bool):
            if True, error is bootstrapped standard deviation
            default: False
        - title (str):
            plot title
            default: None
        - wide (bool):
            if True, subplots are wider
            default: False

    Returns:
        - ax (2D array): 
            array of subplots    
    """

    sess_ns = misc_analys.get_sess_ns(sesspar, idx_stats_df)

    figpar = sess_plot_util.fig_init_linpla(figpar)

    figpar["init"]["sharey"] = "row"
    figpar["init"]["subplot_hei"] = 4.1
    figpar["init"]["subplot_wid"] = 2.6
    figpar["init"]["gs"] = {"wspace": 0.25, "hspace": 0.2}
    if wide:
        figpar["init"]["subplot_wid"] = 3.3

    fig, ax = plot_util.init_fig(plot_helper_fcts.N_LINPLA, **figpar["init"])

    if title is not None:
        fig.suptitle(title, y=0.98, weight="bold")
    
    data_col = "roi_idx_stats"
    if absolute:
        data_col = f"abs_{data_col}"
    
    if data_col not in idx_stats_df.columns:
        raise KeyError(f"Expected to find {data_col} in idx_stats_df columns.")

    for (line, plane), lp_df in idx_stats_df.groupby(["lines", "planes"]):
        li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs(line, plane)
        sub_ax = ax[pl, li]

        sub_ax.axhline(
            y=0, ls=plot_helper_fcts.HDASH, c="k", lw=3.0, alpha=0.5, 
            zorder=-13
            )

        mouse_ns = ["any"]
        mouse_cols = [col]
        if by_mouse:
            mouse_ns = sorted(lp_df["mouse_ns"].unique())
            mouse_cols = plot_util.get_hex_color_range(
                len(mouse_ns), col=col, 
                interval=plot_helper_fcts.MOUSE_COL_INTERVAL
                )

        for mouse_n, mouse_col in zip(mouse_ns, mouse_cols):
            sub_df = lp_df
            if by_mouse:
                sub_df = lp_df.loc[lp_df["mouse_ns"] == mouse_n]
            
            sess_indices = []
            sub_sess_ns = []

            for sess_n in sess_ns:
                rows = sub_df.loc[sub_df["sess_ns"] == sess_n]
                if len(rows) == 1:
                    sess_indices.append(rows.index[0])
                    sub_sess_ns.append(sess_n)

            data = np.asarray([sub_df.loc[i, data_col] for i in sess_indices])

            # plot errorbars
            alpha = 0.6 if by_mouse else 0.8
            capsize = 8 if bootstr_err else None
            plot_util.plot_errorbars(
                sub_ax, data[:, 0], data[:, 1:].T, sub_sess_ns, color=mouse_col, 
                alpha=alpha, xticks="auto", line_dash=dash, capsize=capsize,
                )

    if between_sess_sig:
        if permpar is None:
            raise ValueError(
                "If 'between_sess_sig' is True, must provide permpar."
                )
        if by_mouse:
            raise NotImplementedError(
                "Plotting between session statistical signifiance is not "
                "implemented if 'by_mouse' if True."
                )

        seq_plots.add_between_sess_sig(
            ax, idx_stats_df, permpar, data_col=data_col
            )
    
    # Add plane, line info to plots
    sess_plot_util.format_linpla_subaxes(
        ax, datatype="roi", xticks=sess_ns, ylab="", kind="reg"
        )
    
    return ax
Exemple #5
0
def plot_tracked_idxs(idx_only_df, sesspar, figpar, title=None, wide=False):
    """
    plot_tracked_idxs(idx_only_df, sesspar, figpar)

    Plots tracked ROI USIs as individual lines.

    Required args:
        - idx_only_df (pd.DataFrame):
            dataframe with one row per (mouse/)session/line/plane, and the 
            following columns, in addition to the basic sess_df columns:
            - roi_idxs (list): index for each ROI

        - sesspar (dict): 
            dictionary with keys of SessPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - title (str):
            plot title
            default: None
        - wide (bool):
            if True, subplots are wider
            default: False

    Returns:
        - ax (2D array): 
            array of subplots    
    """

    sess_ns = misc_analys.get_sess_ns(sesspar, idx_only_df)

    figpar = sess_plot_util.fig_init_linpla(figpar)

    figpar["init"]["sharey"] = "row"
    figpar["init"]["subplot_hei"] = 4.1
    figpar["init"]["subplot_wid"] = 2.5
    figpar["init"]["gs"] = {"wspace": 0.25, "hspace": 0.2}
    if wide:
        figpar["init"]["subplot_wid"] = 3.3
        figpar["init"]["gs"]["wspace"] = 0.25

    fig, ax = plot_util.init_fig(plot_helper_fcts.N_LINPLA, **figpar["init"])

    if title is not None:
        fig.suptitle(title, y=0.98, weight="bold")

    for (line, plane), lp_df in idx_only_df.groupby(["lines", "planes"]):
        li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
        sub_ax = ax[pl, li]

        # mouse_ns
        lp_mouse_ns = sorted(lp_df["mouse_ns"].unique())

        lp_data = []
        for mouse_n in lp_mouse_ns:
            mouse_df = lp_df.loc[lp_df["mouse_ns"] == mouse_n]
            nrois = mouse_df["nrois"].unique()
            if len(nrois) != 1:
                raise RuntimeError(
                    "Each mouse in idx_stats_df should retain the same number "
                    " of ROIs across sessions.")
            
            mouse_data = np.full((len(sess_ns), nrois[0]), np.nan)
            for s, sess_n in enumerate(sess_ns):
                rows = mouse_df.loc[mouse_df["sess_ns"] == sess_n]
                if len(rows) == 1:
                    mouse_data[s] = rows.loc[rows.index[0], "roi_idxs"]
                elif len(rows) > 1:
                    raise RuntimeError(
                        "Expected 1 row per line/plane/session/mouse."
                        )
            lp_data.append(mouse_data)

        lp_data = np.concatenate(lp_data, axis=1)

        sub_ax.axhline(
            y=0, ls=plot_helper_fcts.HDASH, c="k", lw=3.0, alpha=0.5, 
            zorder=-13
            )
        sub_ax.plot(sess_ns, lp_data, color=col, lw=2, alpha=0.3)
    
    # Add plane, line info to plots
    sess_plot_util.format_linpla_subaxes(
        ax, datatype="roi", xticks=sess_ns, ylab="", kind="reg"
        )

    for sub_ax in ax.reshape(-1):
        xticks = sub_ax.get_xticks()
        plot_util.set_ticks(
            sub_ax, "x", np.min(xticks), np.max(xticks), n=len(xticks), 
            pad_p=0.2
            )
    
    return ax
Exemple #6
0
def plot_idxs(idx_df, sesspar, figpar, plot="items", density=True, n_bins=40, 
              title=None, size="reg"):
    """
    plot_idxs(idx_df, sesspar, figpar)

    Returns exact color for a specific line.

    Required args:
        - idx_df (pd.DataFrame):
            dataframe with indices for different line/plane combinations, and 
            the following columns, in addition to the basic sess_df columns:
            - rand_idx_binned (list): bin counts for the random ROI indices
            - bin_edges (list): first and last bin edge
            - CI_edges (list): confidence interval limit values
            - CI_perc (list): confidence interval percentile limits
            if plot == "items":
            - roi_idx_binned (list): bin counts for the ROI indices
            if plot == "percs":
            - perc_idx_binned (list): bin counts for the ROI index percentiles
            optionally:
            - n_signif_lo (int): number of significant ROIs (low) 
            - n_signif_hi (int): number of significant ROIs (high)

        - sesspar (dict): 
            dictionary with keys of SessPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - plot (str): 
            type of data to plot ("items" or "percs")
            default: "items"
        - density (bool): 
            if True, histograms are plotted as densities
            default: True
        - n_bins (int): 
            number of bins to use in histograms
            default: 40
        - title (str): 
            plot title
            default: None
        - size (str): 
            plot size ("reg", "small" or "tall")
            default: "reg"
        
    Returns:
        - ax (2D array): 
            array of subplots
    """

    if plot == "items":
        data_key = "roi_idx_binned"
        CI_key = "CI_edges"
    elif plot == "percs":
        data_key = "perc_idx_binned"
        CI_key = "CI_perc"
    else:
        gen_util.accepted_values_error("plot", plot, ["items", "percs"])

    sess_ns = misc_analys.get_sess_ns(sesspar, idx_df)

    n_plots = len(sess_ns) * 4
    figpar["init"]["sharey"] = "row"
    figpar = sess_plot_util.fig_init_linpla(figpar, kind="idx", 
        n_sub=len(sess_ns), sharex=(plot == "percs"))

    y = 1
    if size == "reg":
        subplot_hei = 3.2
        subplot_wid = 5.5
    elif size == "small":
        y = 1.04
        subplot_hei = 2.40
        subplot_wid = 3.75
        figpar["init"]["gs"] = {"hspace": 0.25, "wspace": 0.30}
    elif size == "tall":
        y = 0.98
        subplot_hei = 5.3
        subplot_wid = 5.55
    else:
        gen_util.accepted_values_error("size", size, ["reg", "small", "tall"])
    
    figpar["init"]["subplot_hei"] = subplot_hei
    figpar["init"]["subplot_wid"] = subplot_wid
    figpar["init"]["sharey"] = "row"
    
    fig, ax = plot_util.init_fig(n_plots, **figpar["init"])
    if title is not None:
        fig.suptitle(title, y=y, weight="bold")

    for (line, plane), lp_df in idx_df.groupby(["lines", "planes"]):
        li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)

        for s, sess_n in enumerate(sess_ns):
            rows = lp_df.loc[lp_df["sess_ns"] == sess_n]
            if len(rows) == 0:
                continue
            elif len(rows) > 1:
                raise RuntimeError(
                    "Expected sess_ns to be unique per line/plane."
                    )
            row = rows.loc[rows.index[0]]

            sub_ax = ax[s + pl * len(sess_ns), li]

            # get percentage significant label
            perc_label = None
            if "n_signif_lo" in row.keys() and "n_signif_hi" in row.keys():
                n_sig_lo, n_sig_hi = row["n_signif_lo"], row["n_signif_hi"]
                nrois = np.sum(row["nrois"])
                perc_signif = np.sum([n_sig_lo, n_sig_hi]) / nrois * 100
                perc_label = (f"{perc_signif:.2f}% sig\n"
                    f"({n_sig_lo}-/{n_sig_hi}+ of {nrois})")                

            plot_stim_idx_hist(
                sub_ax, row[data_key], row[CI_key], n_bins=n_bins, 
                rand_data=row["rand_idx_binned"], 
                orig_edges=row["bin_edges"], 
                plot=plot, col=col, density=density, perc_label=perc_label)
            
            if size == "small":
                sub_ax.legend(fontsize="small")

    # Add plane, line info to plots
    y_lab = "Density" if density else f"N ROIs" 
    sess_plot_util.format_linpla_subaxes(ax, datatype="roi", ylab=y_lab, 
        xticks=None, sess_ns=None, kind="idx", modif_share=False, 
        xlab="Index", single_lab=True)

    # Add indices after setting formatting
    if plot == "percs":
        nticks = 5
        xticks = [int(np.around(x, 0)) for x in np.linspace(0, 100, nticks)]
        for sub_ax in ax[-1]:
            sub_ax.set_xticks(xticks)
            sub_ax.set_xticklabels(xticks, weight="bold")
    
    elif plot == "items":
        nticks = 3
        plot_util.set_interm_ticks(
            ax, nticks, axis="x", weight="bold", share=False, skip=False
            )
    
    else:
        gen_util.accepted_values_error("plot", plot, ["items", "percs"])
        
    return ax