def init_res_fig(n_subplots, max_sess=None, modif=False):

    Initializes a figure in which to plot summary results.

    Required args:
        - n_subplots (int): number of subplots
    Optional args:
        - max_sess (int): maximum number of sessions plotted
                          default: None
        - modif (bool)  : if True, plots are made in a modified (simplified 
                          default: False

        - fig (plt Fig): figure
        - ax (plt Axis): axis

    subplot_hei = 14
    subplot_wid = 7.5
    if max_sess is not None:
        subplot_wid *= max_sess/4.0

    if modif:
        figpar_init = sess_plot_util.fig_init_linpla(sharey=True)["init"]
        fig, ax = plot_util.init_fig(n_subplots, **figpar_init)
        fig, ax = plot_util.init_fig(n_subplots, 2, sharey=True, 
            subplot_hei=subplot_hei, subplot_wid=subplot_wid)

    return fig, ax
def plot_idx_correlations(idx_corr_df,
    plot_idx_correlations(idx_corr_df, permpar, figpar)

    Plots ROI USI index correlations across sessions.

    Required args:
        - idx_corr_df (pd.DataFrame):
            dataframe with one row per line/plane, and the 
            following columns, in addition to the basic sess_df columns:

            for session comparisons, e.g. 1v2
            - {}v{}{norm_str}_corrs (float): intersession ROI index correlations
            - {}v{}{norm_str}_corr_stds (float): bootstrapped intersession ROI 
                index correlation standard deviation
            - {}v{}_null_CIs (list): adjusted null CI for intersession ROI 
                index correlations
            - {}v{}_raw_p_vals (float): p-value for intersession correlations
            - {}v{}_p_vals (float): p-value for intersession correlations, 
                corrected for multiple comparisons and tails

        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - permute (bool):
            type of permutation to due ("tracking", "sess" or "all")
            default: "sess"
        - corr_type (str):
            type of correlation run, i.e. "corr" or "R_sqr"
            default: "corr"
        - title (str):
            plot title
            default: None
        - small (bool):
            if True, smaller subplots are plotted
            default: True

        - ax (2D array): 
            array of subplots

    norm = False
    if permute in ["sess", "all"]:
        corr_type = f"diff_{corr_type}"
        if corr_type == "diff_corr":
            norm = True
            title = title.replace("Correlations", "Normalized correlations")

    norm_str = "_norm" if norm else ""

    sess_pairs = get_sorted_sess_pairs(idx_corr_df, norm=norm)
    n_pairs = int(np.ceil(len(sess_pairs) / 2) * 2)  # multiple of 2

    figpar = sess_plot_util.fig_init_linpla(figpar,
                                            n_sub=int(n_pairs / 2))

    figpar["init"]["ncols"] = n_pairs
    figpar["init"]["sharey"] = "row"

    figpar["init"]["gs"] = {"hspace": 0.25}
    if small:
        figpar["init"]["subplot_wid"] = 2.7
        figpar["init"]["subplot_hei"] = 4.21
        figpar["init"]["gs"]["wspace"] = 0.2
        figpar["init"]["subplot_wid"] = 3.3
        figpar["init"]["subplot_hei"] = 4.71
        figpar["init"]["gs"]["wspace"] = 0.3

    fig, ax = plot_util.init_fig(n_pairs * 2, **figpar["init"])
    if title is not None:
        fig.suptitle(title, y=0.98, weight="bold")

    plane_pts = get_idx_corr_ylims(idx_corr_df, norm=norm)
    lines = [None, None]

    comp_info = misc_analys.get_comp_info(permpar)"{comp_info}:", extra={"spacing": "\n"})
    for (line, plane), lp_df in idx_corr_df.groupby(["lines", "planes"]):
        li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
        linpla_name = plot_helper_fcts.get_line_plane_name(line, plane)
        lines[li] = line.split("-")[0].replace("23", "2/3")

        if len(lp_df) != 1:
            raise RuntimeError("Expected only one row per line/plane.")
        row = lp_df.loc[lp_df.index[0]]

        lp_sig_str = f"{linpla_name:6}:"
        for s, sess_pair in enumerate(sess_pairs):
            sub_ax = ax[pl, s]
            if s == 0:

            col_base = f"{sess_pair[0]}v{sess_pair[1]}"

            CI = row[f"{col_base}_null_CIs"]
            extr = np.asarray([CI[0], CI[2]])

            y = row[f"{col_base}{norm_str}_corrs"]

            err = row[f"{col_base}{norm_str}_corr_stds"]

            # add significance markers
            p_val = row[f"{col_base}_p_vals"]
            side = np.sign(y - CI[1])
            sensitivity = misc_analys.get_sensitivity(permpar)
            sig_str = misc_analys.get_sig_symbol(p_val,

            if len(sig_str):
                high = np.max([CI[-1], y + err])

            sess_str = f"S{sess_pair[0]}v{sess_pair[1]}: "
            lp_sig_str = f"{lp_sig_str}{TAB}{sess_str}{p_val:.5f}{sig_str:3}", extra={"spacing": TAB})

    # Add plane, line info to plots
                                         lines=["", ""],
                                         xticks=[0, 1],

    xs = np.arange(len(lines))
    pad_x = 0.6 * (xs[1] - xs[0])
    for row_n in range(len(ax)):
        for col_n in range(len(ax[row_n])):
            sub_ax = ax[row_n, col_n]
            sub_ax.tick_params(axis="x", which="both", bottom=False)
            sub_ax.set_xticklabels(lines, weight="bold")
            sub_ax.set_xlim([xs[0] - pad_x, xs[-1] + pad_x])


            if row_n == 0:
                if col_n < len(sess_pairs):
                    s1, s2 = sess_pairs[col_n]
                    sess_pair_title = f"Session {s1} v {s2}"


    return ax
def run_sess_lstm(sessid, args):

    if args.parallel and args.plt_bkend is not None:
        plt.switch_backend(args.plt_bkend) # needs to be repeated within joblib

    args.seed = rand_util.seed_all(args.seed, args.device, seed_torch=True)

    train_p = 0.8
    lr = 1. * 10**(-args.lr_ex)
    if args.conv:
        conv_str = "_conv"
        outch_str = f"_{args.out_ch}outch"
        conv_str = ""
        outch_str = ""

    # Input output parameters
    n_stim_s  = 0.6
    n_roi_s = 0.3

    # Stim/traces for training
    train_gabfr = 0
    train_post = 0.9 # up to C
    roi_train_pre = 0 # from A
    stim_train_pre   = 0.3 # from preceeding grayscreen

    # Stim/traces for testing (separated for unexp vs exp)
    test_gabfr = 3
    test_post  = 0.6 # up to grayscreen
    roi_test_pre = 0 # from D/U
    stim_test_pre   = 0.3 # from preceeding C

    sess = sess_gen_util.init_sessions(
        sessid, args.datadir, args.mouse_df, args.runtype, full_table=False, 
        fluor="dff", dend="extr", run=True, temp_log="warning")[0]

    analysdir = sess_gen_util.get_analysdir(
        sess.mouse_n, sess.sess_n, sess.plane, stimtype=args.stimtype, 
    dirname = Path(args.output, analysdir)
    file_util.createdir(dirname, log_dir=False)

    # Must not scale ROIs or running BEFOREHAND. Must do after to use only 
    # network available data.

    # seq x frame x gabor x par"Preparing stimulus parameter dataframe", 
        extra={"spacing": "\n"})
    train_stim_wins, run_stats = sess_data_util.get_stim_data(
        sess, args.stimtype, n_stim_s, train_gabfr, stim_train_pre, 
        train_post, gabk=16, run=True)"Adding ROI data")
    xran, train_roi_wins, roi_stats = sess_data_util.get_roi_data(
        sess, args.stimtype, n_roi_s, train_gabfr, roi_train_pre, train_post, 

    logger.warning("Preparing windowed datasets (too slow - to be improved)")
    raise NotImplementedError("Not implemented properly - some error leads "
        "to excessive memory requests.")
    test_stim_wins = []
    test_roi_wins  = []
    for unexp in [0, 1]:
        stim_wins = sess_data_util.get_stim_data(
            sess, args.stimtype, n_stim_s, test_gabfr, stim_test_pre, 
            test_post, unexp, gabk=16, run_mean=run_stats[0], 
        roi_wins = sess_data_util.get_roi_data(sess, args.stimtype, n_roi_s,  
                           test_gabfr, roi_test_pre, test_post, unexp, gabk=16, 
                           roi_means=roi_stats[0], roi_stds=roi_stats[1])[1]

    n_pars = train_stim_wins.shape[-1] # n parameters (121)
    n_rois = train_roi_wins.shape[-1] # n ROIs

    hyperstr = (f"{args.hidden_dim}hd_{args.num_layers}hl_{args.lr_ex}lrex_"

    dls = data_util.create_dls(train_stim_wins, train_roi_wins, train_p=train_p, 
                            test_p=0, batchsize=args.batchsize, thresh_cl=0, 
    train_dl, val_dl, _ = dls

    test_dls = []
    for s in [0, 1]:
        dl = data_util.init_dl(test_stim_wins[s], test_roi_wins[s], 
        test_dls.append(dl)"Running LSTM")
    if args.conv:
        lstm = ConvPredROILSTM(args.hidden_dim, n_rois, out_ch=args.out_ch, 
                            num_layers=args.num_layers, dropout=args.dropout)
        lstm = PredLSTM(n_pars, args.hidden_dim, n_rois, 
                        num_layers=args.num_layers, dropout=args.dropout)

    lstm =
    lstm.loss_fn = torch.nn.MSELoss(size_average=False)
    lstm.opt = torch.optim.Adam(lstm.parameters(), lr=lr)

    loss_df = pd.DataFrame(
        np.nan, index=range(args.n_epochs), columns=["train", "val"])
    min_val = np.inf
    for ep in range(args.n_epochs):"====> Epoch {ep}", extra={"spacing": "\n"})
        if ep == 0:
            train_loss = run_dl(lstm, train_dl, args.device, train=False)    
            train_loss = run_dl(lstm, train_dl, args.device, train=True)
        val_loss = run_dl(lstm, val_dl, args.device, train=False)
        loss_df["train"].loc[ep] = train_loss/train_dl.dataset.n_samples
        loss_df["val"].loc[ep] = val_loss/val_dl.dataset.n_samples"Training loss  : {loss_df['train'].loc[ep]}")"Validation loss: {loss_df['val'].loc[ep]}")

        # record model if training is lower than val, and val reaches a new low
        if ep == 0 or val_loss < min_val:
            prev_model = glob.glob(str(Path(dirname, f"{hyperstr}_ep*.pth")))
            prev_df = glob.glob(str(Path(dirname, f"{hyperstr}.csv")))
            min_val = val_loss
            saved_ep = ep
            if len(prev_model) == 1 and len(prev_df) == 1:

            savename = f"{hyperstr}_ep{ep}"
            savefile = Path(dirname, savename)
  {"net": lstm.state_dict(), "opt": lstm.opt.state_dict()},
            file_util.saveinfo(loss_df, hyperstr, dirname, "csv")

    plot_util.linclab_plt_defaults(font=["Arial", "Liberation Sans"], 
    fig, ax = plt.subplots(1)
    for dataset in ["train", "val"]:
        plot_util.plot_traces(ax, range(args.n_epochs), np.asarray(loss_df[dataset]), 
                  label=dataset, title=f"Average loss (MSE) ({n_rois} ROIs)", 
    fig.savefig(Path(dirname, f"{hyperstr}_loss"))

    savemod = Path(dirname, f"{hyperstr}_ep{saved_ep}.pth")
    checkpoint = torch.load(savemod)

    n_samples = 20
    val_idx = np.random.choice(range(val_dl.dataset.n_samples), n_samples)
    val_samples = val_dl.dataset[val_idx]
    xrans = data_util.get_win_xrans(xran, val_samples[1].shape[1], val_idx.tolist())

    fig, ax = plot_util.init_fig(n_samples, ncols=4, sharex=True, subplot_hei=2, 

    with torch.no_grad():
        batch_len, seq_len, n_items = val_samples[1].shape
        pred_tr = lstm(val_samples[0].transpose(1, 0).to(args.device))
        pred_tr = pred_tr.view([seq_len, batch_len, n_items]).transpose(1, 0)

    for lab, data in zip(["target", "pred"], [val_samples[1], pred_tr]):
        data = data.numpy()
        for n in range(n_samples):
            roi_n = np.random.choice(range(data.shape[-1]))
            sub_ax = plot_util.get_subax(ax, n)
            plot_util.plot_traces(sub_ax, xrans[n], data[n, :, roi_n], 
                label=lab, xticks="auto")
            plot_util.set_ticks(sub_ax, "x", xran[0], xran[-1], n=7)

    sess_plot_util.plot_labels(ax, train_gabfr, plot_vals="exp", pre=roi_train_pre, 

    fig.suptitle(f"Target vs predicted validation traces ({n_rois} ROIs)")
    fig.savefig(Path(dirname, f"{hyperstr}_traces"))
def plot_glm_expl_var(analyspar, sesspar, stimpar, extrapar, glmpar,
                      sess_info, all_expl_var, figpar=None, savedir=None):
    plot_glm_expl_var(analyspar, sesspar, stimpar, extrapar, 
                       sess_info, all_expl_var)

    From dictionaries, plots explained variance for different variables for 
    each ROI.

    Required args:
        - analyspar (dict)    : dictionary with keys of AnalysPar namedtuple
        - sesspar (dict)      : dictionary with keys of SessPar namedtuple 
        - stimpar (dict)      : dictionary with keys of StimPar namedtuple
        - glmpar (dict)       : dictionary with keys of GLMPar namedtuple
        - extrapar (dict)     : dictionary containing additional analysis 
            ["analysis"] (str): analysis type (e.g., "v")
        - sess_info (dict)    : dictionary containing information from each
            ["mouse_ns"] (list)   : mouse numbers
            ["sess_ns"] (list)    : session numbers  
            ["lines"] (list)      : mouse lines
            ["planes"] (list)     : imaging planes
            ["nrois"] (list)      : number of ROIs in session

        - all_expl_var (list) : list of dictionaries with explained variance 
                                for each session set, with each glm 
                                coefficient as a key:
            ["full"] (list)    : list of full explained variance stats for 
                                 every ROI, structured as ROI x stats
            ["coef_all"] (dict): max explained variance for each ROI with each
                                 coefficient as a key, structured as ROI x stats
            ["coef_uni"] (dict): unique explained variance for each ROI with 
                                 each coefficient as a key, 
                                 structured as ROI x stats
            ["rois"] (list)    : ROI numbers (-1 for GLMs fit to 
                                 mean/median ROI activity)
    Optional args:
        - figpar (dict): dictionary containing the following figure parameter 
                         default: None
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters
        - savedir (str): path of directory in which to save plots.
                         default: None
        - fulldir (str) : final path of the directory in which the figure is 
                          saved (may differ from input savedir, if datetime 
                          subfolder is added.)
        - savename (str): name under which the figure is saved

    stimstr_pr = sess_str_util.stim_par_str(
        stimpar["stimtype"], stimpar["visflow_dir"], stimpar["visflow_size"], 
        stimpar["gabk"], "print")
    dendstr_pr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], "roi", "print")

    sessstr = sess_str_util.sess_par_str(
        sesspar["sess_n"], stimpar["stimtype"], sesspar["plane"], 
        stimpar["visflow_dir"], stimpar["visflow_size"], stimpar["gabk"]) 
    dendstr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], "roi")

    # extract some info from sess_info
    keys = ["mouse_ns", "sess_ns", "lines", "planes"]
    [mouse_ns, sess_ns, lines, planes] = [sess_info[key] for key in keys]

    n_sess = len(mouse_ns)
    nroi_strs = sess_str_util.get_nroi_strs(sess_info, style="par")

    plot_bools = [ev["rois"] not in [[-1], "all"] for ev in all_expl_var]
    n_sess = sum(plot_bools)

    if stimpar["stimtype"] == "gabors":
        xyzc_dims = ["unexpected", "gabor_frame", "run_data", "pup_diam_data"]
        log_dims = xyzc_dims + ["gabor_mean_orientation"]
    elif stimpar["stimtype"] == "visflow":
        xyzc_dims = [
            "unexpected", "main_flow_direction", "run_data", "pup_diam_data"
        log_dims = xyzc_dims
    # start plotting"Plotting GLM full and unique explained variance for "
        f"{', '.join(xyzc_dims)}.", extra={"spacing": "\n"})

    if n_sess > 0:
        if figpar is None:
            figpar = sess_plot_util.init_figpar()

        figpar = copy.deepcopy(figpar)
        cmap = plot_util.linclab_colormap(nbins=100, no_white=True)

        if figpar["save"]["use_dt"] is None:
            figpar["save"]["use_dt"] = gen_util.create_time_str()
        figpar["init"]["ncols"] = n_sess
        figpar["init"]["sharex"] = False
        figpar["init"]["sharey"] = False
        figpar["init"]["gs"] = {"wspace": 0.2, "hspace": 0.35}
        figpar["save"]["fig_ext"] = "png"
        fig, ax = plot_util.init_fig(2 * n_sess, **figpar["init"], proj="3d")

        fig.suptitle("Explained variance per ROI", y=1)

        # get colormap range
        c_range = [np.inf, -np.inf]
        c_key = xyzc_dims[3]

        for expl_var in all_expl_var:
            for var_type in ["coef_all", "coef_uni"]:
                rs = np.where(np.asarray(expl_var["rois"]) != -1)[0]
                if c_key in expl_var[var_type].keys():
                    c_data = np.asarray(expl_var[var_type][c_key])[rs, 0]
                    # adjust colormap range
                    c_range[0] = np.min([c_range[0], min(c_data)])
                    c_range[1] = np.max([c_range[1], max(c_data)])
        if not np.isfinite(sum(c_range)):
            c_range = [-0.5, 0.5] # dummy range
            c_range = plot_util.rounded_lims(c_range, out=True)

    else:"No plots, as only results across ROIs are included")
        fig = None

    i = 0
    for expl_var in all_expl_var:
        # collect info for plotting and logging results across ROIs
        rs = np.where(np.asarray(expl_var["rois"]) != -1)[0]
        all_rs = np.where(np.asarray(expl_var["rois"]) == -1)[0]
        if len(all_rs) != 1:
            raise RuntimeError("Expected only one result for all ROIs.")
            all_rs = all_rs[0]
            full_ev = expl_var["full"][all_rs]

        title = (f"Mouse {mouse_ns[i]} - {stimstr_pr}\n(sess {sess_ns[i]}, "
                f"{lines[i]} {planes[i]}{dendstr_pr},{nroi_strs[i]})"), extra={"spacing": "\n"})

        math_util.log_stats(full_ev, stat_str="\nFull explained variance")

        dim_length = max([len(dim) for dim in log_dims])
        for v, var_type in enumerate(["coef_all", "coef_uni"]):
            if var_type == "coef_all":
                sub_title = "Explained variance per coefficient"
            elif var_type == "coef_uni":
                sub_title = "Unique explained variance\nper coefficient"
  , extra={"spacing": "\n"})

            dims_all = []
            for key in log_dims:
                if key in xyzc_dims:
                    # get mean/med
                    if key not in expl_var[var_type].keys():

                    dims_all.append(np.asarray(expl_var[var_type][key])[rs, 0])
                    stat_str=key.ljust(dim_length), log_spacing=TAB

            if not plot_bools[-1]:

            if v == 0:
                y = 1.12
                subpl_title = f"{title}\n{sub_title}"
                y = 1.02
                subpl_title = sub_title

            # retrieve values and names for each dimension, including dummy 
            # dimensions
            use_xyzc_dims = []
            n_vals = None
            dummies = []
            pads = [16, 16, 20]
            for d, dim in enumerate(dims_all):
                dim_name = xyzc_dims[d].replace("_", " ")
                if " direction"  in dim_name:
                    dim_name = dim_name.replace(" direction", "\ndirection")
                    pads[d] = 24
                if isinstance(dim, str) and dim == "dummy":
                    use_xyzc_dims.append(f"{dim_name} (dummy)")
                    n_vals = len(dim)
            for d in dummies:
                dims_all[d] = np.zeros(n_vals)

            [x_data, y_data, z_data, c_data] = dims_all

            sub_ax = ax[v, i]
            im = sub_ax.scatter(
                x_data, y_data, z_data, c=c_data, cmap=cmap, 
                vmin=c_range[0], vmax=c_range[1]
            sub_ax.set_title(subpl_title, y=y)
            # sub_ax.set_zlim3d(0, 1.0)

            # adjust padding for z axis
            sub_ax.tick_params(axis='z', which='major', pad=10)

            # add labels
            sub_ax.set_xlabel(use_xyzc_dims[0], labelpad=pads[0])
            sub_ax.set_ylabel(use_xyzc_dims[1], labelpad=pads[1])
            sub_ax.set_zlabel(use_xyzc_dims[2], labelpad=pads[2])

            if v == 0:
                full_ev_lab = math_util.log_stats(
                    full_ev, stat_str="Full EV", ret_str_only=True
                sub_ax.plot([], [], c="k", label=full_ev_lab)

        i += 1

    if fig is not None:
            fig, im, n_sess, label=use_xyzc_dims[3],
            space_fact=np.max([2, n_sess])

        # plot 0 planes, and lines
        for sub_ax in ax.reshape(-1):
            all_lims = [sub_ax.get_xlim(), sub_ax.get_ylim(), sub_ax.get_zlim()]
            xs, ys, zs = [
                [vs[0] - (vs[1] - vs[0]) * 0.02, vs[1] + (vs[1] - vs[0]) * 0.02]
                for vs in all_lims
            for plane in ["x", "y", "z"]:
                if plane == "x":
                    xx, yy = np.meshgrid(xs, ys)
                    zz = xx * 0
                    x_flat = xs
                    y_flat, z_flat = [0, 0], [0, 0]
                elif plane == "y":
                    yy, zz = np.meshgrid(ys, zs)
                    xx = yy * 0
                    y_flat = ys
                    z_flat, x_flat = [0, 0], [0, 0]
                elif plane == "z":
                    zz, xx = np.meshgrid(zs, xs)
                    yy = zz * 0
                    z_flat = zs
                    x_flat, y_flat = [0, 0], [0, 0]
                sub_ax.plot_surface(xx, yy, zz, alpha=0.05, color="k")
                    x_flat, y_flat, z_flat, alpha=0.4, color="k", ls=(0, (2, 2))

    if savedir is None:
        savedir = Path(

    savename = (f"roi_glm_ev_{sessstr}{dendstr}")

    fulldir = plot_util.savefig(fig, savename, savedir, **figpar["save"])

    return fulldir, savename                              
def plot_mag_change(analyspar, sesspar, stimpar, extrapar, permpar, quantpar, 
                    sess_info, mags, figpar=None, savedir=None):
    plot_mag_change(analyspar, sesspar, stimpar, extrapar, permpar, quantpar, 
                    sess_info, mags) 

    From dictionaries, plots magnitude of change in unexpected and expected
    responses across quantiles.

    Returns figure name and save directory path.

    Required args:
        - analyspar (dict): dictionary with keys of AnalysPar namedtuple
        - sesspar (dict)  : dictionary with keys of SessPar namedtuple 
        - stimpar (dict)  : dictionary with keys of StimPar namedtuple
        - extrapar (dict) : dictionary containing additional analysis 
            ["analysis"] (str): analysis type (e.g., "m")
            ["datatype"] (str): datatype (e.g., "run", "roi")
            ["seed"]     (int): seed value used
        - permpar (dict)  : dictionary with keys of PermPar namedtuple 
        - quantpar (dict) : dictionary with keys of QuantPar namedtuple
        - roigrppar (dict): dictionary with keys of RoiGrpPar namedtuple
        - sess_info (dict): dictionary containing information from each
            ["mouse_ns"] (list)   : mouse numbers
            ["sess_ns"] (list)    : session numbers  
            ["lines"] (list)      : mouse lines
            ["planes"] (list)     : imaging planes
            ["nrois"] (list)      : number of ROIs in session

        - mags (dict)     : dictionary containing magnitude data to plot
            ["L2"] (array-like)    : nested list containing L2 norms, 
                                     structured as: 
                                         sess x scaling x unexp
            ["L2_sig"] (list)      : L2 significance results for each session 
                                         ("hi", "lo" or "no")
            ["mag_sig"] (list)     : magnitude significance results for each 
                                         ("hi", "lo" or "no")
            ["mag_st"] (array-like): array or nested list containing magnitude 
                                     stats across ROIs, structured as: 
                                         sess x scaling x unexp x stats

    Optional args:
        - figpar (dict): dictionary containing the following figure parameter 
                         default: None
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters
        - savedir (str): path of directory in which to save plots.
                         default: None
        - fulldir (str) : final path of the directory in which the figure is 
                          saved (may differ from input savedir, if datetime 
                          subfolder is added.)
        - savename (str): name under which the figure is saved
    sessstr_pr = sess_str_util.sess_par_str(
        sesspar["sess_n"], stimpar["stimtype"], sesspar["plane"], 
        stimpar["visflow_dir"], stimpar["visflow_size"], stimpar["gabk"], 
    statstr_pr = sess_str_util.stat_par_str(
        analyspar["stats"], analyspar["error"], "print")
    dendstr_pr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"], "print")
    sessstr = sess_str_util.sess_par_str(
        sesspar["sess_n"], stimpar["stimtype"], sesspar["plane"], 
        stimpar["visflow_dir"],stimpar["visflow_size"], stimpar["gabk"]) 
    dendstr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"])
    datatype = extrapar["datatype"]
    dimstr = sess_str_util.datatype_dim_str(datatype)
    # extract some info from sess_info
    keys = ["mouse_ns", "sess_ns", "lines", "planes"]
    [mouse_ns, sess_ns, lines, planes] = [sess_info[key] for key in keys]
    nroi_strs = sess_str_util.get_nroi_strs(
        sess_info, empty=(datatype!="roi"), style="par"

    n_sess = len(mouse_ns)

    qu_ns = [gen_util.pos_idx(q, quantpar["n_quants"]) + 1 
        for q in quantpar["qu_idx"]]
    if len(qu_ns) != 2:
        raise ValueError(f"Expected 2 quantiles, not {len(qu_ns)}.")
    mag_st = np.asarray(mags["mag_st"])

    if figpar is None:
        figpar = sess_plot_util.init_figpar()

    figpar = copy.deepcopy(figpar)
    if figpar["save"]["use_dt"] is None:
        figpar["save"]["use_dt"] = gen_util.create_time_str()
    figpar["init"]["subplot_wid"] *= n_sess/2.0
    scales = [False, True]

    # get plot elements
    barw = 0.75
    # scaling strings for printing and filenames
    leg = ["exp", "unexp"]    
    cent, bar_pos, xlims = plot_util.get_barplot_xpos(n_sess, len(leg), barw)   
    title = (u"Magnitude ({}) of difference in activity".format(statstr_pr) +
        f"\nbetween Q{qu_ns[0]} and {qu_ns[1]} across {dimstr} "
    labels = [f"Mouse {mouse_ns[i]} sess {sess_ns[i]},\n {lines[i]} {planes[i]}"
        f"{dendstr_pr}{nroi_strs[i]}" for i in range(n_sess)]

    figs, axs = [], []
    for sc, scale in enumerate(scales):
        scalestr_pr = sess_str_util.scale_par_str(scale, "print")
        fig, ax = plot_util.init_fig(1, **figpar["init"])
        sub_ax = ax[0, 0]
        # always set ticks (even again) before setting labels
        title_scale = u"{}{}".format(title, scalestr_pr)
            sub_ax, fluor=analyspar["fluor"], area=True, scale=scale, x_ax="", 
        for s, lab in enumerate(leg):
            xpos = list(zip(*bar_pos))[s]
                sub_ax, xpos, mag_st[:, sc, s, 0], err=mag_st[:, sc, s, 1:], 
                width=barw, xlims=xlims, xticks="None", label=lab, capsize=4,
                title=title_scale, hline=0)
    # add significance markers
    for i in range(n_sess):
        signif = mags["mag_sig"][i]
        if signif in ["hi", "lo"]:
            xpos = bar_pos[i]
            for sc, (ax, scale) in enumerate(zip(axs, scales)):
                yval = mag_st[i, sc, :, 0]
                yerr = mag_st[i, sc, :, 1:]
                plot_util.plot_barplot_signif(ax[0, 0], xpos, yval, yerr)
    plot_util.turn_off_extra(ax, n_sess)

   # figure directory
    if savedir is None:
        savedir = Path(
    log_dir = False
    for i, (fig, scale) in enumerate(zip(figs, scales)):
        if i == len(figs) - 1:
            log_dir = True
        scalestr = sess_str_util.scale_par_str(scale)
        savename = f"{datatype}_mag_diff_{sessstr}{dendstr}"
        savename_full = f"{savename}{scalestr}"
        fulldir = plot_util.savefig(
            fig, savename_full, savedir, log_dir=log_dir, ** figpar["save"])

    return fulldir, savename
def plot_traces_by_qu_unexp_sess(analyspar, sesspar, stimpar, extrapar, 
                                quantpar, sess_info, trace_stats, figpar=None, 
                                savedir=None, modif=False):
    plot_traces_by_qu_unexp_sess(analyspar, sesspar, stimpar, extrapar, 
                                quantpar, sess_info, trace_stats)

    From dictionaries, plots traces by quantile/unexpected with each session in a 
    separate subplot.
    Returns figure name and save directory path.
    Required args:
        - analyspar (dict)  : dictionary with keys of AnalysPar namedtuple
        - sesspar (dict)    : dictionary with keys of SessPar namedtuple
        - stimpar (dict)    : dictionary with keys of StimPar namedtuple
        - extrapar (dict)   : dictionary containing additional analysis 
            ["analysis"] (str): analysis type (e.g., "t")
            ["datatype"] (str): datatype (e.g., "run", "roi")
        - quantpar (dict)   : dictionary with keys of QuantPar namedtuple
        - sess_info (dict)  : dictionary containing information from each
            ["mouse_ns"] (list)   : mouse numbers
            ["sess_ns"] (list)    : session numbers  
            ["lines"] (list)      : mouse lines
            ["planes"] (list)     : imaging planes
            if extrapar["datatype"] == "roi":
            ["nrois"] (list)      : number of ROIs in session

        - trace_stats (dict): dictionary containing trace stats information
            ["xrans"] (list)           : time values for the frames, for each 
            ["all_stats"] (list)       : list of 4D arrays or lists of trace 
                                         data statistics across ROIs for each
                                         session, structured as:
                                            sess x unexp x quantiles x
                                            stats (me, err) x frames
            ["all_counts"] (array-like): number of sequences, structured as:
                                                sess x unexp x quantiles
    Optional args:
        - figpar (dict): dictionary containing the following figure parameter 
                         default: None
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters
        - savedir (str): path of directory in which to save plots.
                         default: None    
        - modif (bool) : if True, modified (slimmed-down) plots are created
                         default: False
        - fulldir (str) : final path of the directory in which the figure is 
                          saved (may differ from input savedir, if datetime 
                          subfolder is added.)
        - savename (str): name under which the figure is saved
    stimstr_pr = sess_str_util.stim_par_str(
        stimpar["stimtype"], stimpar["visflow_dir"], stimpar["visflow_size"],
        stimpar["gabk"], "print")
    statstr_pr = sess_str_util.stat_par_str(
        analyspar["stats"], analyspar["error"], "print")
    dendstr_pr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"], "print")
    sessstr = sess_str_util.sess_par_str(
        sesspar["sess_n"], stimpar["stimtype"], sesspar["plane"], 
        stimpar["visflow_dir"], stimpar["visflow_size"], stimpar["gabk"])
    dendstr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"])
    datatype = extrapar["datatype"]
    dimstr = sess_str_util.datatype_dim_str(datatype)

    # extract some info from sess_info
    keys = ["mouse_ns", "sess_ns", "lines", "planes"]
    [mouse_ns, sess_ns, lines, planes] = [sess_info[key] for key in keys]
    nroi_strs = sess_str_util.get_nroi_strs(sess_info, empty=(datatype!="roi")) 

    n_sess = len(mouse_ns)

    xrans      = [np.asarray(xran) for xran in trace_stats["xrans"]]
    all_stats  = [np.asarray(sessst) for sessst in trace_stats["all_stats"]]
    all_counts = trace_stats["all_counts"]

    cols, lab_cols = sess_plot_util.get_quant_cols(quantpar["n_quants"])
    alpha = np.min([0.4, 0.8 / quantpar["n_quants"]])

    unexps = ["exp", "unexp"]
    n = 6
    if figpar is None:
        figpar = sess_plot_util.init_figpar()
    fig, ax = plot_util.init_fig(n_sess, **figpar["init"])
    for i in range(n_sess):
        sub_ax = plot_util.get_subax(ax, i)
        for s, [col, leg_ext] in enumerate(zip(cols, unexps)):
            for q, qu_idx in enumerate(quantpar["qu_idx"]):
                qu_lab = ""
                if quantpar["n_quants"] > 1:
                    qu_lab = "{} ".format(sess_str_util.quantile_str(
                        qu_idx, quantpar["n_quants"], str_type="print"
                if modif:
                    line = "2/3" if "23" in lines[i] else "5"
                    plane = "somata" if "soma" in planes[i] else "dendrites"
                    title = f"M{mouse_ns[i]} - layer {line} {plane}{dendstr_pr}"
                    leg = f"{qu_lab}{leg_ext}" if i == 0 else None
                    y_ax = None if i == 0 else ""

                    title=(f"Mouse {mouse_ns[i]} - {stimstr_pr}, " 
                        u"{}\n".format(statstr_pr) + f"across {dimstr} (sess "
                        f"{sess_ns[i]}, {lines[i]} {planes[i]}{dendstr_pr}"
                    leg = f"{qu_lab}{leg_ext} ({all_counts[i][s][q]})"
                    y_ax = None

                    sub_ax, xrans[i], all_stats[i][s, q, 0], 
                    all_stats[i][s, q, 1:], title, color=col[q], alpha=alpha, 
                    label=leg, n_xticks=n, xticks="auto")
                    sub_ax, fluor=analyspar["fluor"], datatype=datatype, 
    plot_util.turn_off_extra(ax, n_sess)

    if stimpar["stimtype"] == "gabors": 
            ax, stimpar["gabfr"], "both", pre=stimpar["pre"], 
            post=stimpar["post"], cols=lab_cols, 
    if savedir is None:
        savedir = Path(

    qu_str = f"_{quantpar['n_quants']}q"
    if quantpar["n_quants"] == 1:
        qu_str = ""

    savename = f"{datatype}_av_{sessstr}{dendstr}{qu_str}"
    fulldir = plot_util.savefig(fig, savename, savedir, **figpar["save"])

    return fulldir, savename
def plot_rel_resp_data(rel_resp_df, analyspar, sesspar, stimpar, permpar, 
                       figpar, title=None, small=True):
    plot_rel_resp_data((rel_resp_df, analyspar, sesspar, stimpar, permpar, 

    Plots relative response errorbar data across sessions.

    Required args:
        - rel_resp_df (pd.DataFrame):
            dataframe with one row per session/line/plane, and the following 
            columns, in addition to the basic sess_df columns: 
            - rel_reg or rel_exp (list): 
                data stats for regular or expected data (me, err)
            - rel_unexp (list): data stats for unexpected data (me, err)
            for reg/exp/unexp data types, session comparisons, e.g. 1v2:
            - {data_type}_raw_p_vals_{}v{} (float): uncorrected p-value for 
                data differences between sessions 
            - {data_type}_p_vals_{}v{} (float): p-value for data between 
                sessions, corrected for multiple comparisons and tails

        - analyspar (dict): 
            dictionary with keys of AnalysPar namedtuple
        - sesspar (dict):
            dictionary with keys of SessPar namedtuple
        - stimpar (dict):
            dictionary with keys of StimPar namedtuple
        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - title (str):
            plot title
            default: None
        - small (bool):
            if True, subplots are smaller
            default: False

        - ax (2D array): 
            array of subplots

    sess_ns = misc_analys.get_sess_ns(sesspar, rel_resp_df)

    figpar = sess_plot_util.fig_init_linpla(figpar)

    figpar["init"]["sharey"] = "row"
    if small:
        figpar["init"]["subplot_hei"] = 4.1
        figpar["init"]["subplot_wid"] = 2.6
        figpar["init"]["gs"] = {"hspace": 0.2, "wspace": 0.25}
        figpar["init"]["gs"]["wspace"] = 0.25
        figpar["init"]["subplot_hei"] = 4.4
        figpar["init"]["subplot_wid"] = 3.0
        figpar["init"]["gs"] = {"wspace": 0.3}

    fig, ax = plot_util.init_fig(plot_helper_fcts.N_LINPLA, **figpar["init"])

    if title is not None:
        fig.suptitle(title, y=0.98, weight="bold")

    if stimpar["stimtype"] == "gabors":
        data_types = ["rel_reg", "rel_unexp"]
    elif stimpar["stimtype"] == "visflow":
        data_types = ["rel_exp", "rel_unexp"]
            "stimpar['stimtype']", stimpar["stimtype"], ["gabors", "visflow"]

    for (line, plane), lp_df in rel_resp_df.groupby(["lines", "planes"]):
        li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs(line, plane)
        sub_ax = ax[pl, li]

        sess_indices = []
        lp_sess_ns = []
        for sess_n in sess_ns:
            rows = lp_df.loc[lp_df["sess_ns"] == sess_n]
            if len(rows) == 1:
            elif len(rows) > 1:
                raise RuntimeError("Expected 1 row per line/plane/session.")

            y=1, ls=plot_helper_fcts.HDASH, c="k", lw=3.0, alpha=0.5, zorder=-13

        colors = ["gray", col]
        fmts = ["-d", "-o"]
        alphas = [0.6, 0.8]
        ms = [12, None]
        for d, data_type in enumerate(data_types):
            data = np.asarray([lp_df.loc[i, data_type] for i in sess_indices])
                sub_ax, data[:, 0], data[:, 1:].T, lp_sess_ns, color=colors[d], 
                alpha=alphas[d], ms=ms[d], fmt=fmts[d], line_dash=dash)
    highest = None
    for dry_run in [True, False]: # to get correct data heights
        for data_type in data_types:
            ctrl = ("unexp" not in data_type)
            highest = add_between_sess_sig(
                ax, rel_resp_df, permpar, data_col=data_type, highest=highest, 
                ctrl=ctrl, p_val_prefix=True, dry_run=dry_run)
            if not dry_run:
                highest = [val * 1.05 for val in highest] # increment a bit

    sess_plot_util.format_linpla_subaxes(ax, fluor=analyspar["fluor"], 
        datatype="roi", sess_ns=sess_ns, kind="reg", xticks=sess_ns, 

    return ax

def plot_sess_data(data_df, analyspar, sesspar, permpar, figpar, 
                   between_sess_sig=True, data_col="diff_stats", 
                   decoder_data=False, title=None, wide=False):
    plot_sess_data(data_df, analyspar, sesspar, permpar, figpar)

    Plots errorbar data across sessions.

    Required args:
        - data_df (pd.DataFrame):
            dataframe with one row per session/line/plane, and the following 
            columns, in addition to the basic sess_df columns: 
            - {data_key} (list): data stats (me, err)
            - null_CIs (list): adjusted null CI for data
            - raw_p_vals (float): uncorrected p-value for data within 
            - p_vals (float): p-value for data within sessions, 
                corrected for multiple comparisons and tails
            for session comparisons, e.g. 1v2:
            - raw_p_vals_{}v{} (float): uncorrected p-value for data 
                differences between sessions 
            - p_vals_{}v{} (float): p-value for data between sessions, 
                corrected for multiple comparisons and tails

        - analyspar (dict): 
            dictionary with keys of AnalysPar namedtuple
        - sesspar (dict):
            dictionary with keys of SessPar namedtuple
        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - between_sess_sig (bool):
            if True, significance between sessions is logged and plotted
            default: True         
        - data_col (str):
            dataframe column in which data to plot is stored
            default: "diff_stats"
        - decoder_data (bool):
            if True, data plotted is decoder data
            default: False
        - title (str):
            plot title
            default: None
        - wide (bool):
            if True, subplots are wider
            default: False
        - ax (2D array): 
            array of subplots

    sess_ns = misc_analys.get_sess_ns(sesspar, data_df)

    figpar = sess_plot_util.fig_init_linpla(figpar)
    sharey = True if decoder_data else "row"
    figpar["init"]["sharey"] = sharey
    figpar["init"]["subplot_hei"] = 4.4
    figpar["init"]["gs"] = {"hspace": 0.2}
    if wide:
        figpar["init"]["subplot_wid"] = 3.0
        figpar["init"]["gs"]["wspace"] = 0.3
        figpar["init"]["subplot_wid"] = 2.6
        figpar["init"]["gs"]["wspace"] = 0.3

    fig, ax = plot_util.init_fig(plot_helper_fcts.N_LINPLA, **figpar["init"])

    if title is not None:
        fig.suptitle(title, y=0.97, weight="bold")

    sensitivity = misc_analys.get_sensitivity(permpar)
    comp_info = misc_analys.get_comp_info(permpar)

    for pass_n in [0, 1]: # add significance markers on the second pass
        if pass_n == 1:
  "{comp_info}:", extra={"spacing": "\n"})
        for (line, plane), lp_df in data_df.groupby(["lines", "planes"]):
            li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs(
                line, plane
            line_plane_name = plot_helper_fcts.get_line_plane_name(line, plane)
            sub_ax = ax[pl, li]

            sess_indices = []
            lp_sess_ns = []
            for sess_n in sess_ns:
                rows = lp_df.loc[lp_df["sess_ns"] == sess_n]
                if len(rows) == 1:
                elif len(rows) > 1:
                    raise RuntimeError("Expected 1 row per line/plane/session.")

            data = np.asarray([lp_df.loc[i, data_col] for i in sess_indices])

            if pass_n == 0:
                # plot errorbars
                    sub_ax, data[:, 0], data[:, 1:].T, lp_sess_ns, color=col, 
                    alpha=0.8, xticks="auto", line_dash=dash

            if pass_n == 1:
                # plot CIs
                CIs = np.asarray(
                    [lp_df.loc[i, "null_CIs"] for i in sess_indices]
                CI_meds = CIs[:, 1]
                CIs = CIs[:, np.asarray([0, 2])]

                plot_util.plot_CI(sub_ax, CIs.T, med=CI_meds, x=lp_sess_ns, 
                    width=0.45, color="lightgrey", med_col="gray", med_rat=0.03, 

                # add significance markers within sessions
                y_maxes = data[:, 0] + data[:, -1]
                sides = [
                    np.sign(sub[0] - CI_med) 
                    for sub, CI_med in zip(data, CI_meds)
                p_vals_corr = [lp_df.loc[i, "p_vals"] for i in sess_indices]
                lp_sig_str = f"{line_plane_name:6} (within session):"
                for s, sess_n in enumerate(lp_sess_ns):
                    sig_str = misc_analys.get_sig_symbol(
                        p_vals_corr[s], sensitivity=sensitivity, side=sides[s], 
                        tails=permpar["tails"], p_thresh=permpar["p_val"]

                    if len(sig_str):
                        plot_util.add_signif_mark(sub_ax, sess_n, y_maxes[s], 
                            rel_y=0.15, color=col, mark=sig_str)  

                    lp_sig_str = (
                        f"{lp_sig_str}{TAB} S{sess_n}: "

      , extra={"spacing": TAB})
    if between_sess_sig:
        add_between_sess_sig(ax, data_df, permpar, data_col=data_col)

    area, ylab = True, None
    if decoder_data:
        area = False
        if "balanced" in data_col:
            ylab = "Balanced accuracy (%)" 
            ylab = "Accuracy %"

    sess_plot_util.format_linpla_subaxes(ax, fluor=analyspar["fluor"], 
        area=area, ylab=ylab, datatype="roi", sess_ns=sess_ns, kind="reg", 
        xticks=sess_ns, modif_share=False)

    return ax
def plot_roi_tracking(roi_mask_df, figpar, title=None):
    plot_roi_tracking(roi_mask_df, figpar)
    Plots ROI tracking examples, for different session permutations, and union 
    across permutations.

    Required args:
        - roi_mask_df (pd.DataFrame in dict format):
            dataframe with a row for each mouse, and the following 
            columns, in addition to the basic sess_df columns: 
            - "roi_mask_shapes" (list): shape into which ROI mask indices index 
                (sess x hei x wid)
            - "union_n_conflicts" (int): number of conflicts after union
            for "union", "fewest" and "most" tracked ROIs:
            - "{}_registered_roi_mask_idxs" (list): list of mask indices, 
                registered across sessions, for each session 
                (flattened across ROIs) ((sess, hei, wid) x val),
                ordered by {}_sess_ns if "fewest" or "most"
            - "{}_n_tracked" (int): number of tracked ROIs
            for "fewest", "most" tracked ROIs:
            - "{}_sess_ns" (list): ordered session number 
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters  

    Optional args:
        - title (str):
            plot title
            default: None

        - ax (2D array): 
            array of subplots

    if len(roi_mask_df) != 1:
        raise ValueError("Expected only one row in roi_mask_df")
    roi_mask_row = roi_mask_df.loc[roi_mask_df.index[0]]

    columns = ["fewest", "most", "", "union"]

    figpar["init"]["ncols"] = len(columns)
    figpar["init"]["sharex"] = False
    figpar["init"]["sharey"] = False
    figpar["init"]["subplot_hei"] = 5.05
    figpar["init"]["subplot_wid"] = 5.05
    figpar["init"]["gs"] = {"wspace": 0.06}

    # MUST ADJUST if anything above changes [right, bottom, width, height]
    new_axis_coords = [0.905, 0.125, 0.06, 0.74] 

    fig, ax = plot_util.init_fig(plot_helper_fcts.N_LINPLA, **figpar["init"])

    sub_ax_scale = fig.add_axes(new_axis_coords)

    if title is not None:
        fig.suptitle(title, y=1.05, weight="bold")

    sess_cols = get_sess_cols(roi_mask_df)
    alpha = 0.6
    for c, column in enumerate(columns):
        sub_ax = ax[0, c]

        if c == 0:
            lp_col = plot_helper_fcts.get_line_plane_idxs(
                roi_mask_row["lines"], roi_mask_row["planes"]

            lp_name = plot_helper_fcts.get_line_plane_name(
                roi_mask_row["lines"], roi_mask_row["planes"]
            sub_ax.set_ylabel(lp_name, fontweight="bold", color=lp_col)
            log_info = f"Conflicts and matches for a {lp_name} example:"

        if column == "":
            subplot_title = \
                "     Union - conflicts\n...   ====================>"
            sub_ax.set_title(subplot_title, fontweight="bold", y=0.5)
            for spine in ["right", "left", "top", "bottom"]:

        if column in ["fewest", "most"]:
            y = 1.01
            ord_sess_ns = roi_mask_row[f"{column}_sess_ns"]
            ord_sess_ns_str = ", ".join([str(n) for n in ord_sess_ns])

            n_matches = int(roi_mask_row[f"{column}_n_tracked"])
            subplot_title = f"{n_matches} matches\n(sess {ord_sess_ns_str})"
            log_info = (f"{log_info}\n{TAB}"
                f"{column.capitalize()} matches (sess {ord_sess_ns_str}): "
        elif column == "union":
            y = 1.04
            ord_sess_ns = roi_mask_row["sess_ns"]
            n_union = int(roi_mask_row[f"{column}_n_tracked"])
            n_conflicts = int(roi_mask_row[f"{column}_n_conflicts"])
            n_matches = n_union - n_conflicts

            subplot_title = f"{n_matches} matches"
            log_info = (f"{log_info}\n{TAB}"
                "Union - conflicts: "
                f"{n_union} - {n_conflicts} = {n_matches} matches"

        sub_ax.set_title(subplot_title, fontweight="bold", y=y)

        roi_masks = create_sess_roi_masks(
        for sess_n in roi_mask_row["sess_ns"]:
            col = sess_cols[int(sess_n)]
            s = ord_sess_ns.index(sess_n)
            add_roi_mask(sub_ax, roi_masks[s], col=col, alpha=alpha)

    # add scale marker
    hei_len = roi_mask_row["roi_mask_shapes"][1]
        sub_ax_scale, side_len=hei_len, ori="vertical", quadrant=3, fontsize=20
        ), extra={"spacing": "\n"})

    # add legend
        ax[0, columns.index("")], 
        bbox_to_anchor=(0.67, 0.3), 

    return ax
def plot_roi_masks_overlayed(roi_mask_df, figpar, title=None):
    plot_roi_masks_overlayed(roi_mask_df, figpar)

    Plots ROI masks overlayed across sessions, optionally cropped.

    Required args:
        - roi_mask_df (pd.DataFrame in dict format):
            dataframe with a row for each mouse, and the following 
            columns, in addition to the basic sess_df columns: 
            - "registered_roi_mask_idxs" (list): list of mask indices, 
                registered across sessions, for each session 
                (flattened across ROIs) ((sess, hei, wid) x val)
            - "roi_mask_shapes" (list): shape into which ROI mask indices index 
                (sess x hei x wid)
            and optionally, if cropping:
            - "crop_fact" (num): factor by which to crop masks (> 1) 
            - "shift_prop_hei" (float): proportion by which to shift cropped 
                mask center vertically from left edge [0, 1]
            - "shift_prop_wid" (float): proportion by which to shift cropped 
                mask center horizontally from left edge [0, 1]

        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters  

    Optional args:
        - title (str):
            plot title
            default: None

        - ax (2D array): 
            array of subplots

    crop = "crop_fact" in roi_mask_df.columns

    figpar = sess_plot_util.fig_init_linpla(figpar)

    figpar["init"]["sharex"] = False
    figpar["init"]["sharey"] = False
    figpar["init"]["subplot_hei"] = 5.2
    figpar["init"]["subplot_wid"] = 5.2
    figpar["init"]["gs"] = {"wspace": 0.03, "hspace": 0.32}

    # MUST ADJUST if anything above changes [right, bottom, width, height]
    new_axis_coords = [0.885, 0.11, 0.1, 0.33]
    if crop: # move to the left
        new_axis_coords[0] = 0.04

    fig, ax = plot_util.init_fig(plot_helper_fcts.N_LINPLA, **figpar["init"])

    sub_ax_scale = fig.add_axes(new_axis_coords)

    if title is not None:
        fig.suptitle(title, y=0.95, weight="bold")

    sess_cols = get_sess_cols(roi_mask_df)
    alpha = 0.6
    hei_lens = []
    for (line, plane), lp_mask_df in roi_mask_df.groupby(["lines", "planes"]):
        li, pl, _, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
        sub_ax = ax[pl, li]

        if len(lp_mask_df) != 1:
            raise RuntimeError("Expected only one row per line/plane.")
        lp_row = lp_mask_df.loc[lp_mask_df.index[0]]
        roi_masks = create_sess_roi_masks(lp_row, crop=crop)

        for s, sess_n in enumerate(lp_row["sess_ns"]):
            col = sess_cols[int(sess_n)]
            add_roi_mask(sub_ax, roi_masks[s], col=col, alpha=alpha)

    # add legend
        ax[0, 1], sess_cols, bbox_to_anchor=(0.7, -0.01), alpha=alpha

    # add scale marker
    hei_lens = np.unique(hei_lens)
    if len(hei_lens) != 1:
        raise NotImplementedError(
            "Adding scale bar not implemented if ROI mask image heights are "
            "different for different planes."
    quadrant = 1 if crop else 3
        sub_ax_scale, side_len=hei_lens[0], ori="vertical", quadrant=quadrant, 
    # Add plane, line info to plots
    sess_plot_util.format_linpla_subaxes(ax, ylab="", kind="map")

    return ax
def plot_roi_masks_overlayed_with_proj(roi_mask_df, figpar, title=None):
    plot_roi_masks_overlayed_with_proj(roi_mask_df, figpar)

    Plots ROI mask contours overlayed over imaging planes, and ROI masks 
    overlayed over each other across sessions.

    Required args:
        - roi_mask_df (pd.DataFrame in dict format):
            dataframe with a row for each mouse, and the following 
            columns, in addition to the basic sess_df columns: 

            - "max_projections" (list): pixel intensities of maximum projection 
                for the plane (hei x wid)
            - "registered_roi_mask_idxs" (list): list of mask indices, 
                registered across sessions, for each session 
                (flattened across ROIs) ((sess, hei, wid) x val)
            - "roi_mask_idxs" (list): list of mask indices for each session, 
                and each ROI (sess x (ROI, hei, wid) x val) (not registered)
            - "roi_mask_shapes" (list): shape into which ROI mask indices index 
                (sess x hei x wid)

            - "crop_fact" (num): factor by which to crop masks (> 1) 
            - "shift_prop_hei" (float): proportion by which to shift cropped 
                mask center vertically from left edge [0, 1]
            - "shift_prop_wid" (float): proportion by which to shift cropped 
                mask center horizontally from left edge [0, 1]

        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters  

    Optional args:
        - title (str):
            plot title
            default: None

        - ax (2D array): 
            array of subplots

    n_lines = len(roi_mask_df["lines"].unique())
    n_planes = len(roi_mask_df["planes"].unique())

    sess_cols = get_sess_cols(roi_mask_df)
    n_sess = len(sess_cols)
    n_cols = n_sess * n_lines

    figpar = sess_plot_util.fig_init_linpla(figpar)

    figpar["init"]["sharex"] = False
    figpar["init"]["sharey"] = False
    figpar["init"]["subplot_hei"] = 2.3
    figpar["init"]["subplot_wid"] = 2.3
    figpar["init"]["gs"] = {"wspace": 0.2, "hspace": 0.2}
    figpar["init"]["ncols"] = n_cols

    fig, ax = plot_util.init_fig(n_cols * n_planes * 2, **figpar["init"])

    if title is not None:
        fig.suptitle(title, y=0.93, weight="bold")

    crop = "crop_fact" in roi_mask_df.columns

    sess_cols = get_sess_cols(roi_mask_df)
    alpha = 0.6
    raster_zorder = -12

    for (line, plane), lp_mask_df in roi_mask_df.groupby(["lines", "planes"]):
        li, pl, _, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
        lp_col = plot_helper_fcts.get_line_plane_idxs(line, plane)[2]
        lp_name = plot_helper_fcts.get_line_plane_name(line, plane)

        if len(lp_mask_df) != 1:
            raise RuntimeError("Expected only one row per line/plane.")
        lp_row = lp_mask_df.loc[lp_mask_df.index[0]]

        # identify subplots
        base_row = (pl % n_planes) * n_planes
        base_col = (li % n_lines) * n_lines

        ax_grp = ax[base_row : base_row + 2, base_col : base_col + n_sess + 1]

        # add imaging planes and masks
        imaging_planes = add_proj_and_roi_masks(
            ax_grp, lp_row, sess_cols, crop=crop, alpha=alpha, 
            proj_zorder=raster_zorder - 1

        # add markings
        shared_row = base_row + 1
        shared_col = base_col + int((n_sess - 1) // 2)
        shared_sub_ax = ax[shared_row, shared_col]

        if shared_col == 0:
            shared_sub_ax.set_ylabel(lp_name, fontweight="bold", color=lp_col)
            lp_sub_ax = ax[shared_row, 0]
            lp_sub_ax.set_xlim([0, 1])
            lp_sub_ax.set_ylim([0, 1])
                0.5, 0.5, lp_name, fontweight="bold", color=lp_col, 
                ha="center", va="center", fontsize="x-large"

        # add scale bar
        if n_sess < 2:
            raise NotImplementedError(
                "Scale bar placement not implemented for fewer than 2 "
        scale_ax = ax[shared_row, -1]
        wid_len = imaging_planes[0].shape[-1]
            scale_ax, side_len=wid_len, ori="horizontal", quadrant=1, 
            )"Rasterizing imaging plane images...", extra={"spacing": TAB})
    for i in range(ax.shape[0]):
        for j in range(ax.shape[1]):
            sub_ax = ax[i, j]
            if not(i % 2):

    # add legend
    if n_sess < 2:
        raise NotImplementedError(
            "Legend placement not implemented for fewer than 2 sessions."
        ax[-1, -1], sess_cols, bbox_to_anchor=(1, 0.6), alpha=alpha, 

    return ax
def plot_imaging_planes(imaging_plane_df, figpar, title=None):
    plot_imaging_planes(imaging_plane_df, figpar)
    Plots imaging planes.

    Required args:
        - imaging_plane_df (pd.DataFrame in dict format):
            dataframe with a row for each mouse, and the following 
            columns, in addition to the basic sess_df columns: 
            - "max_projections" (list): pixel intensities of maximum projection 
                for the plane (hei x wid)
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters  

    Optional args:
        - title (str):
            plot title
            default: None

        - ax (2D array): 
            array of subplots

    figpar = sess_plot_util.fig_init_linpla(figpar)

    figpar["init"]["sharex"] = False
    figpar["init"]["sharey"] = False
    figpar["init"]["subplot_hei"] = 2.4
    figpar["init"]["subplot_wid"] = 2.4
    figpar["init"]["gs"] = {"wspace": 0.25, "hspace": 0.2}

    # MUST ADJUST if anything above changes [right, bottom, width, height]
    new_axis_coords = [0.91, 0.115, 0.15, 0.34]

    fig, ax = plot_util.init_fig(plot_helper_fcts.N_LINPLA, **figpar["init"])

    sub_ax_scale = fig.add_axes(new_axis_coords)

    if title is not None:
        fig.suptitle(title, y=1, weight="bold")

    hei_lens = []
    raster_zorder = -12
    for (line, plane), lp_mask_df in imaging_plane_df.groupby(["lines", "planes"]):
        li, pl, _, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
        sub_ax = ax[pl, li]

        if len(lp_mask_df) != 1:
            raise RuntimeError("Expected only one row per line/plane.")
        lp_row = lp_mask_df.loc[lp_mask_df.index[0]]
        # add projection
        imaging_plane = np.asarray(lp_row["max_projections"])
        add_imaging_plane(sub_ax, imaging_plane, alpha=0.98, 
            zorder=raster_zorder - 1
    # add scale marker
    hei_lens = np.unique(hei_lens)
    if len(hei_lens) != 1:
        raise NotImplementedError(
            "Adding scale bar not implemented if ROI mask image heights are "
            "different for different planes."
        sub_ax_scale, side_len=hei_lens[0], ori="vertical", quadrant=3, 
        )"Rasterizing imaging plane images...", extra={"spacing": TAB})
    for sub_ax in ax.reshape(-1):

    # Add plane, line info to plots
    sess_plot_util.format_linpla_subaxes(ax, ylab="", kind="map")
    for sub_ax in ax.reshape(-1):

    return ax
def plot_pup_diff_corr(analyspar, sesspar, stimpar, extrapar, 
                       sess_info, corr_data, figpar=None, savedir=None):
    plot_pup_diff_corr(analyspar, sesspar, stimpar, extrapar, 
                       sess_info, corr_data)

    From dictionaries, plots correlation between unexpected-locked changes in 
    pupil diameter and running or ROI data for each session.

    Required args:
        - analyspar (dict)    : dictionary with keys of AnalysPar namedtuple
        - sesspar (dict)      : dictionary with keys of SessPar namedtuple 
        - stimpar (dict)      : dictionary with keys of StimPar namedtuple
        - extrapar (dict)     : dictionary containing additional analysis 
            ["analysis"] (str): analysis type (e.g., "c")
            ["datatype"] (str): datatype (e.g., "run", "roi")
        - sess_info (dict)    : dictionary containing information from each
            ["mouse_ns"] (list)   : mouse numbers
            ["sess_ns"] (list)    : session numbers  
            ["lines"] (list)      : mouse lines
            ["planes"] (list)     : imaging planes
            ["nrois"] (list)      : number of ROIs in session
        - corr_data (dict)    : dictionary containing data to plot:
            ["corrs"] (list): list of correlation values between pupil and 
                              running or ROI differences for each session
            ["diffs"] (list): list of differences for each session, structured
                                  as [pupil, ROI/run] x trials x frames
    Optional args:
        - figpar (dict) : dictionary containing the following figure parameter 
                          default: None
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters
        - savedir (Path): path of directory in which to save plots.
                          default: None
        - fulldir (Path): final path of the directory in which the figure is 
                          saved (may differ from input savedir, if datetime 
                          subfolder is added.)
        - savename (str): name under which the figure is saved
    statstr_pr = sess_str_util.stat_par_str(
        analyspar["stats"], analyspar["error"], "print")
    stimstr_pr = sess_str_util.stim_par_str(
        stimpar["stimtype"], stimpar["visflow_dir"], stimpar["visflow_size"], 
        stimpar["gabk"], "print")
    dendstr_pr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"], "print")

    sessstr = sess_str_util.sess_par_str(
        sesspar["sess_n"], stimpar["stimtype"], sesspar["plane"], 
        stimpar["visflow_dir"],stimpar["visflow_size"], stimpar["gabk"]) 
    dendstr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"])
    datatype = extrapar["datatype"]
    datastr = sess_str_util.datatype_par_str(datatype)

    if datatype == "roi":
        label_str = sess_str_util.fluor_par_str(
            analyspar["fluor"], str_type="print")
        full_label_str = u"{}, {} across ROIs".format(
            label_str, analyspar["stats"])
    elif datatype == "run":
        label_str = datastr
        full_label_str = datastr
    lab_app = (f" ({analyspar['stats']} over "
        f"{stimpar['pre']}/{stimpar['post']} sec)")"Plotting pupil vs {datastr} changes.")
    delta = "\u0394"

    # extract some info from sess_info
    keys = ["mouse_ns", "sess_ns", "lines", "planes"]
    [mouse_ns, sess_ns, lines, planes] = [sess_info[key] for key in keys]

    n_sess = len(mouse_ns)
    nroi_strs = sess_str_util.get_nroi_strs(
        sess_info, empty=(datatype!="roi"), style="comma"

    if figpar is None:
        figpar = sess_plot_util.init_figpar()

    figpar = copy.deepcopy(figpar)
    if figpar["save"]["use_dt"] is None:
        figpar["save"]["use_dt"] = gen_util.create_time_str()
    figpar["init"]["sharex"] = False
    figpar["init"]["sharey"] = False
    figpar["init"]["ncols"] = n_sess
    fig, ax = plot_util.init_fig(2 * n_sess, **figpar["init"])
    suptitle = (f"Relationship between pupil diam. and {datastr} changes, "
        "locked to unexpected events")
    for i, sess_diffs in enumerate(corr_data["diffs"]):
        sub_axs = ax[:, i]
        title = (f"Mouse {mouse_ns[i]} - {stimstr_pr}, " + 
            u"{}".format(statstr_pr) + f"\n(sess {sess_ns[i]}, {lines[i]} "
        # top plot: correlations
        corr = f"Corr = {corr_data['corrs'][i]:.2f}"
            sess_diffs[0], sess_diffs[1], marker=".", linestyle="None", 
        sub_axs[0].set_title(title, y=1.01)
        sub_axs[0].set_xlabel(u"{} pupil diam.{}".format(delta, lab_app))
        if i == 0:
            sub_axs[0].set_ylabel(u"{} {}\n{}".format(
                delta, full_label_str, lab_app))
        # bottom plot: differences across occurrences
        data_lab = u"{} {}".format(delta, label_str)   
        pup_lab = u"{} pupil diam.".format(delta)
        cols = []
        scaled = []
        for d, lab in enumerate([pup_lab, data_lab]):
                np.asarray(sess_diffs[d]), sc_type="min_max")[0])
            art, = sub_axs[1].plot(scaled[-1], marker=".")
            if i == n_sess - 1: # only for last graph
        sub_axs[1].set_xlabel("Unexpected event occurrence")
        if i == 0:
                u"{} response locked\nto unexpected onset (scaled)".format(delta))
        # shade area between lines
            sub_axs[1], scaled[0], scaled[1], color=cols, alpha=0.4)

    fig.suptitle(suptitle, fontsize="xx-large", y=1)

    if savedir is None:
        savedir = Path(

    savename = f"{datatype}_diff_corr_{sessstr}{dendstr}"

    fulldir = plot_util.savefig(fig, savename, savedir, **figpar["save"])

    return fulldir, savename                              
def plot_pup_roi_stim_corr(analyspar, sesspar, stimpar, extrapar, 
                           sess_info, corr_data, figpar=None, savedir=None):
    plot_pup_roi_stim_corr(analyspar, sesspar, stimpar, extrapar, 
                           sess_info, corr_data)

    From dictionaries, plots correlation between unexpected-locked changes in 
    pupil diameter and each ROI, for gabors versus visual flow responses for 
    each session.
    Required args:
        - analyspar (dict)    : dictionary with keys of AnalysPar namedtuple
        - sesspar (dict)      : dictionary with keys of SessPar namedtuple 
        - stimpar (dict)      : dictionary with keys of StimPar namedtuple
        - extrapar (dict)     : dictionary containing additional analysis 
            ["analysis"] (str): analysis type (e.g., "r")
            ["datatype"] (str): datatype (e.g., "roi")
        - sess_info (dict)    : dictionary containing information from each
            ["mouse_ns"] (list)   : mouse numbers
            ["sess_ns"] (list)    : session numbers  
            ["lines"] (list)      : mouse lines
            ["planes"] (list)     : imaging planes
            ["nrois"] (list)      : number of ROIs in session
        - corr_data (dict)    : dictionary containing data to plot:
            ["stim_order"] (list): ordered list of stimtypes
            ["roi_corrs"] (list) : nested list of correlations between pupil 
                                   and ROI responses changes locked to 
                                   unexpected, structured as 
                                       session x stimtype x ROI
            ["corrs"] (list)     : list of correlation between stimtype
                                   correlations for each session
    Optional args:
        - figpar (dict) : dictionary containing the following figure parameter 
                          default: None
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters
        - savedir (Path): path of directory in which to save plots.
                          default: None
        - fulldir (Path): final path of the directory in which the figure is 
                          saved (may differ from input savedir, if datetime 
                          subfolder is added.)
        - savename (str): name under which the figure is saved

    stimstr_prs = []
    for stimtype in corr_data["stim_order"]:
        stimstr_pr = sess_str_util.stim_par_str(
            stimtype, stimpar["visflow_dir"], stimpar["visflow_size"], 
            stimpar["gabk"], "print")
        stimstr_pr = stimstr_pr[:-1] if stimstr_pr[-1] == "s" else stimstr_pr
    dendstr_pr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"], "print")

    sessstr = f"sess{sesspar['sess_n']}_{sesspar['plane']}" 
    dendstr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"])

    label_str = sess_str_util.fluor_par_str(
        analyspar["fluor"], str_type="print")
    lab_app = (f" ({analyspar['stats']} over "
        f"{stimpar['pre']}/{stimpar['post']} sec)")"Plotting pupil-ROI difference correlations for "
        "{} vs {}.".format(*corr_data["stim_order"]))

    # extract some info from sess_info
    keys = ["mouse_ns", "sess_ns", "lines", "planes"]
    [mouse_ns, sess_ns, lines, planes] = [sess_info[key] for key in keys]

    n_sess = len(mouse_ns)
    nroi_strs = sess_str_util.get_nroi_strs(sess_info, style="comma")

    if figpar is None:
        figpar = sess_plot_util.init_figpar()

    figpar = copy.deepcopy(figpar)
    if figpar["save"]["use_dt"] is None:
        figpar["save"]["use_dt"] = gen_util.create_time_str()
    figpar["init"]["sharex"] = True
    figpar["init"]["sharey"] = True
    fig, ax = plot_util.init_fig(n_sess, **figpar["init"])
    suptitle = (u"Relationship between pupil diam. and {} changes, locked to "
        "unexpected events\n{} for each ROI ({} vs {})".format(
            label_str, lab_app, *corr_data["stim_order"]))
    for i, sess_roi_corrs in enumerate(corr_data["roi_corrs"]):
        sub_ax = plot_util.get_subax(ax, i)
        title = (f"Mouse {mouse_ns[i]} (sess {sess_ns[i]}, {lines[i]} "
        # top plot: correlations
        corr = f"Corr = {corr_data['corrs'][i]:.2f}"
            sess_roi_corrs[0], sess_roi_corrs[1], marker=".", linestyle="None", 
        sub_ax.set_title(title, y=1.01)
        if sub_ax.is_last_row():
            sub_ax.set_xlabel(f"{stimstr_prs[0].capitalize()} correlations")
        if sub_ax.is_first_col():
            sub_ax.set_ylabel(f"{stimstr_prs[1].capitalize()} correlations")

    plot_util.turn_off_extra(ax, n_sess)

    fig.suptitle(suptitle, fontsize="xx-large", y=1)

    if savedir is None:
        savedir = Path(

    savename = f"roi_diff_corrbyroi_{sessstr}{dendstr}"

    fulldir = plot_util.savefig(fig, savename, savedir, **figpar["save"])

    return fulldir, savename                           
def plot_idx_corr_scatterplots(idx_corr_df,
    plot_idx_corr_scatterplots(idx_corr_df, permpar, figpar)

    Plots ROI USI index correlation scatterplots for individual session 

    Required args:
        - idx_corr_df (pd.DataFrame):
            dataframe with one row per line/plane, and the 
            following columns, in addition to the basic sess_df columns:

            for correlation data (normalized if corr_type is "diff_corr") for 
            session comparisons (x, y), e.g. 1v2
            - binned_rand_stats (list): number of random correlation values per 
                bin (xs x ys)
            - corr_data_xs (list): USI values for x
            - corr_data_ys (list): USI values for y
            - corrs (float): correlation between session data (x and y)
            - p_vals (float): p-value for correlation, corrected for 
                multiple comparisons and tails
            - rand_corr_meds (float): median of the random correlations
            - raw_p_vals (float): p-value for intersession correlations
            - regr_coefs (float): regression correlation coefficient (slope)
            - regr_intercepts (float): regression correlation intercept
            - x_bin_mids (list): x mid point for each random correlation bin
            - y_bin_mids (list): y mid point for each random correlation bin

        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - permute (bool):
            type of permutation to due ("tracking", "sess" or "all")
            default: "sess"
        - title (str):
            plot title
            default: None

        - ax (2D array): 
            array of subplots

    diffs = False
    if permute in ["sess", "all"]:
        diffs = True

    figpar = sess_plot_util.fig_init_linpla(figpar, kind="reg")

    figpar["init"]["sharex"] = False
    figpar["init"]["sharey"] = False
    figpar["init"]["subplot_hei"] = 4
    figpar["init"]["subplot_wid"] = 4
    figpar["init"]["gs"] = {"hspace": 0.4, "wspace": 0.4}

    fig, ax = plot_util.init_fig(4, **figpar["init"])

    if title is not None:
        fig.suptitle(title, fontweight="bold", y=0.97)

    sess_ns = None

    # first pass to plot
    for (line, plane), lp_df in idx_corr_df.groupby(["lines", "planes"]):
        li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
        sub_ax = ax[pl, li]

        if len(lp_df) != 1:
            raise RuntimeError("Expected exactly one row.")
        lp_row = lp_df.loc[lp_df.index[0]]

        if sess_ns is None:
            sess_ns = lp_row["sess_ns"]
            xlabel = f"Session {sess_ns[0]} USIs"
            ylabel = f"Session {sess_ns[1]} USIs"
            if diffs:
                ylabel = f"Session {sess_ns[1]} - {sess_ns[0]} USIs"

        elif sess_ns != lp_row["sess_ns"]:
            raise RuntimeError("Expected all sess_ns to match.")

        density_data = [
            lp_row["x_bin_mids"], lp_row["y_bin_mids"],

        alpha = 0.3**(len(lp_row["corr_data_xs"]) / 300)

    # Add plane, line info to plots

    # second pass to add plot markings
    comp_info = misc_analys.get_comp_info(permpar)"{comp_info}:", extra={"spacing": "\n"})
    sig_str = ""
    for (line, plane), lp_df in idx_corr_df.groupby(["lines", "planes"]):
        li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
        line_plane_name = plot_helper_fcts.get_line_plane_name(line, plane)
        sub_ax = ax[pl, li]

        # add markers back in (removed due to kind='reg')
        sub_ax.tick_params(axis="x", which="both", bottom=True, top=False)

        lp_row = lp_df.loc[lp_df.index[0]]

        p_val_corr = lp_row["p_vals"]
        lp_sig_str = add_scatterplot_markers(sub_ax,

        sig_str = (f"{sig_str}{TAB}{line_plane_name}: "
                   f"{p_val_corr:.5f}{lp_sig_str:3}"), extra={"spacing": TAB})

    return ax
def plot_roi_correlations(corr_df, figpar, title=None, log_scale=True):
    plot_roi_correlations(corr_df, figpar)

    Plots correlation histograms.

    Required args:
        - corr_df (pd.DataFrame):
            dataframe with one row per session/line/plane, and the 
            following columns, in addition to the basic sess_df columns:
            - bin_edges (list): first and last bin edge
            - corrs_binned (list): number of correlation values per bin
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - title (str):
            plot title
            default: None
        - log_scale (bool):
            if True, a near logarithmic scale is used for the y axis (with a 
            linear range to reach 0, and break marks to mark the transition 
            from linear to log range)
            default: True

        - ax (2D array): 
            array of subplots

    sess_ns = np.arange(corr_df.sess_ns.min(), corr_df.sess_ns.max() + 1)
    n_sess = len(sess_ns)

    figpar = sess_plot_util.fig_init_linpla(figpar,
    figpar["init"]["subplot_hei"] = 3.0
    figpar["init"]["subplot_wid"] = 2.8
    figpar["init"]["sharex"] = log_scale
    if log_scale:
        figpar["init"]["sharey"] = True

    fig, ax = plot_util.init_fig(4 * len(sess_ns), **figpar["init"])
    if title is not None:
        fig.suptitle(title, y=1.02, weight="bold")


    log_base = 2
    for (line, plane), lp_df in corr_df.groupby(["lines", "planes"]):
        li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
        for s, sess_n in enumerate(sess_ns):
            sess_rows = lp_df.loc[lp_df["sess_ns"] == sess_n]
            if len(sess_rows) == 0:
            elif len(sess_rows) > 1:
                raise RuntimeError("Expected exactly one row.")
            sess_row = sess_rows.loc[sess_rows.index[0]]

            sub_ax = ax[pl, s + li * n_sess]

            weights = np.asarray(sess_row["corrs_binned"])

            bin_edges = np.linspace(*sess_row["bin_edges"], len(weights) + 1)


            sub_ax.tick_params(axis="x", which="both", bottom=True, top=False)

            if log_scale:
                sub_ax.set_yscale("log", base=log_base)
                sub_ax.set_xlim(-1, 1)
                sub_ax.autoscale(axis="x", tight=True)

            sub_ax.autoscale(axis="y", tight=True)

    if log_scale:  # update x ticks
        set_symlog_scale(ax, log_base=log_base, col_per_grp=n_sess, n_ticks=4)

    else:  # update x and y ticks
        for i in range(ax.shape[0]):
            for j in range(int(ax.shape[1] / n_sess)):
                sub_axes = ax[i, j * n_sess:(j + 1) * n_sess]


    return ax
def plot_sess_traces(data_df, analyspar, sesspar, figpar, 
                     trace_col="trace_stats", row_col="sess_ns", 
                     row_order=None, split="by_exp", title=None, size="reg"):
    plot_sess_traces(data_df, analyspar, sesspar, figpar) 
    Plots traces from dataframe.

    Required args:
        - data_df (pd.DataFrame):
            traces data frame with, in addition to the basic sess_df columns, 
            columns specified by trace_col, row_col, and a "time_values" column
        - analyspar (dict): 
            dictionary with keys of AnalysPar namedtuple
        - sesspar (dict):
            dictionary with keys of SessPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - trace_col (str):
             dataframe column containing trace statistics, as 
             split x ROIs x frames x stats 
             default: "trace_stats"
        - row_col (str):
            dataframe column specifying the variable that defines rows 
            within each line/plane
            default: "sess_ns"
        - row_order (list):
            ordered list specifying the order in which to plot from row_col.
            If None, automatic sorting order is used.
            default: None 
        - split (str):
            data split, e.g. "exp_lock", "unexp_lock", "stim_onset" or 
            default: False
        - title (str):
            plot title
            default: None
        - size (str):
            subplot sizes
            default: "reg"

        - ax (2D array): 
            array of subplots
    # retrieve session numbers, and infer row_order, if necessary
    sess_ns = None
    if row_col == "sess_ns":
        sess_ns = row_order
        if row_order is None:
            row_order = misc_analys.get_sess_ns(sesspar, data_df)

    elif row_order is None:
        row_order = data_df[row_col].unique()

    figpar = sess_plot_util.fig_init_linpla(
        figpar, kind="traces", n_sub=len(row_order), sharey=False

    if size == "small":
        figpar["init"]["subplot_hei"] = 1.51
        figpar["init"]["subplot_wid"] = 3.7
    elif size == "wide":
        figpar["init"]["subplot_hei"] = 1.36
        figpar["init"]["subplot_wid"] = 4.8
        figpar["init"]["gs"] = {"wspace": 0.3, "hspace": 0.5}
    elif size == "reg":
        figpar["init"]["subplot_hei"] = 1.36
        figpar["init"]["subplot_wid"] = 3.4
        gen_util.accepted_values_error("size", size, ["small", "wide", "reg"])

    fig, ax = plot_util.init_fig(len(row_order) * 4, **figpar["init"])
    if title is not None:
        fig.suptitle(title, y=1.0, weight="bold")

    for (line, plane), lp_df in data_df.groupby(["lines", "planes"]):
        li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs(line, plane)

        for r, row_val in enumerate(row_order):
            rows = lp_df.loc[lp_df[row_col] == row_val]
            if len(rows) == 0:
            elif len(rows) > 1:
                raise RuntimeError(
                    "Expected row_order instances to be unique per line/plane."
            row = rows.loc[rows.index[0]]

            sub_ax = ax[r + pl * len(row_order), li]

            if line == "L2/3-Cux2":
                exp_col = "darkgray" # oddly, lighter than gray
                exp_col = "gray"

                sub_ax, row["time_values"], row[trace_col], split=split, 
                col=col, ls=dash, exp_col=exp_col, lab=False
    for sub_ax in ax.reshape(-1):
        plot_util.set_minimal_ticks(sub_ax, axis="y")

    sess_plot_util.format_linpla_subaxes(ax, fluor=analyspar["fluor"], 
        area=False, datatype="roi", sess_ns=sess_ns, xticks=None, 
        kind="traces", modif_share=False)

   # fix x ticks and lims
    plot_util.set_interm_ticks(ax, 3, axis="x", fontweight="bold")
    xlims = [np.min(row["time_values"]), np.max(row["time_values"])]
    if split != "by_exp":
        xlims = [-xlims[1], xlims[1]]

    return ax
def plot_snr_sigmeans_nrois(data_df,
                            title="ROI SNRs"):
    plot_snr_sigmeans_nrois(data_df, figpar)

    Plots SNR, signal means or number of ROIs, depending on the case.

    Required args:
        - data_df (pd.DataFrame):
            dataframe with SNR, signal mean or number of ROIs data for each 
            session, in addition to the basic sess_df columns
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - datatype (str):
            type of data to plot, also corresponding to column name
            default: "snrs"
        - title (str):
            plot title
            default: "ROI SNRs"

        - ax (2D array): 
            array of subplots

    sess_ns = np.arange(data_df.sess_ns.min(), data_df.sess_ns.max() + 1)

    figpar = sess_plot_util.fig_init_linpla(figpar, kind="reg")
    figpar["init"]["sharey"] = "row"

    figpar["init"]["subplot_hei"] = 4.4
    figpar["init"]["gs"] = {"wspace": 0.2, "hspace": 0.2}
    if datatype != "nrois":
        figpar["init"]["subplot_wid"] = 3.2
        figpar["init"]["subplot_wid"] = 2.5

    fig, ax = plot_util.init_fig(4, **figpar["init"])
    if title is not None:
        fig.suptitle(title, y=0.97, weight="bold")

    for (line, plane), lp_df in data_df.groupby(["lines", "planes"]):
        li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs(line, plane)

        sub_ax = ax[pl, li]

        if datatype == "snrs":
        elif datatype == "signal_means":
        elif datatype != "nrois":
            gen_util.accepted_values_error("datatype", datatype,
                                           ["snrs", "signal_means", "nrois"])

        if datatype == "nrois":
            plot_nrois(sub_ax, lp_df, sess_ns=sess_ns, col=col, dash=dash)

        data = []
        use_sess_ns = []
        for sess_n in sess_ns:
            rows = lp_df.loc[lp_df["sess_ns"] == sess_n]
            if len(rows) > 0:

                       whis=[5, 95],
                       capprops=dict(color=col, linewidth=3.0),
                       whiskerprops=dict(color=col, linewidth=3.0),
                       medianprops=dict(color=col, linewidth=3.0))


    return ax
def plot_ex_gabor_traces(ex_traces_df, stimpar, figpar, title=None):
    plot_ex_gabor_traces(ex_traces_df, stimpar, figpar)

    Plots example Gabor traces.

    Required args:
        - ex_traces_df (pd.DataFrame):
            dataframe with a row for each ROI, and the following columns, 
            in addition to the basic sess_df columns: 
            - time_values (list): values for each frame, in seconds
            - roi_ns (list): selected ROI number
             - traces_sm (list): selected ROI sequence traces, smoothed, with 
                dims: seq x frames
            - trace_stat (list): selected ROI trace mean or median
        - stimpar (dict):
            dictionary with keys of StimPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters
    Optional args:
        - title (str):
            plot title
            default: None

        - ax (2D array): 
            array of subplots

    if stimpar["stimtype"] != "gabors":
        raise ValueError("Expected stimpar['stimtype'] to be 'gabors'.")

    group_columns = ["lines", "planes"]
    n_per = np.max(
        [len(lp_df) for _, lp_df in ex_traces_df.groupby(group_columns)]
    per_rows, per_cols = math_util.get_near_square_divisors(n_per)
    n_per = per_rows * per_cols

    figpar = sess_plot_util.fig_init_linpla(
        figpar, kind="traces", n_sub=per_rows
    figpar["init"]["subplot_hei"] = 1.36
    figpar["init"]["subplot_wid"] = 2.47
    figpar["init"]["ncols"] = per_cols * 2

    fig, ax = plot_util.init_fig(
        plot_helper_fcts.N_LINPLA * n_per, **figpar["init"]
    if title is not None:
        fig.suptitle(title, y=1.03, weight="bold")

    ylims = np.full(ax.shape + (2, ), np.nan)"Plotting individual traces...", extra={"spacing": TAB})
    raster_zorder = -12
    for (line, plane), lp_df in ex_traces_df.groupby(["lines", "planes"]):
        li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs(line, plane)
        for i, idx in enumerate(lp_df.index):
            row_idx = int(pl * per_rows + i % per_rows)
            col_idx = int(li * per_cols + i // per_rows)
            sub_ax = ax[row_idx, col_idx]

            ylims[row_idx, col_idx] = plot_ex_gabor_roi_traces(
                zorder=raster_zorder - 1

        time_values = np.asarray(lp_df.loc[lp_df.index[-1], "time_values"])
    sess_plot_util.format_linpla_subaxes(ax, fluor="dff", 
        area=False, datatype="roi", sess_ns=None, xticks=None, kind="traces", 

   # fix x ticks and lims
    for sub_ax in ax.reshape(-1):
        xlims = [time_values[0], time_values[-1]]
        xticks = np.linspace(*xlims, 6)
    plot_util.set_interm_ticks(ax, 3, axis="x", fontweight="bold", skip=False)
    for sub_ax in ax.reshape(-1):
    # reset y limits
    for r in range(ax.shape[0]):
        for c in range(ax.shape[1]):
            if not np.isfinite(ylims[r, c].sum()):
            ax[r, c].set_ylim(ylims[r, c])

        ax, 2, axis="y", share=False, weight="bold", update_ticks=True

    # rasterize the gray lines"Rasterizing individual traces...", extra={"spacing": TAB})
    for sub_ax in ax.reshape(-1):

    return ax
def plot_idxs(idx_df, sesspar, figpar, plot="items", density=True, n_bins=40, 
              title=None, size="reg"):
    plot_idxs(idx_df, sesspar, figpar)

    Returns exact color for a specific line.

    Required args:
        - idx_df (pd.DataFrame):
            dataframe with indices for different line/plane combinations, and 
            the following columns, in addition to the basic sess_df columns:
            - rand_idx_binned (list): bin counts for the random ROI indices
            - bin_edges (list): first and last bin edge
            - CI_edges (list): confidence interval limit values
            - CI_perc (list): confidence interval percentile limits
            if plot == "items":
            - roi_idx_binned (list): bin counts for the ROI indices
            if plot == "percs":
            - perc_idx_binned (list): bin counts for the ROI index percentiles
            - n_signif_lo (int): number of significant ROIs (low) 
            - n_signif_hi (int): number of significant ROIs (high)

        - sesspar (dict): 
            dictionary with keys of SessPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - plot (str): 
            type of data to plot ("items" or "percs")
            default: "items"
        - density (bool): 
            if True, histograms are plotted as densities
            default: True
        - n_bins (int): 
            number of bins to use in histograms
            default: 40
        - title (str): 
            plot title
            default: None
        - size (str): 
            plot size ("reg", "small" or "tall")
            default: "reg"
        - ax (2D array): 
            array of subplots

    if plot == "items":
        data_key = "roi_idx_binned"
        CI_key = "CI_edges"
    elif plot == "percs":
        data_key = "perc_idx_binned"
        CI_key = "CI_perc"
        gen_util.accepted_values_error("plot", plot, ["items", "percs"])

    sess_ns = misc_analys.get_sess_ns(sesspar, idx_df)

    n_plots = len(sess_ns) * 4
    figpar["init"]["sharey"] = "row"
    figpar = sess_plot_util.fig_init_linpla(figpar, kind="idx", 
        n_sub=len(sess_ns), sharex=(plot == "percs"))

    y = 1
    if size == "reg":
        subplot_hei = 3.2
        subplot_wid = 5.5
    elif size == "small":
        y = 1.04
        subplot_hei = 2.40
        subplot_wid = 3.75
        figpar["init"]["gs"] = {"hspace": 0.25, "wspace": 0.30}
    elif size == "tall":
        y = 0.98
        subplot_hei = 5.3
        subplot_wid = 5.55
        gen_util.accepted_values_error("size", size, ["reg", "small", "tall"])
    figpar["init"]["subplot_hei"] = subplot_hei
    figpar["init"]["subplot_wid"] = subplot_wid
    figpar["init"]["sharey"] = "row"
    fig, ax = plot_util.init_fig(n_plots, **figpar["init"])
    if title is not None:
        fig.suptitle(title, y=y, weight="bold")

    for (line, plane), lp_df in idx_df.groupby(["lines", "planes"]):
        li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)

        for s, sess_n in enumerate(sess_ns):
            rows = lp_df.loc[lp_df["sess_ns"] == sess_n]
            if len(rows) == 0:
            elif len(rows) > 1:
                raise RuntimeError(
                    "Expected sess_ns to be unique per line/plane."
            row = rows.loc[rows.index[0]]

            sub_ax = ax[s + pl * len(sess_ns), li]

            # get percentage significant label
            perc_label = None
            if "n_signif_lo" in row.keys() and "n_signif_hi" in row.keys():
                n_sig_lo, n_sig_hi = row["n_signif_lo"], row["n_signif_hi"]
                nrois = np.sum(row["nrois"])
                perc_signif = np.sum([n_sig_lo, n_sig_hi]) / nrois * 100
                perc_label = (f"{perc_signif:.2f}% sig\n"
                    f"({n_sig_lo}-/{n_sig_hi}+ of {nrois})")                

                sub_ax, row[data_key], row[CI_key], n_bins=n_bins, 
                plot=plot, col=col, density=density, perc_label=perc_label)
            if size == "small":

    # Add plane, line info to plots
    y_lab = "Density" if density else f"N ROIs" 
    sess_plot_util.format_linpla_subaxes(ax, datatype="roi", ylab=y_lab, 
        xticks=None, sess_ns=None, kind="idx", modif_share=False, 
        xlab="Index", single_lab=True)

    # Add indices after setting formatting
    if plot == "percs":
        nticks = 5
        xticks = [int(np.around(x, 0)) for x in np.linspace(0, 100, nticks)]
        for sub_ax in ax[-1]:
            sub_ax.set_xticklabels(xticks, weight="bold")
    elif plot == "items":
        nticks = 3
            ax, nticks, axis="x", weight="bold", share=False, skip=False
        gen_util.accepted_values_error("plot", plot, ["items", "percs"])
    return ax
def plot_full_traces(analyspar, sesspar, extrapar, sess_info, trace_info, 
                     roi_tr=None, figpar=None, savedir=None):
    plot_full_traces(analyspar, sesspar, extrapar, sess_info, trace_info)

    From dictionaries, plots full traces for each session in a separate subplot.
    Returns figure name and save directory path.
    Required args:
        - analyspar (dict)  : dictionary with keys of AnalysPar namedtuple
        - sesspar (dict)    : dictionary with keys of SessPar namedtuple
        - extrapar (dict)   : dictionary containing additional analysis 
            ["analysis"] (str): analysis type (e.g., "f")
            ["datatype"] (str): datatype (e.g., "run", "roi")
        - sess_info (dict)  : dictionary containing information from each
            ["mouse_ns"] (list)   : mouse numbers
            ["sess_ns"] (list)    : session numbers  
            ["lines"] (list)      : mouse lines
            ["planes"] (list)     : imaging planes
            ["nrois"] (list)      : number of ROIs in session

        - trace_info (dict): dictionary containing trace information
            ["all_tr"] (nested list): trace values structured as
                                          sess x 
                                          (me/err if datatype is "roi" x)
            ["all_edges"] (list)    : frame edge values for each parameter, 
                                      structured as sess x block x 
                                                    edges ([start, end])
            ["all_pars"] (list)     : stimulus parameter strings structured as 
                                                    sess x block
    Optional args:
        - roi_tr (list): trace values for each ROI, structured as 
                         sess x ROI x frames
                         default: None
        - figpar (dict): dictionary containing the following figure parameter 
                         default: None
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters
        - savedir (str): path of directory in which to save plots.
                         default: None    
        - fulldir (str) : final path of the directory in which the figure is 
                          saved (may differ from input savedir, if datetime 
                          subfolder is added.)
        - savename (str): name under which the figure is saved
    statstr_pr = sess_str_util.stat_par_str(
        analyspar["stats"], analyspar["error"], "print")
    dendstr_pr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"], "print")
    sessstr = f"sess{sesspar['sess_n']}_{sesspar['plane']}"
    dendstr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"])

    datatype = extrapar["datatype"]

    # extract some info from sess_info
    keys = ["mouse_ns", "sess_ns", "lines", "planes"]
    [mouse_ns, sess_ns, lines, planes] = [sess_info[key] for key in keys]
    nroi_strs = sess_str_util.get_nroi_strs(sess_info, empty=(datatype!="roi")) 
    n_sess = len(mouse_ns)

    if figpar is None:
        figpar = sess_plot_util.init_figpar()
    figpar = copy.deepcopy(figpar)
    figpar["init"]["subplot_wid"] = 10
    figpar["init"]["ncols"] = n_sess
    if datatype == "roi":
        figpar["save"]["fig_ext"] = "jpg"
        figpar["init"]["sharex"] = False
        figpar["init"]["sharey"] = False
        # set subplot ratios and removes extra space between plots vertically
        gs = {"height_ratios": [5, 1], "hspace": 0.1} 
        n_rows = 2
        if roi_tr is None:
            raise ValueError("Cannot plot data as ROI traces are missing "
                "(not recorded in dictionary, likely due to size).")
        gs = None
        n_rows = 1

    if datatype == "roi" and not figpar["save"]["save_fig"]:
        warnings.warn("Figure plotting is being skipped. Since full ROI traces "
            "are not saved to dictionary, to actually plot traces, analysis "
            "will have to be rerun with 'save_fig' set to True.", 

    fig, ax = plot_util.init_fig(n_sess*n_rows, gs=gs, **figpar["init"])

    label_height = 0.8
    if datatype == "roi":
        fig.subplots_adjust(top=0.92) # remove extra white space at top
        label_height = 0.55
    for i in range(n_sess):
        title = (f"Mouse {mouse_ns[i]} (sess {sess_ns[i]}, {lines[i]} "
        sub_axs = ax[:, i]
        sub_axs[0].set_title(title, y=1.02)
        if datatype == "roi":
            xran = range(len(trace_info["all_tr"][i][1]))   
            # each ROI (top subplot)
            plot_util.plot_sep_data(sub_axs[0], np.asarray(roi_tr[i]), 0.1)
                sub_axs[0], fluor=analyspar["fluor"], scale=True, 
                datatype=datatype, x_ax="")
            # average across ROIs (bottom subplot)
            av_tr = np.asarray(trace_info["all_tr"][i])
            subtitle = u"{} across ROIs".format(statstr_pr)
                sub_axs[1], xran, av_tr[0], av_tr[1:], lw=0.2, xticks="auto",
            xran = range(len(trace_info["all_tr"][i]))
            run_tr = np.asarray(trace_info["all_tr"][i])
            sub_axs[0].plot(run_tr, lw=0.2)
        for b, block in enumerate(trace_info["all_edges"][i]):
            # all block labels to the lower plot
                sub_axs[-1], trace_info["all_pars"][i][b], np.mean(block), 
                label_height, color="k")
                sub_axs[-1], fluor=analyspar["fluor"], datatype=datatype, 
            plot_util.remove_ticks(sub_axs[-1], True, False)
            plot_util.remove_graph_bars(sub_axs[-1], bars="horiz")
            # add lines to both plots
            for r in range(n_rows):
                plot_util.add_bars(sub_axs[r], bars=block)
    if savedir is None:
        savedir = Path(

    y = 1 if datatype == "run" else 0.98
    fig.suptitle("Full traces across sessions", fontsize="xx-large", y=y)

    savename = f"{datatype}_tr_{sessstr}{dendstr}"
    fulldir = plot_util.savefig(
        fig, savename, savedir, dpi=400, **figpar["save"])

    return fulldir, savename
def plot_perc_sig_usis(perc_sig_df, analyspar, permpar, figpar, by_mouse=False, 
    plot_perc_sig_usis(perc_sig_df, analyspar, figpar)

    Plots percentage of significant USIs.

    Required args:
        - perc_sig_df (pd.DataFrame):
            dataframe with one row per session/line/plane, and the following 
            columns, in addition to the basic sess_df columns:
            for sig in ["lo", "hi"]: for low vs high ROI indices
            - perc_sig_{sig}_idxs (num): percent significant ROIs (0-100)
            - perc_sig_{sig}_idxs_stds (num): bootstrapped standard deviation 
                over percent significant ROIs
            - perc_sig_{sig}_idxs_CIs (list): adjusted CI for percent sig. ROIs 
            - perc_sig_{sig}_idxs_null_CIs (list): adjusted null CI for percent 
                sig. ROIs
            - perc_sig_{sig}_idxs_raw_p_vals (num): uncorrected p-value for 
                percent sig. ROIs
            - perc_sig_{sig}_idxs_p_vals (num): p-value for percent sig. 
                ROIs, corrected for multiple comparisons and tails

        - analyspar (dict): 
            dictionary with keys of AnalysPar namedtuple
        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - by_mouse (bool):
            if True, plotting is done per mouse
            default: False
        - title (str):
            plot title
            default: None
        - ax (2D array): 
            array of subplots

    perc_sig_df = perc_sig_df.copy(deep=True)

    nanpol = None if analyspar["rem_bad"] else "omit"

    sess_ns = perc_sig_df["sess_ns"].unique()
    if len(sess_ns) != 1:
        raise NotImplementedError(
            "Plotting function implemented for 1 session only."

    figpar = sess_plot_util.fig_init_linpla(figpar, kind="idx", n_sub=1, 
        sharex=True, sharey=True)

    figpar["init"]["sharey"] = True
    figpar["init"]["subplot_wid"] = 3.4
    figpar["init"]["gs"] = {"wspace": 0.18}
    if by_mouse:
        figpar["init"]["subplot_hei"] = 8.4
        figpar["init"]["subplot_hei"] = 3.5

    fig, ax = plot_util.init_fig(2, **figpar["init"])
    if title is not None:
        y = 0.98 if by_mouse else 1.07
        fig.suptitle(title, y=y, weight="bold")

    tail_order = ["Low tail", "High tail"]
    tail_keys = ["lo", "hi"]
    chance = permpar["p_val"] / 2 * 100

    ylims = get_perc_sig_ylims(perc_sig_df, high_pt_min=40)
    n_linpla = plot_helper_fcts.N_LINPLA

    comp_info = misc_analys.get_comp_info(permpar)"{comp_info}:", extra={"spacing": "\n"})
    for t, (tail, key) in enumerate(zip(tail_order, tail_keys)):
        sub_ax = ax[0, t]
        sub_ax.set_title(tail, fontweight="bold")

        # replace bottom spine with line at 0
        sub_ax.axhline(y=0, c="k", lw=4.0)

        data_key = f"perc_sig_{key}_idxs"

        CIs = np.full((plot_helper_fcts.N_LINPLA, 2), np.nan)
        CI_meds = np.full(plot_helper_fcts.N_LINPLA, np.nan)

        tail_sig_str = f"{tail:9}:"
        linpla_names = []
        for (line, plane), lp_df in perc_sig_df.groupby(["lines", "planes"]):
            li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
            x_index = 2 * li + pl
            linpla_name = plot_helper_fcts.get_line_plane_name(line, plane)
            if len(lp_df) == 0:
            elif len(lp_df) > 1 and not by_mouse:
                raise RuntimeError("Expected a single row per line/plane.")

            lp_df = lp_df.sort_values("mouse_ns") # sort by mouse
            df_indices = lp_df.index.tolist()

            if by_mouse:
                # plot means or medians per mouse
                mouse_data = lp_df[data_key].to_numpy()
                mouse_cols = plot_util.get_hex_color_range(
                    len(lp_df), col=col, 
                mouse_data_mean = math_util.mean_med(
                    mouse_data, stats=analyspar["stats"], nanpol=nanpol
                CI_dummy = np.repeat(mouse_data_mean, 2)
                plot_util.plot_CI(sub_ax, CI_dummy, med=mouse_data_mean, 
                    x=x_index, width=0.6, med_col=col, med_rat=0.01)
                # collect confidence interval data
                row = lp_df.loc[df_indices[0]]
                mouse_cols = [col]
                CIs[x_index] = np.asarray(row[f"{data_key}_null_CIs"])[
                    np.asarray([0, 2])
                CI_meds[x_index] = row[f"{data_key}_null_CIs"][1]

            if by_mouse:
                perc_p_vals = []
                rel_y = 0.05
                tail_sig_str = f"{tail_sig_str}{TAB}{linpla_name}: "
                rel_y = 0.1

            for df_i, mouse_col in zip(df_indices, mouse_cols):
                # plot UFOs
                err = None
                no_line = True
                if not by_mouse:
                    err = perc_sig_df.loc[df_i, f"{data_key}_stds"]
                    no_line = False
                # indicate bootstrapped error with wider capsize
                    sub_ax, x_index, perc_sig_df.loc[df_i, data_key], err,
                    color=mouse_col, capsize=8, no_line=no_line

                # add significance markers
                p_val = perc_sig_df.loc[df_i, f"{data_key}_p_vals"]
                perc = perc_sig_df.loc[df_i, data_key]
                nrois = np.sum(perc_sig_df.loc[df_i, "nrois"])
                side = np.sign(perc - chance)
                sensitivity = misc_analys.get_binom_sensitivity(
                    nrois, null_perc=chance, side=side
                sig_str = misc_analys.get_sig_symbol(
                    p_val, sensitivity=sensitivity, side=side, 
                    tails=permpar["tails"], p_thresh=permpar["p_val"]

                if len(sig_str):
                    perc_high = perc + err if err is not None else perc
                    plot_util.add_signif_mark(sub_ax, x_index, perc_high, 
                        rel_y=rel_y, color=mouse_col, fontsize=24, 

                if by_mouse:
                        (int(np.around(perc)), p_val, sig_str)
                    tail_sig_str = (

            if by_mouse: # sort p-value logging by percentage value
                tail_sig_str = f"{tail_sig_str}\n\t{linpla_name:6}: "
                order = np.argsort([vals[0] for vals in perc_p_vals])
                for i in order:
                    perc, p_val, sig_str = perc_p_vals[i]
                    perc_str = f"(~{perc}%)"
                    tail_sig_str = (
                        f"{tail_sig_str}{TAB}{perc_str:6} "
        # add chance information
        if by_mouse:
                y=chance, ls=plot_helper_fcts.VDASH, c="k", lw=3.0, alpha=0.5, 
            plot_util.plot_CI(sub_ax, CIs.T, med=CI_meds, 
                x=np.arange(n_linpla), width=0.45, med_rat=0.025, zorder=-12), extra={"spacing": TAB})
    for sub_ax in fig.axes:
        sub_ax.tick_params(axis="x", which="both", bottom=False) 
            sub_ax, min_tick=0, max_tick=n_linpla - 1, n=n_linpla, pad_p=0.2)
        sub_ax.set_xticklabels(linpla_names, rotation=90, weight="bold")

    ax[0, 0].set_ylabel("%", fontweight="bold")
    plot_util.set_interm_ticks(ax, 3, axis="y", weight="bold", share=True)

    # adjustment if tick interval is repeated in the negative
    if ax[0, 0].get_ylim()[0] < 0:
        ax[0, 0].set_ylim([ylims[0], ax[0, 0].get_ylim()[1]])

    return ax
def plot_traces_by_qu_lock_sess(analyspar, sesspar, stimpar, extrapar, 
                                quantpar, sess_info, trace_stats, 
                                figpar=None, savedir=None, modif=False):
    plot_traces_by_qu_lock_sess(analyspar, sesspar, stimpar, extrapar, 
                                quantpar, sess_info, trace_stats)

    From dictionaries, plots traces by quantile, locked to transitions from 
    unexpected to expected or v.v. with each session in a separate subplot.
    Returns figure name and save directory path.
    Required args:
        - analyspar (dict)  : dictionary with keys of AnalysPar namedtuple
        - sesspar (dict)    : dictionary with keys of SessPar namedtuple
        - stimpar (dict)    : dictionary with keys of StimPar namedtuple
        - extrapar (dict)   : dictionary containing additional analysis 
            ["analysis"] (str): analysis type (e.g., "l")
            ["datatype"] (str): datatype (e.g., "run", "roi")
        - quantpar (dict)   : dictionary with keys of QuantPar namedtuple
        - sess_info (dict)  : dictionary containing information from each
            ["mouse_ns"] (list)   : mouse numbers
            ["sess_ns"] (list)    : session numbers  
            ["lines"] (list)      : mouse lines
            ["planes"] (list)     : imaging planes
            if datatype == 
            ["nrois"] (list)      : number of ROIs in session

        - trace_stats (dict): dictionary containing trace stats information
            ["xrans"] (list)           : time values for the 2p frames for each 
            ["all_stats"] (list)       : list of 4D arrays or lists of trace 
                                         data statistics across ROIs for each 
                                         session, structured as:
                                            (unexp_len x) quantiles x
                                            stats (me, err) x frames
            ["all_counts"] (array-like): number of sequences, structured as:
                                                sess x (unexp_len x) quantiles
            ["lock"] (str)             : value to which segments are locked:
                                         "unexp", "exp" or "unexp_split"
            ["baseline"] (num)         : number of seconds used for baseline
            ["exp_stats"] (list)       : list of 3D arrays or lists of trace 
                                         data statistics across ROIs for
                                         expected sampled sequences, 
                                         structured as:
                                            quantiles (1) x stats (me, err) 
                                            x frames
            ["exp_counts"] (array-like): number of sequences corresponding to
                                         exp_stats, structured as:
                                            sess x quantiles (1)
            if data is by unexp_len:
            ["unexp_lens"] (list)       : number of consecutive segments for
                                         each unexp_len, structured by session
    Optional args:
        - figpar (dict): dictionary containing the following figure parameter 
                         default: None
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters
        - savedir (str): path of directory in which to save plots.
                         default: None   
        - modif (bool) : if True, modified (slimmed-down) plots are created
                         default: False
        - fulldir (str) : final path of the directory in which the figure is 
                          saved (may differ from input savedir, if datetime 
                          subfolder is added.)
        - savename (str): name under which the figure is saved
    analyspar["dend"] = None
    stimstr_pr = sess_str_util.stim_par_str(
        stimpar["stimtype"], stimpar["visflow_dir"], stimpar["visflow_size"],
        stimpar["gabk"], "print")
    statstr_pr = sess_str_util.stat_par_str(
        analyspar["stats"], analyspar["error"], "print")
    dendstr_pr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"], "print")
    sessstr = sess_str_util.sess_par_str(
        sesspar["sess_n"], stimpar["stimtype"], sesspar["plane"], 
        stimpar["visflow_dir"], stimpar["visflow_size"], stimpar["gabk"])
    dendstr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"])
    basestr = sess_str_util.base_par_str(trace_stats["baseline"])
    basestr_pr = sess_str_util.base_par_str(trace_stats["baseline"], "print")

    datatype = extrapar["datatype"]
    dimstr = sess_str_util.datatype_dim_str(datatype)

    # extract some info from sess_info
    keys = ["mouse_ns", "sess_ns", "lines", "planes"]
    [mouse_ns, sess_ns, lines, planes] = [sess_info[key] for key in keys]
    nroi_strs = sess_str_util.get_nroi_strs(sess_info, empty=(datatype!="roi")) 

    n_sess = len(mouse_ns)

    xrans      = [np.asarray(xran) for xran in trace_stats["xrans"]]
    all_stats  = [np.asarray(sessst) for sessst in trace_stats["all_stats"]]
    exp_stats  = [np.asarray(expst) for expst in trace_stats["exp_stats"]]
    all_counts = trace_stats["all_counts"]
    exp_counts = trace_stats["exp_counts"]

    lock  = trace_stats["lock"]
    col_idx = 0
    if "unexp" in lock:
        lock = "unexp"
        col_idx = 1
    # plot unexp_lens default values
    if stimpar["stimtype"] == "gabors":
        DEFAULT_UNEXP_LEN = [3.0, 4.5, 6.0]
        if stimpar["gabfr"] not in ["any", "all"]:
            offset = sess_str_util.gabfr_nbrs(stimpar["gabfr"])
        DEFAULT_UNEXP_LEN = [2.0, 3.0, 4.0]
    offset = 0
    unexp_lab, len_ext = "", ""
    unexp_lens = [[None]] * n_sess
    unexp_len_default = True
    if "unexp_lens" in trace_stats.keys():
        unexp_len_default = False
        unexp_lens = trace_stats["unexp_lens"]
        len_ext = "_bylen"
        if stimpar["stimtype"] == "gabors":
            unexp_lens = [
                [sl * 1.5/5 - 0.3 * offset for sl in sls] for sls in unexp_lens

    inv = 1 if lock == "unexp" else -1
    if modif:
        st_val = -2.0
        end_val  = 6.0
        n_ticks = int((end_val - st_val) // 2 + 1)
        n_ticks = 21

    if figpar is None:
        figpar = sess_plot_util.init_figpar()
    figpar = copy.deepcopy(figpar)
    if modif:
        figpar["init"]["subplot_wid"] = 6.5
        figpar["init"]["subplot_wid"] *= 2

    fig, ax = plot_util.init_fig(n_sess, **figpar["init"])
    exp_min, exp_max = np.inf, -np.inf
    for i, (stats, counts) in enumerate(zip(all_stats, all_counts)):
        sub_ax = plot_util.get_subax(ax, i)

        # plot expected data
        if exp_stats[i].shape[0] != 1:
            raise ValueError("Expected only one quantile for exp_stats.")

        n_lines = quantpar["n_quants"] * len(unexp_lens[i])
        cols = sess_plot_util.get_quant_cols(n_lines)[0][col_idx]
        if len(cols) < n_lines:
            cols = [None] * n_lines

        if modif:
            line = "2/3" if "23" in lines[i] else "5"
            plane = "somata" if "soma" in planes[i] else "dendrites"
            title = f"M{mouse_ns[i]} - layer {line} {plane}{dendstr_pr}"
            lab = "exp" if i == 0 else None
            y_ax = None if i == 0 else ""

            st, end = 0, len(xrans[i])
            st_vals = list(filter(
                lambda j: xrans[i][j] <= st_val, range(len(xrans[i]))
            end_vals = list(filter(
                lambda j: xrans[i][j] >= end_val, range(len(xrans[i]))
            if len(st_vals) != 0:
                st = st_vals[-1]
            if len(end_vals) != 0:
                end = end_vals[0] + 1
            time_slice = slice(st, end)

            title = (f"Mouse {mouse_ns[i]} - {stimstr_pr}, "
                u"{} ".format(statstr_pr) + f"{lock} locked across {dimstr}"
                f"{basestr_pr}\n(sess {sess_ns[i]}, {lines[i]} {planes[i]}"
            lab = f"exp (no lock) ({exp_counts[i][0]})"
            y_ax = None
            st = 0
            end = len(xrans[i])
            time_slice = slice(None) # use all

        # add length markers
        use_unexp_lens = unexp_lens[i]
        if unexp_len_default:
            use_unexp_lens = DEFAULT_UNEXP_LEN
        leng_col = sess_plot_util.get_quant_cols(1)[0][col_idx][0]
        for leng in use_unexp_lens:
            if leng is None:
            edge = leng * inv
            if edge < 0:
                edge = np.max([xrans[i][st], edge])
            elif edge > 0:
                edge = np.min([xrans[i][end - 1], edge])
                sub_ax, 0, edge, color=leng_col, alpha=0.1)

            sub_ax, fluor=analyspar["fluor"], datatype=datatype, y_ax=y_ax
        plot_util.add_bars(sub_ax, hbars=0)
        alpha = np.min([0.4, 0.8 / n_lines])

        if stimpar["stimtype"] == "gabors":
                sub_ax, xrans[i], offset=offset, bars_omit=[0] + unexp_lens[i]

            sub_ax, xrans[i][time_slice], exp_stats[i][0][0, time_slice], 
            exp_stats[i][0][1:, time_slice], n_xticks=n_ticks,
            alpha=alpha, label=lab, alpha_line=0.8, color="darkgray", 

        # get expected data range to adjust y lims
        exp_min = np.min([exp_min, np.nanmin(exp_stats[i][0][0])])
        exp_max = np.max([exp_max, np.nanmax(exp_stats[i][0][0])])

        n = 0 # count lines plotted
        for s, unexp_len in enumerate(unexp_lens[i]):
            if unexp_len is not None:
                counts, stats = all_counts[i][s], all_stats[i][s]       
                # remove offset   
                unexp_lab = f"unexp len {unexp_len + 0.3 * offset}"
                unexp_lab = "unexp" if modif else f"{lock} lock"
            for q, qu_idx in enumerate(quantpar["qu_idx"]):
                qu_lab = ""
                if quantpar["n_quants"] > 1:
                    qu_lab = "{} ".format(sess_str_util.quantile_str(
                        qu_idx, quantpar["n_quants"], str_type="print"
                lab = f"{qu_lab}{unexp_lab}"
                if modif:
                    lab = lab if i == 0 else None
                    lab = f"{lab} ({counts[q]})"
                if n == 2 and cols[n] is None:
                    sub_ax.plot([], []) # to advance the color cycle (past gray)
                plot_util.plot_traces(sub_ax, xrans[i][time_slice], 
                    stats[q][0, time_slice], stats[q][1:, time_slice], title, 
                    alpha=alpha, label=lab, n_xticks=n_ticks, alpha_line=0.8, 
                    color=cols[n], xticks="auto")
                n += 1
            if unexp_len is not None:
                    sub_ax, hbars=unexp_len, color=sub_ax.lines[-1].get_color(), 
    plot_util.turn_off_extra(ax, n_sess)

    if savedir is None:
        savedir = Path(
            f"{lock}_lock", basestr.replace("_", ""))

    if not modif:
        if stimpar["stimtype"] == "visflow":
            plot_util.rel_confine_ylims(sub_ax, [exp_min, exp_max], 5)

    qu_str = f"_{quantpar['n_quants']}q"
    if quantpar["n_quants"] == 1:
        qu_str = ""
    savename = (f"{datatype}_av_{lock}_lock{len_ext}{basestr}_{sessstr}"
    fulldir = plot_util.savefig(fig, savename, savedir, **figpar["save"])

    return fulldir, savename
def plot_tracked_idxs(idx_only_df, sesspar, figpar, title=None, wide=False):
    plot_tracked_idxs(idx_only_df, sesspar, figpar)

    Plots tracked ROI USIs as individual lines.

    Required args:
        - idx_only_df (pd.DataFrame):
            dataframe with one row per (mouse/)session/line/plane, and the 
            following columns, in addition to the basic sess_df columns:
            - roi_idxs (list): index for each ROI

        - sesspar (dict): 
            dictionary with keys of SessPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - title (str):
            plot title
            default: None
        - wide (bool):
            if True, subplots are wider
            default: False

        - ax (2D array): 
            array of subplots    

    sess_ns = misc_analys.get_sess_ns(sesspar, idx_only_df)

    figpar = sess_plot_util.fig_init_linpla(figpar)

    figpar["init"]["sharey"] = "row"
    figpar["init"]["subplot_hei"] = 4.1
    figpar["init"]["subplot_wid"] = 2.5
    figpar["init"]["gs"] = {"wspace": 0.25, "hspace": 0.2}
    if wide:
        figpar["init"]["subplot_wid"] = 3.3
        figpar["init"]["gs"]["wspace"] = 0.25

    fig, ax = plot_util.init_fig(plot_helper_fcts.N_LINPLA, **figpar["init"])

    if title is not None:
        fig.suptitle(title, y=0.98, weight="bold")

    for (line, plane), lp_df in idx_only_df.groupby(["lines", "planes"]):
        li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
        sub_ax = ax[pl, li]

        # mouse_ns
        lp_mouse_ns = sorted(lp_df["mouse_ns"].unique())

        lp_data = []
        for mouse_n in lp_mouse_ns:
            mouse_df = lp_df.loc[lp_df["mouse_ns"] == mouse_n]
            nrois = mouse_df["nrois"].unique()
            if len(nrois) != 1:
                raise RuntimeError(
                    "Each mouse in idx_stats_df should retain the same number "
                    " of ROIs across sessions.")
            mouse_data = np.full((len(sess_ns), nrois[0]), np.nan)
            for s, sess_n in enumerate(sess_ns):
                rows = mouse_df.loc[mouse_df["sess_ns"] == sess_n]
                if len(rows) == 1:
                    mouse_data[s] = rows.loc[rows.index[0], "roi_idxs"]
                elif len(rows) > 1:
                    raise RuntimeError(
                        "Expected 1 row per line/plane/session/mouse."

        lp_data = np.concatenate(lp_data, axis=1)

            y=0, ls=plot_helper_fcts.HDASH, c="k", lw=3.0, alpha=0.5, 
        sub_ax.plot(sess_ns, lp_data, color=col, lw=2, alpha=0.3)
    # Add plane, line info to plots
        ax, datatype="roi", xticks=sess_ns, ylab="", kind="reg"

    for sub_ax in ax.reshape(-1):
        xticks = sub_ax.get_xticks()
            sub_ax, "x", np.min(xticks), np.max(xticks), n=len(xticks), 
    return ax
def plot_autocorr(analyspar, sesspar, stimpar, extrapar, autocorrpar, 
                  sess_info, autocorr_data, figpar=None, savedir=None):
    plot_autocorr(analyspar, sesspar, stimpar, extrapar, autocorrpar, 
                  sess_info, autocorr_data)

    From dictionaries, plots autocorrelation during stimulus blocks.

    Required args:
        - analyspar (dict)    : dictionary with keys of AnalysPar namedtuple
        - sesspar (dict)      : dictionary with keys of SessPar namedtuple 
        - stimpar (dict)      : dictionary with keys of StimPar namedtuple
        - extrapar (dict)     : dictionary containing additional analysis 
            ["analysis"] (str): analysis type (e.g., "a")
            ["datatype"] (str): datatype (e.g., "run", "roi")
        - autocorrpar (dict)  : dictionary with keys of AutocorrPar namedtuple
        - sess_info (dict)    : dictionary containing information from each
            ["mouse_ns"] (list)   : mouse numbers
            ["sess_ns"] (list)    : session numbers  
            ["lines"] (list)      : mouse lines
            ["planes"] (list)     : imaging planes
            ["nrois"] (list)      : number of ROIs in session

        - autocorr_data (dict): dictionary containing data to plot:
            ["xrans"] (list): list of lag values in seconds for each session
            ["stats"] (list): list of 3D arrays (or nested lists) of
                              autocorrelation statistics, structured as:
                                     sessions stats (me, err) 
                                     x ROI or 1x and 10x lag 
                                     x lag
    Optional args:
        - figpar (dict): dictionary containing the following figure parameter 
                         default: None
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters
        - savedir (str): path of directory in which to save plots.
                         default: None
        - fulldir (str) : final path of the directory in which the figure is 
                          saved (may differ from input savedir, if datetime 
                          subfolder is added.)
        - savename (str): name under which the figure is saved

    statstr_pr = sess_str_util.stat_par_str(
        analyspar["stats"], analyspar["error"], "print")
    stimstr_pr = sess_str_util.stim_par_str(
        stimpar["stimtype"], stimpar["visflow_dir"], stimpar["visflow_size"], 
        stimpar["gabk"], "print")
    dendstr_pr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"], "print")

    sessstr = sess_str_util.sess_par_str(
        sesspar["sess_n"], stimpar["stimtype"], sesspar["plane"], 
        stimpar["visflow_dir"],stimpar["visflow_size"], stimpar["gabk"]) 
    dendstr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"])
    datatype = extrapar["datatype"]
    if datatype == "roi":
        fluorstr_pr = sess_str_util.fluor_par_str(
            analyspar["fluor"], str_type="print")
        if autocorrpar["byitem"]:
            title_str = u"{}\nautocorrelation".format(fluorstr_pr)
            title_str = "\nautocorr. acr. ROIs" .format(fluorstr_pr)

    elif datatype == "run":
        datastr = sess_str_util.datatype_par_str(datatype)
        title_str = u"\n{} autocorrelation".format(datastr)

    if stimpar["stimtype"] == "gabors":
        seq_bars = [-1.5, 1.5] # light lines
        seq_bars = [-1.0, 1.0] # light lines

    # extract some info from sess_info
    keys = ["mouse_ns", "sess_ns", "lines", "planes"]
    [mouse_ns, sess_ns, lines, planes] = [sess_info[key] for key in keys]
    nroi_strs = sess_str_util.get_nroi_strs(sess_info, empty=(datatype!="roi")) 

    n_sess = len(mouse_ns)

    xrans = autocorr_data["xrans"]
    stats = [np.asarray(stat) for stat in autocorr_data["stats"]]

    lag_s = autocorrpar["lag_s"]
    xticks = np.linspace(-lag_s, lag_s, lag_s*2+1)
    yticks = np.linspace(0, 1, 6)

    if figpar is None:
        figpar = sess_plot_util.init_figpar()

    byitemstr = ""
    if autocorrpar["byitem"]:
        byitemstr = "_byroi"

    fig, ax = plot_util.init_fig(n_sess, **figpar["init"])
    for i in range(n_sess):
        sub_ax = plot_util.get_subax(ax, i)
        title = (f"Mouse {mouse_ns[i]} - {stimstr_pr}, "
            u"{} ".format(statstr_pr) + f"{title_str} (sess "
            f"{sess_ns[i]}, {lines[i]} {planes[i]}{dendstr_pr}{nroi_strs[i]})")
        # transpose to ROI/lag x stats x series
        sess_stats = stats[i].transpose(1, 0, 2) 
        for s, sub_stats in enumerate(sess_stats):
            lab = None
            if not autocorrpar["byitem"]:
                lab = ["actual lag", "10x lag"][s]

                sub_ax, xrans[i], sub_stats[0], sub_stats[1:], xticks=xticks, 
                yticks=yticks, alpha=0.2, label=lab)

        plot_util.add_bars(sub_ax, hbars=seq_bars)
        sub_ax.set_ylim([0, 1])
        sub_ax.set_title(title, y=1.02)
        if sub_ax.is_last_row():
            sub_ax.set_xlabel("Lag (s)")

    plot_util.turn_off_extra(ax, n_sess)

    if savedir is None:
        savedir = Path(

    savename = (f"{datatype}_autocorr{byitemstr}_{sessstr}{dendstr}")

    fulldir = plot_util.savefig(fig, savename, savedir, **figpar["save"])

    return fulldir, savename
def plot_tracked_idx_stats(idx_stats_df, sesspar, figpar, permpar=None,
                           absolute=True, between_sess_sig=True, 
                           by_mouse=False, bootstr_err=None, title=None, 
    plot_tracked_idx_stats(idx_stats_df, sesspar, figpar)

    Plots tracked ROI USI statistics.

    Required args:
        - idx_stats_df (pd.DataFrame):
            dataframe with one row per session, and the following columns, in 
            addition to the basic sess_df columns:
            - roi_idxs (list): index statistics
            - abs_roi_idxs (list): absolute index statistics
        - sesspar (dict): 
            dictionary with keys of SessPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - permpar (dict): 
            dictionary with keys of PermPar namedtuple. Required if 
            between_sess_sig is True.
            default: None
        - absolute (bool):
            if True, data statistics are on absolute ROI indices
            default: True
        - between_sess_sig (bool):
            if True, significance between sessions is logged and plotted
            default: True
        - by_mouse (bool):
            if True, plotting is done per mouse
            default: False
        - bootstr_err (bool):
            if True, error is bootstrapped standard deviation
            default: False
        - title (str):
            plot title
            default: None
        - wide (bool):
            if True, subplots are wider
            default: False

        - ax (2D array): 
            array of subplots    

    sess_ns = misc_analys.get_sess_ns(sesspar, idx_stats_df)

    figpar = sess_plot_util.fig_init_linpla(figpar)

    figpar["init"]["sharey"] = "row"
    figpar["init"]["subplot_hei"] = 4.1
    figpar["init"]["subplot_wid"] = 2.6
    figpar["init"]["gs"] = {"wspace": 0.25, "hspace": 0.2}
    if wide:
        figpar["init"]["subplot_wid"] = 3.3

    fig, ax = plot_util.init_fig(plot_helper_fcts.N_LINPLA, **figpar["init"])

    if title is not None:
        fig.suptitle(title, y=0.98, weight="bold")
    data_col = "roi_idx_stats"
    if absolute:
        data_col = f"abs_{data_col}"
    if data_col not in idx_stats_df.columns:
        raise KeyError(f"Expected to find {data_col} in idx_stats_df columns.")

    for (line, plane), lp_df in idx_stats_df.groupby(["lines", "planes"]):
        li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs(line, plane)
        sub_ax = ax[pl, li]

            y=0, ls=plot_helper_fcts.HDASH, c="k", lw=3.0, alpha=0.5, 

        mouse_ns = ["any"]
        mouse_cols = [col]
        if by_mouse:
            mouse_ns = sorted(lp_df["mouse_ns"].unique())
            mouse_cols = plot_util.get_hex_color_range(
                len(mouse_ns), col=col, 

        for mouse_n, mouse_col in zip(mouse_ns, mouse_cols):
            sub_df = lp_df
            if by_mouse:
                sub_df = lp_df.loc[lp_df["mouse_ns"] == mouse_n]
            sess_indices = []
            sub_sess_ns = []

            for sess_n in sess_ns:
                rows = sub_df.loc[sub_df["sess_ns"] == sess_n]
                if len(rows) == 1:

            data = np.asarray([sub_df.loc[i, data_col] for i in sess_indices])

            # plot errorbars
            alpha = 0.6 if by_mouse else 0.8
            capsize = 8 if bootstr_err else None
                sub_ax, data[:, 0], data[:, 1:].T, sub_sess_ns, color=mouse_col, 
                alpha=alpha, xticks="auto", line_dash=dash, capsize=capsize,

    if between_sess_sig:
        if permpar is None:
            raise ValueError(
                "If 'between_sess_sig' is True, must provide permpar."
        if by_mouse:
            raise NotImplementedError(
                "Plotting between session statistical signifiance is not "
                "implemented if 'by_mouse' if True."

            ax, idx_stats_df, permpar, data_col=data_col
    # Add plane, line info to plots
        ax, datatype="roi", xticks=sess_ns, ylab="", kind="reg"
    return ax
def plot_stim_data_df(stim_data_df, stimpar, permpar, figpar, pop_stats=True, 
    plot_stim_data_df(stim_data_df, stimpar, permpar, figpar)

    Plots stimulus comparison data.

    Required args:
        - stim_stats_df (pd.DataFrame):
            dataframe with one row per line/plane and one for all line/planes 
            together, and the basic sess_df columns, in addition to, 
            for each stimtype:
            - stimtype (list): absolute fractional change statistics (me, err)
            - raw_p_vals (float): uncorrected p-value for data differences 
                between stimulus types 
            - p_vals (float): p-value for data differences between stimulus 
                types, corrected for multiple comparisons and tails
        - stimpar (dict): 
            dictionary with keys of StimPar namedtuple
        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - pop_stats (bool):
            if True, analyses are run on population statistics, and not 
            individual tracked ROIs
            default: True
        - title (str):
            plot title
            default: None
        - ax (2D array): 
            array of subplots 
            (does not include added subplot for all line/plane data together)

    figpar = sess_plot_util.fig_init_linpla(figpar, kind="reg")

    figpar["init"]["subplot_wid"] = 2.1
    figpar["init"]["subplot_hei"] = 4.2
    figpar["init"]["gs"] = {"hspace": 0.20, "wspace": 0.3}
    figpar["init"]["sharey"] = "row"
    fig, ax = plot_util.init_fig(plot_helper_fcts.N_LINPLA, **figpar["init"])
    fig.suptitle(title, y=0.98, weight="bold")

    sub_ax_all = fig.add_axes([1.05, 0.11, 0.3, 0.77])

    stimtypes = stimpar["stimtype"][:] # deep copy

    # indicate bootstrapped error with wider capsize
    capsize = 8 if pop_stats else 6

    lp_data = []
    cols = []
    for (line, plane), lp_df in stim_data_df.groupby(["lines", "planes"]):
        x = [0, 1]
        data = np.vstack(
            [lp_df[stimtypes[0]].tolist(), lp_df[stimtypes[1]].tolist()]
        y = data[0]
        err = data[1:]

        if line != "all" and plane != "all":
            li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs(
                line, plane
            alpha = 0.5
            sub_ax = ax[pl, li]
            col = plot_helper_fcts.NEARBLACK
            dash = None
            alpha = 0.2
            sub_ax = sub_ax_all
            sub_ax.set_title("all", fontweight="bold")
            sub_ax, x, y=y, err=err, width=0.5, lw=None, alpha=alpha, 
            color=col, ls=dash, capsize=capsize

    # add dots to the all subplot
    x_vals = np.asarray([-0.17, 0.25, -0.25, 0.17]) # to spread dots out
    lw = 4
    ms = 200
    for s, _ in enumerate(stimtypes):
        lp_stim_data = [data[s] for data in lp_data]
        sorter = np.argsort(lp_stim_data)
        for i, idx in enumerate(sorter):
            x_val = s + x_vals[i]
            # white behind
                x=x_val, y=lp_stim_data[idx], s=ms, linewidth=lw, alpha=0.8, 
                color="white", zorder=10
            # colored dots
                x=x_val, y=lp_stim_data[idx], s=ms, alpha=0.6, linewidth=0, 
                color=cols[idx], zorder=11
            # dot borders
                x=x_val, y=lp_stim_data[idx], s=ms, color="None", 
                edgecolor=cols[idx], linewidth=lw, alpha=1, zorder=12

    # add between stim significance 
    add_between_stim_sig(ax, sub_ax_all, stim_data_df, permpar)

    # add plane, line info to plots
    sess_plot_util.format_linpla_subaxes(ax, datatype="roi", lines=None, 
        planes=["", ""], xticks=[0, 1], ylab="Absolute fractional change", 
        kind="reg", xlab=""
    # adjust plot details
    stimtype_names = stimtypes[:]
    stimtype_names[stimtypes.index("visflow")] = "visual\nflow"
    for sub_ax in fig.axes:
        y_max = sub_ax.get_ylim()[1]
        sub_ax.set_ylim([0, y_max])
        sub_ax.set_xticks([0, 1])
            stimtypes, weight="bold", rotation=45, ha="right"
        sub_ax.tick_params(axis="x", bottom=False)
    sub_ax_all.set_xlim(ax[0, 0].get_xlim())
        np.asarray(sub_ax_all), 4, axis="y", share=False, weight="bold"

    return ax
def plot_pupil_run_trace_stats(trace_df, analyspar, figpar, split="by_exp", 
    plot_pupil_run_trace_stats(trace_df, analyspar, figpar)

    Plots pupil and running trace statistics.

    Required args:
        - trace_df (pd.DataFrame):
            dataframe with one row per session number, and the following 
            columns, in addition to the basic sess_df columns: 
            - run_trace_stats (list): 
                running velocity trace stats (split x frames x stats (me, err))
            - run_time_values (list):
                values for each frame, in seconds
                (only 0 to, unless split is "by_exp")
            - pupil_trace_stats (list): 
                pupil diameter trace stats (split x frames x stats (me, err))
            - pupil_time_values (list):
                values for each frame, in seconds
                (only 0 to, unless split is "by_exp")    

        - analyspar (dict): 
            dictionary with keys of AnalysPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - split (str):
            data split, e.g. "exp_lock", "unexp_lock", "stim_onset" or 
            default: False
        - title (str):
            plot title
            default: None

        - ax (2D array): 
            array of subplots

    if split != "by_exp":
        raise NotImplementedError("Only implemented split 'by_exp'.")

    if analyspar["scale"]:
        raise NotImplementedError(
            "Expected running and pupil data to not be scaled."

    datatypes = ["run", "pupil"]

    figpar["init"]["subplot_wid"] = 4.2
    figpar["init"]["subplot_hei"] = 2.2
    figpar["init"]["gs"] = {"hspace": 0.3}
    figpar["init"]["ncols"] = 1
    figpar["init"]["sharey"] = False
    figpar["init"]["sharex"] = True

    fig, ax = plot_util.init_fig(len(datatypes), **figpar["init"])

    if title is not None:
        fig.suptitle(title, weight="bold", y=1.0)

    if len(trace_df) != 1:
        raise NotImplementedError(
            "Only implemented for a trace_df with one row."
    row_idx = trace_df.index[0]

    exp_col = plot_util.LINCLAB_COLS["gray"]
    unexp_col = plot_util.LINCLAB_COLS["red"]
    for d, datatype in enumerate(datatypes):
        sub_ax = ax[d, 0]

        time_values = trace_df.loc[row_idx, f"{datatype}_time_values"]
        trace_stats = trace_df.loc[row_idx, f"{datatype}_trace_stats"]

            sub_ax, time_values, trace_stats, split=split, col=unexp_col, 
            lab=False, exp_col=exp_col, hline=False
        if datatype == "run":
            ylabel = "Running\nvelocity\n(cm/s)"
        elif datatype == "pupil":
            ylabel = "Pupil\ndiameter\n(mm)"
        sub_ax.set_ylabel(ylabel, weight="bold")

   # fix x ticks and lims
    plot_util.set_interm_ticks(ax, 3, axis="x", fontweight="bold")
    xlims = [np.min(time_values), np.max(time_values)]
    if split != "by_exp":
        xlims = [-xlims[1], xlims[1]]
    sub_ax.set_xlabel("Time (s)", weight="bold")

    # expand y lims a bit and fix y ticks
    for sub_ax in ax.reshape(-1):
        plot_util.expand_lims(sub_ax, axis="y", prop=0.21)

        ax, 2, axis="y", share=False, weight="bold", update_ticks=True

    return ax