Ejemplo n.º 1
0
def run_full_traces(sessions,
                    analysis,
                    analyspar,
                    sesspar,
                    figpar,
                    datatype="roi"):
    """
    run_full_traces(sessions, analysis, analyspar, sesspar, figpar)

    Plots full traces across an entire session. If ROI traces are plotted,
    each ROI is scaled and plotted separately and an average is plotted.
    
    Saves results and parameters relevant to analysis in a dictionary.

    Required args:
        - sessions (list)      : list of Session objects
        - analysis (str)       : analysis type (e.g., "f")
        - analyspar (AnalysPar): named tuple containing analysis parameters
        - sesspar (SessPar)    : named tuple containing session parameters
        - figpar (dict)        : dictionary containing figure parameters
    
    Optional args:
        - datatype (str): type of data (e.g., "roi", "run")
    """

    dendstr_pr = sess_str_util.dend_par_str(analyspar.dend, sesspar.plane,
                                            datatype, "print")

    sessstr_pr = (f"session: {sesspar.sess_n}, "
                  f"plane: {sesspar.plane}{dendstr_pr}")

    datastr = sess_str_util.datatype_par_str(datatype)

    logger.info(
        f"Plotting {datastr} traces across an entire "
        f"session\n({sessstr_pr}).",
        extra={"spacing": "\n"})

    figpar = copy.deepcopy(figpar)
    if figpar["save"]["use_dt"] is None:
        figpar["save"]["use_dt"] = gen_util.create_time_str()

    all_tr, roi_tr, all_edges, all_pars = [], [], [], []
    for sess in sessions:
        # get the block edges and parameters
        edge_fr, par_descrs = [], []
        for stim in sess.stims:
            stim_str = stim.stimtype
            if stim.stimtype == "visflow":
                stim_str = "vis. flow"
            if datatype == "roi":
                fr_type = "twop"
            elif datatype == "run":
                fr_type = "stim"
            else:
                gen_util.accepted_values_error("datatype", datatype,
                                               ["roi", "run"])
            for b in stim.block_params.index:
                row = stim.block_params.loc[b]
                edge_fr.append([
                    int(row[f"start_frame_{fr_type}"]),
                    int(row[f"stop_frame_{fr_type}"])
                ])
                par_vals = [row[param] for param in stim.stim_params]
                pars_str = "\n".join([str(par) for par in par_vals][0:2])

                par_descrs.append(
                    sess_str_util.pars_to_descr(
                        f"{stim_str.capitalize()}\n{pars_str}"))

        if datatype == "roi":
            if sess.only_tracked_rois != analyspar.tracked:
                raise RuntimeError(
                    "sess.only_tracked_rois should match analyspar.tracked.")
            nanpol = None
            if not analyspar.rem_bad:
                nanpol = "omit"
            all_rois = gen_util.reshape_df_data(sess.get_roi_traces(
                None, analyspar.fluor, analyspar.rem_bad,
                analyspar.scale)["roi_traces"],
                                                squeeze_cols=True)
            full_tr = math_util.get_stats(all_rois,
                                          analyspar.stats,
                                          analyspar.error,
                                          axes=0,
                                          nanpol=nanpol).tolist()
            roi_tr.append(all_rois.tolist())
        elif datatype == "run":
            full_tr = sess.get_run_velocity(
                rem_bad=analyspar.rem_bad,
                scale=analyspar.scale).to_numpy().squeeze().tolist()
            roi_tr = None
        all_tr.append(full_tr)
        all_edges.append(edge_fr)
        all_pars.append(par_descrs)

    extrapar = {
        "analysis": analysis,
        "datatype": datatype,
    }

    trace_info = {
        "all_tr": all_tr,
        "all_edges": all_edges,
        "all_pars": all_pars
    }

    sess_info = sess_gen_util.get_sess_info(sessions,
                                            analyspar.fluor,
                                            incl_roi=(datatype == "roi"),
                                            rem_bad=analyspar.rem_bad)

    info = {
        "analyspar": analyspar._asdict(),
        "sesspar": sesspar._asdict(),
        "extrapar": extrapar,
        "sess_info": sess_info,
        "trace_info": trace_info
    }

    fulldir, savename = gen_plots.plot_full_traces(roi_tr=roi_tr,
                                                   figpar=figpar,
                                                   **info)
    file_util.saveinfo(info, savename, fulldir, "json")
Ejemplo n.º 2
0
def run_trace_corr_acr_sess(sessions,
                            analysis,
                            analyspar,
                            sesspar,
                            stimpar,
                            figpar,
                            datatype="roi"):
    """
    run_trace_corr_acr_sess(sessions, analysis, analyspar, sesspar, 
                            stimpar, quantpar, figpar)

    Retrieves trace statistics by session x unexp val and calculates 
    correlations across sessions per unexp val.
    
    Currently only logs results to the console. Does NOT save results and 
    parameters relevant to analysis in a dictionary.

    Required args:
        - sessions (list)      : list of Session objects
        - analysis (str)       : analysis type (e.g., "r")
        - analyspar (AnalysPar): named tuple containing analysis parameters
        - sesspar (SessPar)    : named tuple containing session parameters
        - stimpar (StimPar)    : named tuple containing stimulus parameters
        - figpar (dict)        : dictionary containing figure parameters
    
    Optional args:
        - datatype (str): type of data (e.g., "roi", "run")
    """

    sessstr_pr = sess_str_util.sess_par_str(sesspar.sess_n, stimpar.stimtype,
                                            sesspar.plane, stimpar.visflow_dir,
                                            stimpar.visflow_size, stimpar.gabk,
                                            "print")
    # dendstr_pr = sess_str_util.dend_par_str(
    # analyspar.dend, sesspar.plane, datatype, "print")

    datastr = sess_str_util.datatype_par_str(datatype)

    if sesspar.plane in ["any", "all"] and sesspar.runtype == "pilot":
        logger.warning("Planes may not match between sessions for a mouse!")

    logger.info(
        "Analysing and plotting correlations between unexpected vs "
        f"expected {datastr} traces between sessions ({sessstr_pr}).",
        extra={"spacing": "\n"})

    figpar = copy.deepcopy(figpar)
    if figpar["save"]["use_dt"] is None:
        figpar["save"]["use_dt"] = gen_util.create_time_str()

    prev_level = logger.level
    if prev_level > logging.INFO:
        logger.setLevel(logging.INFO)
        logger.warning("Temporarily lowered log level for correlation "
                       "analysis results.")

    unexps = ["exp", "unexp"]

    # correlate average traces between sessions for each mouse and each
    # unexpected value
    all_counts = []
    all_me_tr = []
    all_corrs = []
    logger.info("Intramouse correlations", extra={"spacing": "\n"})
    for sess_grp in sessions:
        logger.info(f"Mouse {sess_grp[0].mouse_n}, sess {sess_grp[0].sess_n} "
                    f"vs {sess_grp[1].sess_n} corr:")
        trace_info = quant_analys.trace_stats_by_qu_sess(sess_grp,
                                                         analyspar,
                                                         stimpar,
                                                         1, [0],
                                                         byroi=False,
                                                         by_exp=True,
                                                         datatype=datatype)
        # remove quant dim
        grp_stats = np.asarray(trace_info[1]).squeeze(2)
        all_counts.append([[qu_c[0] for qu_c in c] for c in trace_info[2]])
        # get mean/median per grp (sess x unexp_val x frame)
        grp_me = grp_stats[:, :, 0]
        grp_corrs = []
        # collect correlations
        corrs = [
            st.pearsonr(grp_me[0, s], grp_me[1, s]) for s in range(len(unexps))
        ]
        corr_max = np.argmax([corr[0] for corr in corrs])
        for s, (unexp, corr) in enumerate(zip(unexps, corrs)):
            sig_str = "*" if corr[1] < 0.05 else ""
            high_str = " +" if corr_max == s else ""
            logger.info(
                f"{unexp}: {corr[0]:.4f} "
                f"(p={corr[1]:.2f}{sig_str}){high_str}",
                extra={"spacing": TAB})
            corr = corr[0]
            grp_corrs.append(corr)
        all_corrs.append(grp_corrs)
        all_me_tr.append(grp_me)

    # mice x sess x unexp x frame
    all_me_tr = np.asarray(all_me_tr)
    logger.info("Intermouse correlations", extra={"spacing": "\n"})
    all_mouse_corrs = []
    for n, m1_sess_mes in enumerate(all_me_tr):
        if n + 1 < len(all_me_tr):
            mouse_corrs = []
            for n_add, m2_sess_mes in enumerate(all_me_tr[n + 1:]):
                sess_corrs = []
                logger.info(f"Mouse {sessions[n][0].mouse_n} vs "
                            f"{sessions[n + 1 + n_add][0].mouse_n} corr:")
                for se, m1_s1_me in enumerate(m1_sess_mes):
                    unexp_corrs = []
                    logger.info(f"sess {sessions[n][se].sess_n}:",
                                extra={"spacing": TAB})
                    # collect correlations
                    corrs = [
                        st.pearsonr(m1_s1_me[s], m2_sess_mes[se][s])
                        for s in range(len(unexps))
                    ]
                    corr_max = np.argmax([corr[0] for corr in corrs])
                    for s, (unexp, corr) in enumerate(zip(unexps, corrs)):
                        sig_str = "*" if corr[1] < 0.05 else ""
                        high_str = " +" if corr_max == s else ""
                        logger.info(
                            f"{unexp}: {corr[0]:.4f} "
                            f"(p={corr[1]:.2f}{sig_str}){high_str}",
                            extra={"spacing": f"{TAB}{TAB}"})
                        corr = corr[0]
                        unexp_corrs.append(corr)
                    sess_corrs.append(unexp_corrs)
                mouse_corrs.append(sess_corrs)
            all_mouse_corrs.append(mouse_corrs)

    # reset logger level
    logger.setLevel(prev_level)
Ejemplo n.º 3
0
def run_traces_by_qu_unexp_sess(sessions,
                                analysis,
                                analyspar,
                                sesspar,
                                stimpar,
                                quantpar,
                                figpar,
                                datatype="roi"):
    """
    run_traces_by_qu_unexp_sess(sessions, analysis, analyspar, sesspar, 
                               stimpar, quantpar, figpar)

    Retrieves trace statistics by session x unexp val x quantile and
    plots traces across ROIs by quantile/unexpected with each session in a 
    separate subplot.
    
    Also runs analysis for one quantile (full data).
    
    Saves results and parameters relevant to analysis in a dictionary.

    Required args:
        - sessions (list)      : list of Session objects
        - analysis (str)       : analysis type (e.g., "t")
        - analyspar (AnalysPar): named tuple containing analysis parameters
        - sesspar (SessPar)    : named tuple containing session parameters
        - stimpar (StimPar)    : named tuple containing stimulus parameters
        - quantpar (QuantPar)  : named tuple containing quantile analysis 
                                 parameters
        - figpar (dict)        : dictionary containing figure parameters
    
    Optional args:
        - datatype (str): type of data (e.g., "roi", "run")
    """

    sessstr_pr = sess_str_util.sess_par_str(sesspar.sess_n, stimpar.stimtype,
                                            sesspar.plane, stimpar.visflow_dir,
                                            stimpar.visflow_size, stimpar.gabk,
                                            "print")
    dendstr_pr = sess_str_util.dend_par_str(analyspar.dend, sesspar.plane,
                                            datatype, "print")

    datastr = sess_str_util.datatype_par_str(datatype)

    logger.info(
        f"Analysing and plotting unexpected vs expected {datastr} "
        f"traces by quantile ({quantpar.n_quants}) \n({sessstr_pr}"
        f"{dendstr_pr}).",
        extra={"spacing": "\n"})

    # modify quantpar to retain all quantiles
    quantpar_one = sess_ntuple_util.init_quantpar(1, 0)
    n_quants = quantpar.n_quants
    quantpar_mult = sess_ntuple_util.init_quantpar(n_quants, "all")

    figpar = copy.deepcopy(figpar)
    if figpar["save"]["use_dt"] is None:
        figpar["save"]["use_dt"] = gen_util.create_time_str()

    for quantpar in [quantpar_one, quantpar_mult]:
        logger.info(f"{quantpar.n_quants} quant", extra={"spacing": "\n"})
        # get the stats (all) separating by session, unexpected and quantiles
        trace_info = quant_analys.trace_stats_by_qu_sess(sessions,
                                                         analyspar,
                                                         stimpar,
                                                         quantpar.n_quants,
                                                         quantpar.qu_idx,
                                                         byroi=False,
                                                         by_exp=True,
                                                         datatype=datatype)

        extrapar = {
            "analysis": analysis,
            "datatype": datatype,
        }

        xrans = [xran.tolist() for xran in trace_info[0]]
        all_stats = [sessst.tolist() for sessst in trace_info[1]]
        trace_stats = {
            "xrans": xrans,
            "all_stats": all_stats,
            "all_counts": trace_info[2]
        }

        sess_info = sess_gen_util.get_sess_info(sessions,
                                                analyspar.fluor,
                                                incl_roi=(datatype == "roi"),
                                                rem_bad=analyspar.rem_bad)

        info = {
            "analyspar": analyspar._asdict(),
            "sesspar": sesspar._asdict(),
            "stimpar": stimpar._asdict(),
            "quantpar": quantpar._asdict(),
            "extrapar": extrapar,
            "sess_info": sess_info,
            "trace_stats": trace_stats
        }

        fulldir, savename = gen_plots.plot_traces_by_qu_unexp_sess(
            figpar=figpar, **info)
        file_util.saveinfo(info, savename, fulldir, "json")
Ejemplo n.º 4
0
def run_traces_by_qu_lock_sess(sessions,
                               analysis,
                               seed,
                               analyspar,
                               sesspar,
                               stimpar,
                               quantpar,
                               figpar,
                               datatype="roi"):
    """
    run_traces_by_qu_lock_sess(sessions, analysis, analyspar, sesspar, 
                               stimpar, quantpar, figpar)

    Retrieves trace statistics by session x quantile at the transition of
    expected to unexpected sequences (or v.v.) and plots traces across ROIs by 
    quantile with each session in a separate subplot.
    
    Also runs analysis for one quantile (full data) with different unexpected 
    lengths grouped separated 
    
    Saves results and parameters relevant to analysis in a dictionary.

    Required args:
        - sessions (list)      : list of Session objects
        - analysis (str)       : analysis type (e.g., "l")
        - seed (int)           : seed value to use. (-1 treated as None)
        - analyspar (AnalysPar): named tuple containing analysis parameters
        - sesspar (SessPar)    : named tuple containing session parameters
        - stimpar (StimPar)    : named tuple containing stimulus parameters
        - quantpar (QuantPar)  : named tuple containing quantile analysis 
                                 parameters
        - figpar (dict)        : dictionary containing figure parameters
    
    Optional args:
        - datatype (str): type of data (e.g., "roi", "run")

    """

    sessstr_pr = sess_str_util.sess_par_str(sesspar.sess_n, stimpar.stimtype,
                                            sesspar.plane, stimpar.visflow_dir,
                                            stimpar.visflow_size, stimpar.gabk,
                                            "print")
    dendstr_pr = sess_str_util.dend_par_str(analyspar.dend, sesspar.plane,
                                            datatype, "print")

    datastr = sess_str_util.datatype_par_str(datatype)

    logger.info(
        f"Analysing and plotting unexpected vs expected {datastr} "
        f"traces locked to unexpected onset by quantile ({quantpar.n_quants}) "
        f"\n({sessstr_pr}{dendstr_pr}).",
        extra={"spacing": "\n"})

    seed = rand_util.seed_all(seed, "cpu", log_seed=False)

    # modify quantpar to retain all quantiles
    quantpar_one = sess_ntuple_util.init_quantpar(1, 0)
    n_quants = quantpar.n_quants
    quantpar_mult = sess_ntuple_util.init_quantpar(n_quants, "all")

    if stimpar.stimtype == "visflow":
        pre_post = [2.0, 6.0]
    elif stimpar.stimtype == "gabors":
        pre_post = [2.0, 8.0]
    else:
        gen_util.accepted_values_error("stimpar.stimtype", stimpar.stimtype,
                                       ["visflow", "gabors"])
    logger.warning("Setting pre to {}s and post to {}s.".format(*pre_post))

    stimpar = sess_ntuple_util.get_modif_ntuple(stimpar, ["pre", "post"],
                                                pre_post)

    figpar = copy.deepcopy(figpar)
    if figpar["save"]["use_dt"] is None:
        figpar["save"]["use_dt"] = gen_util.create_time_str()

    for baseline in [None, stimpar.pre]:
        basestr_pr = sess_str_util.base_par_str(baseline, "print")
        for quantpar in [quantpar_one, quantpar_mult]:
            locks = ["unexp", "exp"]
            if quantpar.n_quants == 1:
                locks.append("unexp_split")
            # get the stats (all) separating by session and quantiles
            for lock in locks:
                logger.info(
                    f"{quantpar.n_quants} quant, {lock} lock{basestr_pr}",
                    extra={"spacing": "\n"})
                if lock == "unexp_split":
                    trace_info = quant_analys.trace_stats_by_exp_len_sess(
                        sessions,
                        analyspar,
                        stimpar,
                        quantpar.n_quants,
                        quantpar.qu_idx,
                        byroi=False,
                        nan_empty=True,
                        baseline=baseline,
                        datatype=datatype)
                else:
                    trace_info = quant_analys.trace_stats_by_qu_sess(
                        sessions,
                        analyspar,
                        stimpar,
                        quantpar.n_quants,
                        quantpar.qu_idx,
                        byroi=False,
                        lock=lock,
                        nan_empty=True,
                        baseline=baseline,
                        datatype=datatype)

                # for comparison, locking to middle of expected sample (1 quant)
                exp_samp = quant_analys.trace_stats_by_qu_sess(
                    sessions,
                    analyspar,
                    stimpar,
                    quantpar_one.n_quants,
                    quantpar_one.qu_idx,
                    byroi=False,
                    lock="exp_samp",
                    nan_empty=True,
                    baseline=baseline,
                    datatype=datatype)

                extrapar = {
                    "analysis": analysis,
                    "datatype": datatype,
                    "seed": seed,
                }

                xrans = [xran.tolist() for xran in trace_info[0]]
                all_stats = [sessst.tolist() for sessst in trace_info[1]]
                exp_stats = [expst.tolist() for expst in exp_samp[1]]
                trace_stats = {
                    "xrans": xrans,
                    "all_stats": all_stats,
                    "all_counts": trace_info[2],
                    "lock": lock,
                    "baseline": baseline,
                    "exp_stats": exp_stats,
                    "exp_counts": exp_samp[2]
                }

                if lock == "unexp_split":
                    trace_stats["unexp_lens"] = trace_info[3]

                sess_info = sess_gen_util.get_sess_info(
                    sessions,
                    analyspar.fluor,
                    incl_roi=(datatype == "roi"),
                    rem_bad=analyspar.rem_bad)

                info = {
                    "analyspar": analyspar._asdict(),
                    "sesspar": sesspar._asdict(),
                    "stimpar": stimpar._asdict(),
                    "quantpar": quantpar._asdict(),
                    "extrapar": extrapar,
                    "sess_info": sess_info,
                    "trace_stats": trace_stats
                }

                fulldir, savename = gen_plots.plot_traces_by_qu_lock_sess(
                    figpar=figpar, **info)
                file_util.saveinfo(info, savename, fulldir, "json")
Ejemplo n.º 5
0
def plot_glm_expl_var(analyspar, sesspar, stimpar, extrapar, glmpar,
                      sess_info, all_expl_var, figpar=None, savedir=None):
    """
    plot_glm_expl_var(analyspar, sesspar, stimpar, extrapar, 
                       sess_info, all_expl_var)

    From dictionaries, plots explained variance for different variables for 
    each ROI.

    Required args:
        - analyspar (dict)    : dictionary with keys of AnalysPar namedtuple
        - sesspar (dict)      : dictionary with keys of SessPar namedtuple 
        - stimpar (dict)      : dictionary with keys of StimPar namedtuple
        - glmpar (dict)       : dictionary with keys of GLMPar namedtuple
        - extrapar (dict)     : dictionary containing additional analysis 
                                parameters
            ["analysis"] (str): analysis type (e.g., "v")
        - sess_info (dict)    : dictionary containing information from each
                                session 
            ["mouse_ns"] (list)   : mouse numbers
            ["sess_ns"] (list)    : session numbers  
            ["lines"] (list)      : mouse lines
            ["planes"] (list)     : imaging planes
            ["nrois"] (list)      : number of ROIs in session

        - all_expl_var (list) : list of dictionaries with explained variance 
                                for each session set, with each glm 
                                coefficient as a key:
            ["full"] (list)    : list of full explained variance stats for 
                                 every ROI, structured as ROI x stats
            ["coef_all"] (dict): max explained variance for each ROI with each
                                 coefficient as a key, structured as ROI x stats
            ["coef_uni"] (dict): unique explained variance for each ROI with 
                                 each coefficient as a key, 
                                 structured as ROI x stats
            ["rois"] (list)    : ROI numbers (-1 for GLMs fit to 
                                 mean/median ROI activity)
    
    Optional args:
        - figpar (dict): dictionary containing the following figure parameter 
                         dictionaries
                         default: None
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters
        - savedir (str): path of directory in which to save plots.
                         default: None
    
    Returns:
        - fulldir (str) : final path of the directory in which the figure is 
                          saved (may differ from input savedir, if datetime 
                          subfolder is added.)
        - savename (str): name under which the figure is saved
    """

    stimstr_pr = sess_str_util.stim_par_str(
        stimpar["stimtype"], stimpar["visflow_dir"], stimpar["visflow_size"], 
        stimpar["gabk"], "print")
    dendstr_pr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], "roi", "print")

    sessstr = sess_str_util.sess_par_str(
        sesspar["sess_n"], stimpar["stimtype"], sesspar["plane"], 
        stimpar["visflow_dir"], stimpar["visflow_size"], stimpar["gabk"]) 
    dendstr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], "roi")

    # extract some info from sess_info
    keys = ["mouse_ns", "sess_ns", "lines", "planes"]
    [mouse_ns, sess_ns, lines, planes] = [sess_info[key] for key in keys]

    n_sess = len(mouse_ns)
    
    nroi_strs = sess_str_util.get_nroi_strs(sess_info, style="par")

    plot_bools = [ev["rois"] not in [[-1], "all"] for ev in all_expl_var]
    n_sess = sum(plot_bools)

    if stimpar["stimtype"] == "gabors":
        xyzc_dims = ["unexpected", "gabor_frame", "run_data", "pup_diam_data"]
        log_dims = xyzc_dims + ["gabor_mean_orientation"]
    elif stimpar["stimtype"] == "visflow":
        xyzc_dims = [
            "unexpected", "main_flow_direction", "run_data", "pup_diam_data"
            ]
        log_dims = xyzc_dims
    
    # start plotting
    logger.info("Plotting GLM full and unique explained variance for "
        f"{', '.join(xyzc_dims)}.", extra={"spacing": "\n"})

    if n_sess > 0:
        if figpar is None:
            figpar = sess_plot_util.init_figpar()

        figpar = copy.deepcopy(figpar)
        cmap = plot_util.linclab_colormap(nbins=100, no_white=True)

        if figpar["save"]["use_dt"] is None:
            figpar["save"]["use_dt"] = gen_util.create_time_str()
        figpar["init"]["ncols"] = n_sess
        figpar["init"]["sharex"] = False
        figpar["init"]["sharey"] = False
        figpar["init"]["gs"] = {"wspace": 0.2, "hspace": 0.35}
        figpar["save"]["fig_ext"] = "png"
        
        fig, ax = plot_util.init_fig(2 * n_sess, **figpar["init"], proj="3d")

        fig.suptitle("Explained variance per ROI", y=1)

        # get colormap range
        c_range = [np.inf, -np.inf]
        c_key = xyzc_dims[3]

        for expl_var in all_expl_var:
            for var_type in ["coef_all", "coef_uni"]:
                rs = np.where(np.asarray(expl_var["rois"]) != -1)[0]
                if c_key in expl_var[var_type].keys():
                    c_data = np.asarray(expl_var[var_type][c_key])[rs, 0]
                    # adjust colormap range
                    c_range[0] = np.min([c_range[0], min(c_data)])
                    c_range[1] = np.max([c_range[1], max(c_data)])
        
        if not np.isfinite(sum(c_range)):
            c_range = [-0.5, 0.5] # dummy range
        else:
            c_range = plot_util.rounded_lims(c_range, out=True)

    else:
        logger.info("No plots, as only results across ROIs are included")
        fig = None

    i = 0
    for expl_var in all_expl_var:
        # collect info for plotting and logging results across ROIs
        rs = np.where(np.asarray(expl_var["rois"]) != -1)[0]
        all_rs = np.where(np.asarray(expl_var["rois"]) == -1)[0]
        if len(all_rs) != 1:
            raise RuntimeError("Expected only one result for all ROIs.")
        else:
            all_rs = all_rs[0]
            full_ev = expl_var["full"][all_rs]

        title = (f"Mouse {mouse_ns[i]} - {stimstr_pr}\n(sess {sess_ns[i]}, "
                f"{lines[i]} {planes[i]}{dendstr_pr},{nroi_strs[i]})")
        logger.info(title, extra={"spacing": "\n"})

        math_util.log_stats(full_ev, stat_str="\nFull explained variance")

        dim_length = max([len(dim) for dim in log_dims])
        
        for v, var_type in enumerate(["coef_all", "coef_uni"]):
            if var_type == "coef_all":
                sub_title = "Explained variance per coefficient"
            elif var_type == "coef_uni":
                sub_title = "Unique explained variance\nper coefficient"
            logger.info(sub_title, extra={"spacing": "\n"})

            dims_all = []
            for key in log_dims:
                if key in xyzc_dims:
                    # get mean/med
                    if key not in expl_var[var_type].keys():
                        dims_all.append("dummy")
                        continue

                    dims_all.append(np.asarray(expl_var[var_type][key])[rs, 0])
                math_util.log_stats(
                    expl_var[var_type][key][all_rs], 
                    stat_str=key.ljust(dim_length), log_spacing=TAB
                    )

            if not plot_bools[-1]:
                continue

            if v == 0:
                y = 1.12
                subpl_title = f"{title}\n{sub_title}"
            else:
                y = 1.02
                subpl_title = sub_title

            # retrieve values and names for each dimension, including dummy 
            # dimensions
            use_xyzc_dims = []
            n_vals = None
            dummies = []
            pads = [16, 16, 20]
            for d, dim in enumerate(dims_all):
                dim_name = xyzc_dims[d].replace("_", " ")
                if " direction"  in dim_name:
                    dim_name = dim_name.replace(" direction", "\ndirection")
                    pads[d] = 24
                if isinstance(dim, str) and dim == "dummy":
                    dummies.append(d)
                    use_xyzc_dims.append(f"{dim_name} (dummy)")
                else:
                    n_vals = len(dim)
                    use_xyzc_dims.append(dim_name)
            
            for d in dummies:
                dims_all[d] = np.zeros(n_vals)

            [x_data, y_data, z_data, c_data] = dims_all

            sub_ax = ax[v, i]
            im = sub_ax.scatter(
                x_data, y_data, z_data, c=c_data, cmap=cmap, 
                vmin=c_range[0], vmax=c_range[1]
                )
            sub_ax.set_title(subpl_title, y=y)
            # sub_ax.set_zlim3d(0, 1.0)

            # adjust padding for z axis
            sub_ax.tick_params(axis='z', which='major', pad=10)

            # add labels
            sub_ax.set_xlabel(use_xyzc_dims[0], labelpad=pads[0])
            sub_ax.set_ylabel(use_xyzc_dims[1], labelpad=pads[1])
            sub_ax.set_zlabel(use_xyzc_dims[2], labelpad=pads[2])

            if v == 0:
                full_ev_lab = math_util.log_stats(
                    full_ev, stat_str="Full EV", ret_str_only=True
                    )
                sub_ax.plot([], [], c="k", label=full_ev_lab)
                sub_ax.legend()

        i += 1

    if fig is not None:
        plot_util.add_colorbar(
            fig, im, n_sess, label=use_xyzc_dims[3],
            space_fact=np.max([2, n_sess])
            )

        # plot 0 planes, and lines
        for sub_ax in ax.reshape(-1):
            sub_ax.autoscale(False)
            all_lims = [sub_ax.get_xlim(), sub_ax.get_ylim(), sub_ax.get_zlim()]
            xs, ys, zs = [
                [vs[0] - (vs[1] - vs[0]) * 0.02, vs[1] + (vs[1] - vs[0]) * 0.02]
                for vs in all_lims
                ]
            
            for plane in ["x", "y", "z"]:
                if plane == "x":
                    xx, yy = np.meshgrid(xs, ys)
                    zz = xx * 0
                    x_flat = xs
                    y_flat, z_flat = [0, 0], [0, 0]
                elif plane == "y":
                    yy, zz = np.meshgrid(ys, zs)
                    xx = yy * 0
                    y_flat = ys
                    z_flat, x_flat = [0, 0], [0, 0]
                elif plane == "z":
                    zz, xx = np.meshgrid(zs, xs)
                    yy = zz * 0
                    z_flat = zs
                    x_flat, y_flat = [0, 0], [0, 0]
                
                sub_ax.plot_surface(xx, yy, zz, alpha=0.05, color="k")
                sub_ax.plot(
                    x_flat, y_flat, z_flat, alpha=0.4, color="k", ls=(0, (2, 2))
                    )

    if savedir is None:
        savedir = Path(
            figpar["dirs"]["roi"],
            figpar["dirs"]["glm"])

    savename = (f"roi_glm_ev_{sessstr}{dendstr}")

    fulldir = plot_util.savefig(fig, savename, savedir, **figpar["save"])

    return fulldir, savename                              
Ejemplo n.º 6
0
def plot_mag_change(analyspar, sesspar, stimpar, extrapar, permpar, quantpar, 
                    sess_info, mags, figpar=None, savedir=None):
    """
    plot_mag_change(analyspar, sesspar, stimpar, extrapar, permpar, quantpar, 
                    sess_info, mags) 

    From dictionaries, plots magnitude of change in unexpected and expected
    responses across quantiles.

    Returns figure name and save directory path.

    Required args:
        - analyspar (dict): dictionary with keys of AnalysPar namedtuple
        - sesspar (dict)  : dictionary with keys of SessPar namedtuple 
        - stimpar (dict)  : dictionary with keys of StimPar namedtuple
        - extrapar (dict) : dictionary containing additional analysis 
                            parameters
            ["analysis"] (str): analysis type (e.g., "m")
            ["datatype"] (str): datatype (e.g., "run", "roi")
            ["seed"]     (int): seed value used
        - permpar (dict)  : dictionary with keys of PermPar namedtuple 
        - quantpar (dict) : dictionary with keys of QuantPar namedtuple
        - roigrppar (dict): dictionary with keys of RoiGrpPar namedtuple
        - sess_info (dict): dictionary containing information from each
                            session 
            ["mouse_ns"] (list)   : mouse numbers
            ["sess_ns"] (list)    : session numbers  
            ["lines"] (list)      : mouse lines
            ["planes"] (list)     : imaging planes
            ["nrois"] (list)      : number of ROIs in session

        - mags (dict)     : dictionary containing magnitude data to plot
            ["L2"] (array-like)    : nested list containing L2 norms, 
                                     structured as: 
                                         sess x scaling x unexp
            ["L2_sig"] (list)      : L2 significance results for each session 
                                         ("hi", "lo" or "no")
            ["mag_sig"] (list)     : magnitude significance results for each 
                                     session 
                                         ("hi", "lo" or "no")
            ["mag_st"] (array-like): array or nested list containing magnitude 
                                     stats across ROIs, structured as: 
                                         sess x scaling x unexp x stats

    Optional args:
        - figpar (dict): dictionary containing the following figure parameter 
                         dictionaries
                         default: None
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters
        - savedir (str): path of directory in which to save plots.
                         default: None
    
    Returns:
        - fulldir (str) : final path of the directory in which the figure is 
                          saved (may differ from input savedir, if datetime 
                          subfolder is added.)
        - savename (str): name under which the figure is saved
    """
    
    sessstr_pr = sess_str_util.sess_par_str(
        sesspar["sess_n"], stimpar["stimtype"], sesspar["plane"], 
        stimpar["visflow_dir"], stimpar["visflow_size"], stimpar["gabk"], 
        "print")
    statstr_pr = sess_str_util.stat_par_str(
        analyspar["stats"], analyspar["error"], "print")
    dendstr_pr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"], "print")
        
    sessstr = sess_str_util.sess_par_str(
        sesspar["sess_n"], stimpar["stimtype"], sesspar["plane"], 
        stimpar["visflow_dir"],stimpar["visflow_size"], stimpar["gabk"]) 
    dendstr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"])
     
    datatype = extrapar["datatype"]
    dimstr = sess_str_util.datatype_dim_str(datatype)
    
    # extract some info from sess_info
    keys = ["mouse_ns", "sess_ns", "lines", "planes"]
    [mouse_ns, sess_ns, lines, planes] = [sess_info[key] for key in keys]
    nroi_strs = sess_str_util.get_nroi_strs(
        sess_info, empty=(datatype!="roi"), style="par"
        )    

    n_sess = len(mouse_ns)

    qu_ns = [gen_util.pos_idx(q, quantpar["n_quants"]) + 1 
        for q in quantpar["qu_idx"]]
    if len(qu_ns) != 2:
        raise ValueError(f"Expected 2 quantiles, not {len(qu_ns)}.")
    
    mag_st = np.asarray(mags["mag_st"])

    if figpar is None:
        figpar = sess_plot_util.init_figpar()

    figpar = copy.deepcopy(figpar)
    if figpar["save"]["use_dt"] is None:
        figpar["save"]["use_dt"] = gen_util.create_time_str()
 
    figpar["init"]["subplot_wid"] *= n_sess/2.0
    
    scales = [False, True]

    # get plot elements
    barw = 0.75
    # scaling strings for printing and filenames
    leg = ["exp", "unexp"]    
    cent, bar_pos, xlims = plot_util.get_barplot_xpos(n_sess, len(leg), barw)   
    title = (u"Magnitude ({}) of difference in activity".format(statstr_pr) +
        f"\nbetween Q{qu_ns[0]} and {qu_ns[1]} across {dimstr} "
        f"\n({sessstr_pr})")
    labels = [f"Mouse {mouse_ns[i]} sess {sess_ns[i]},\n {lines[i]} {planes[i]}"
        f"{dendstr_pr}{nroi_strs[i]}" for i in range(n_sess)]

    figs, axs = [], []
    for sc, scale in enumerate(scales):
        scalestr_pr = sess_str_util.scale_par_str(scale, "print")
        fig, ax = plot_util.init_fig(1, **figpar["init"])
        figs.append(fig)
        axs.append(ax)
        sub_ax = ax[0, 0]
        # always set ticks (even again) before setting labels
        sub_ax.set_xticks(cent)
        sub_ax.set_xticklabels(labels)
        title_scale = u"{}{}".format(title, scalestr_pr)
        sess_plot_util.add_axislabels(
            sub_ax, fluor=analyspar["fluor"], area=True, scale=scale, x_ax="", 
            datatype=datatype)
        for s, lab in enumerate(leg):
            xpos = list(zip(*bar_pos))[s]
            plot_util.plot_bars(
                sub_ax, xpos, mag_st[:, sc, s, 0], err=mag_st[:, sc, s, 1:], 
                width=barw, xlims=xlims, xticks="None", label=lab, capsize=4,
                title=title_scale, hline=0)
    
    # add significance markers
    for i in range(n_sess):
        signif = mags["mag_sig"][i]
        if signif in ["hi", "lo"]:
            xpos = bar_pos[i]
            for sc, (ax, scale) in enumerate(zip(axs, scales)):
                yval = mag_st[i, sc, :, 0]
                yerr = mag_st[i, sc, :, 1:]
                plot_util.plot_barplot_signif(ax[0, 0], xpos, yval, yerr)
    
    plot_util.turn_off_extra(ax, n_sess)

   # figure directory
    if savedir is None:
        savedir = Path(
            figpar["dirs"][datatype], 
            figpar["dirs"]["unexp_qu"], 
            figpar["dirs"]["mags"])
    
    log_dir = False
    for i, (fig, scale) in enumerate(zip(figs, scales)):
        if i == len(figs) - 1:
            log_dir = True
        scalestr = sess_str_util.scale_par_str(scale)
        savename = f"{datatype}_mag_diff_{sessstr}{dendstr}"
        savename_full = f"{savename}{scalestr}"
        fulldir = plot_util.savefig(
            fig, savename_full, savedir, log_dir=log_dir, ** figpar["save"])

    return fulldir, savename
Ejemplo n.º 7
0
def run_regr(args):
    """
    run_regr(args)

    Does runs of a logistic regressions on the specified comparison and range
    of sessions.
    
    Required args:
        - args (Argument parser): parser with analysis parameters as attributes:
            alg (str)             : algorithm to use ("sklearn" or "pytorch")
            bal (bool)            : if True, classes are balanced
            batchsize (int)       : nbr of samples dataloader will load per 
                                    batch (for "pytorch" alg)
            visflow_dir (str)     : visual flow direction to analyse
            visflow_per (float)   : number of seconds to include before visual 
                                    flow segments
            visflow_size (int or list): visual flow square sizes to include
            comp (str)            : type of comparison
            datadir (str)         : data directory
            dend (str)            : type of dendrites to use ("allen" or "dend")
            device (str)          : device name (i.e., "cuda" or "cpu")
            ep_freq (int)         : frequency at which to log loss to 
                                    console
            error (str)           : error to take, i.e., "std" (for std 
                                    or quantiles) or "sem" (for SEM or MAD)
            fluor (str)           : fluorescence trace type
            fontdir (str)         : directory in which additional fonts are 
                                    located
            gabfr (int)           : gabor frame of reference if comparison 
                                    is "unexp"
            gabk (int or list)    : gabor kappas to include
            gab_ori (list or str) : gabor orientations to include
            incl (str or list)    : sessions to include ("yes", "no", "all")
            lr (num)              : model learning rate (for "pytorch" alg)
            mouse_n (int)         : mouse number
            n_epochs (int)        : number of epochs
            n_reg (int)           : number of regular runs
            n_shuff (int)         : number of shuffled runs
            scale (bool)          : if True, each ROI is scaled
            output (str)          : general directory in which to save 
                                    output
            parallel (bool)       : if True, runs are done in parallel
            plt_bkend (str)       : pyplot backend to use
            q1v4 (bool)           : if True, analysis is trained on first and 
                                    tested on last quartiles
            exp_v_unexp (bool)    : if True, analysis is trained on 
                                    expected and tested on unexpected sequences
            runtype (str)         : type of run ("prod" or "pilot")
            seed (int)            : seed to seed random processes with
            sess_n (int)          : session number
            stats (str)           : stats to take, i.e., "mean" or "median"
            stimtype (str)        : stim to analyse ("gabors" or "visflow")
            train_p (list)        : proportion of dataset to allocate to 
                                    training
            uniqueid (str or int) : unique ID for analysis
            wd (float)            : weight decay value (for "pytorch" arg)
    """

    args = copy.deepcopy(args)

    if args.datadir is None:
        args.datadir = DEFAULT_DATADIR
    else:
        args.datadir = Path(args.datadir)

    if args.uniqueid == "datetime":
        args.uniqueid = gen_util.create_time_str()
    elif args.uniqueid in ["None", "none"]:
        args.uniqueid = None

    reseed = False
    if args.seed in [None, "None"]:
        reseed = True

    # deal with parameters
    extrapar = {"uniqueid": args.uniqueid, "seed": args.seed}

    techpar = {
        "reseed": reseed,
        "device": args.device,
        "alg": args.alg,
        "parallel": args.parallel,
        "plt_bkend": args.plt_bkend,
        "fontdir": args.fontdir,
        "output": args.output,
        "ep_freq": args.ep_freq,
        "n_reg": args.n_reg,
        "n_shuff": args.n_shuff,
    }

    mouse_df = DEFAULT_MOUSE_DF_PATH

    stimpar = logreg.get_stimpar(args.comp,
                                 args.stimtype,
                                 args.visflow_dir,
                                 args.visflow_size,
                                 args.gabfr,
                                 args.gabk,
                                 gab_ori=args.gab_ori,
                                 visflow_pre=args.visflow_pre)

    analyspar = sess_ntuple_util.init_analyspar(args.fluor,
                                                stats=args.stats,
                                                error=args.error,
                                                scale=not (args.no_scale),
                                                dend=args.dend)

    if args.q1v4:
        quantpar = sess_ntuple_util.init_quantpar(4, [0, -1])
    else:
        quantpar = sess_ntuple_util.init_quantpar(1, 0)

    logregpar = sess_ntuple_util.init_logregpar(args.comp, not (args.not_ctrl),
                                                args.q1v4, args.exp_v_unexp,
                                                args.n_epochs, args.batchsize,
                                                args.lr, args.train_p, args.wd,
                                                args.bal, args.alg)

    omit_sess, omit_mice = sess_gen_util.all_omit(stimpar.stimtype,
                                                  args.runtype,
                                                  stimpar.visflow_dir,
                                                  stimpar.visflow_size,
                                                  stimpar.gabk)

    sessids = sorted(
        sess_gen_util.get_sess_vals(mouse_df,
                                    "sessid",
                                    args.mouse_n,
                                    args.sess_n,
                                    args.runtype,
                                    incl=args.incl,
                                    omit_sess=omit_sess,
                                    omit_mice=omit_mice))

    if len(sessids) == 0:
        logger.warning(
            f"No sessions found (mouse: {args.mouse_n}, sess: {args.sess_n}, "
            f"runtype: {args.runtype})")

    for sessid in sessids:
        sess = sess_gen_util.init_sessions(sessid,
                                           args.datadir,
                                           mouse_df,
                                           args.runtype,
                                           full_table=False,
                                           fluor=analyspar.fluor,
                                           dend=analyspar.dend,
                                           temp_log="warning")[0]
        logreg.run_regr(sess, analyspar, stimpar, logregpar, quantpar,
                        extrapar, techpar)

        plot_util.cond_close_figs()
def plot_pup_diff_corr(analyspar, sesspar, stimpar, extrapar, 
                       sess_info, corr_data, figpar=None, savedir=None):
    """
    plot_pup_diff_corr(analyspar, sesspar, stimpar, extrapar, 
                       sess_info, corr_data)

    From dictionaries, plots correlation between unexpected-locked changes in 
    pupil diameter and running or ROI data for each session.

    Required args:
        - analyspar (dict)    : dictionary with keys of AnalysPar namedtuple
        - sesspar (dict)      : dictionary with keys of SessPar namedtuple 
        - stimpar (dict)      : dictionary with keys of StimPar namedtuple
        - extrapar (dict)     : dictionary containing additional analysis 
                                parameters
            ["analysis"] (str): analysis type (e.g., "c")
            ["datatype"] (str): datatype (e.g., "run", "roi")
        - sess_info (dict)    : dictionary containing information from each
                                session 
            ["mouse_ns"] (list)   : mouse numbers
            ["sess_ns"] (list)    : session numbers  
            ["lines"] (list)      : mouse lines
            ["planes"] (list)     : imaging planes
            ["nrois"] (list)      : number of ROIs in session
        - corr_data (dict)    : dictionary containing data to plot:
            ["corrs"] (list): list of correlation values between pupil and 
                              running or ROI differences for each session
            ["diffs"] (list): list of differences for each session, structured
                                  as [pupil, ROI/run] x trials x frames
    
    Optional args:
        - figpar (dict) : dictionary containing the following figure parameter 
                          dictionaries
                          default: None
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters
        - savedir (Path): path of directory in which to save plots.
                          default: None
    
    Returns:
        - fulldir (Path): final path of the directory in which the figure is 
                          saved (may differ from input savedir, if datetime 
                          subfolder is added.)
        - savename (str): name under which the figure is saved
    """
    statstr_pr = sess_str_util.stat_par_str(
        analyspar["stats"], analyspar["error"], "print")
    stimstr_pr = sess_str_util.stim_par_str(
        stimpar["stimtype"], stimpar["visflow_dir"], stimpar["visflow_size"], 
        stimpar["gabk"], "print")
    dendstr_pr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"], "print")

    sessstr = sess_str_util.sess_par_str(
        sesspar["sess_n"], stimpar["stimtype"], sesspar["plane"], 
        stimpar["visflow_dir"],stimpar["visflow_size"], stimpar["gabk"]) 
    dendstr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"])
    
    datatype = extrapar["datatype"]
    datastr = sess_str_util.datatype_par_str(datatype)

    if datatype == "roi":
        label_str = sess_str_util.fluor_par_str(
            analyspar["fluor"], str_type="print")
        full_label_str = u"{}, {} across ROIs".format(
            label_str, analyspar["stats"])
    elif datatype == "run":
        label_str = datastr
        full_label_str = datastr
    
    lab_app = (f" ({analyspar['stats']} over "
        f"{stimpar['pre']}/{stimpar['post']} sec)")

    logger.info(f"Plotting pupil vs {datastr} changes.")
    
    delta = "\u0394"

    # extract some info from sess_info
    keys = ["mouse_ns", "sess_ns", "lines", "planes"]
    [mouse_ns, sess_ns, lines, planes] = [sess_info[key] for key in keys]

    n_sess = len(mouse_ns)
    nroi_strs = sess_str_util.get_nroi_strs(
        sess_info, empty=(datatype!="roi"), style="comma"
        ) 

    if figpar is None:
        figpar = sess_plot_util.init_figpar()

    figpar = copy.deepcopy(figpar)
    if figpar["save"]["use_dt"] is None:
        figpar["save"]["use_dt"] = gen_util.create_time_str()
    figpar["init"]["sharex"] = False
    figpar["init"]["sharey"] = False
    figpar["init"]["ncols"] = n_sess
    
    fig, ax = plot_util.init_fig(2 * n_sess, **figpar["init"])
    suptitle = (f"Relationship between pupil diam. and {datastr} changes, "
        "locked to unexpected events")
    
    for i, sess_diffs in enumerate(corr_data["diffs"]):
        sub_axs = ax[:, i]
        title = (f"Mouse {mouse_ns[i]} - {stimstr_pr}, " + 
            u"{}".format(statstr_pr) + f"\n(sess {sess_ns[i]}, {lines[i]} "
            f"{planes[i]}{dendstr_pr}{nroi_strs[i]})")
        
        # top plot: correlations
        corr = f"Corr = {corr_data['corrs'][i]:.2f}"
        sub_axs[0].plot(
            sess_diffs[0], sess_diffs[1], marker=".", linestyle="None", 
            label=corr)
        sub_axs[0].set_title(title, y=1.01)
        sub_axs[0].set_xlabel(u"{} pupil diam.{}".format(delta, lab_app))
        if i == 0:
            sub_axs[0].set_ylabel(u"{} {}\n{}".format(
                delta, full_label_str, lab_app))
        sub_axs[0].legend()
        
        # bottom plot: differences across occurrences
        data_lab = u"{} {}".format(delta, label_str)   
        pup_lab = u"{} pupil diam.".format(delta)
        cols = []
        scaled = []
        for d, lab in enumerate([pup_lab, data_lab]):
            scaled.append(math_util.scale_data(
                np.asarray(sess_diffs[d]), sc_type="min_max")[0])
            art, = sub_axs[1].plot(scaled[-1], marker=".")
            cols.append(sub_axs[-1].lines[-1].get_color())
            if i == n_sess - 1: # only for last graph
                art.set_label(lab)
                sub_axs[1].legend()
        sub_axs[1].set_xlabel("Unexpected event occurrence")
        if i == 0:
            sub_axs[1].set_ylabel(
                u"{} response locked\nto unexpected onset (scaled)".format(delta))
        # shade area between lines
        plot_util.plot_btw_traces(
            sub_axs[1], scaled[0], scaled[1], color=cols, alpha=0.4)

    fig.suptitle(suptitle, fontsize="xx-large", y=1)

    if savedir is None:
        savedir = Path(
            figpar["dirs"][datatype],
            figpar["dirs"]["pupil"])

    savename = f"{datatype}_diff_corr_{sessstr}{dendstr}"

    fulldir = plot_util.savefig(fig, savename, savedir, **figpar["save"])

    return fulldir, savename                              
def plot_pup_roi_stim_corr(analyspar, sesspar, stimpar, extrapar, 
                           sess_info, corr_data, figpar=None, savedir=None):
    """
    plot_pup_roi_stim_corr(analyspar, sesspar, stimpar, extrapar, 
                           sess_info, corr_data)

    From dictionaries, plots correlation between unexpected-locked changes in 
    pupil diameter and each ROI, for gabors versus visual flow responses for 
    each session.
    
    Required args:
        - analyspar (dict)    : dictionary with keys of AnalysPar namedtuple
        - sesspar (dict)      : dictionary with keys of SessPar namedtuple 
        - stimpar (dict)      : dictionary with keys of StimPar namedtuple
        - extrapar (dict)     : dictionary containing additional analysis 
                                parameters
            ["analysis"] (str): analysis type (e.g., "r")
            ["datatype"] (str): datatype (e.g., "roi")
        - sess_info (dict)    : dictionary containing information from each
                                session 
            ["mouse_ns"] (list)   : mouse numbers
            ["sess_ns"] (list)    : session numbers  
            ["lines"] (list)      : mouse lines
            ["planes"] (list)     : imaging planes
            ["nrois"] (list)      : number of ROIs in session
        - corr_data (dict)    : dictionary containing data to plot:
            ["stim_order"] (list): ordered list of stimtypes
            ["roi_corrs"] (list) : nested list of correlations between pupil 
                                   and ROI responses changes locked to 
                                   unexpected, structured as 
                                       session x stimtype x ROI
            ["corrs"] (list)     : list of correlation between stimtype
                                   correlations for each session
    
    Optional args:
        - figpar (dict) : dictionary containing the following figure parameter 
                          dictionaries
                          default: None
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters
        - savedir (Path): path of directory in which to save plots.
                          default: None
    
    Returns:
        - fulldir (Path): final path of the directory in which the figure is 
                          saved (may differ from input savedir, if datetime 
                          subfolder is added.)
        - savename (str): name under which the figure is saved
    """

    stimstr_prs = []
    for stimtype in corr_data["stim_order"]:
        stimstr_pr = sess_str_util.stim_par_str(
            stimtype, stimpar["visflow_dir"], stimpar["visflow_size"], 
            stimpar["gabk"], "print")
        stimstr_pr = stimstr_pr[:-1] if stimstr_pr[-1] == "s" else stimstr_pr
        stimstr_prs.append(stimstr_pr)
        
    dendstr_pr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"], "print")

    sessstr = f"sess{sesspar['sess_n']}_{sesspar['plane']}" 
    dendstr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"])

    label_str = sess_str_util.fluor_par_str(
        analyspar["fluor"], str_type="print")
    lab_app = (f" ({analyspar['stats']} over "
        f"{stimpar['pre']}/{stimpar['post']} sec)")

    logger.info("Plotting pupil-ROI difference correlations for "
        "{} vs {}.".format(*corr_data["stim_order"]))

    # extract some info from sess_info
    keys = ["mouse_ns", "sess_ns", "lines", "planes"]
    [mouse_ns, sess_ns, lines, planes] = [sess_info[key] for key in keys]

    n_sess = len(mouse_ns)
    nroi_strs = sess_str_util.get_nroi_strs(sess_info, style="comma")

    if figpar is None:
        figpar = sess_plot_util.init_figpar()

    figpar = copy.deepcopy(figpar)
    if figpar["save"]["use_dt"] is None:
        figpar["save"]["use_dt"] = gen_util.create_time_str()
    figpar["init"]["sharex"] = True
    figpar["init"]["sharey"] = True
    
    fig, ax = plot_util.init_fig(n_sess, **figpar["init"])
    suptitle = (u"Relationship between pupil diam. and {} changes, locked to "
        "unexpected events\n{} for each ROI ({} vs {})".format(
            label_str, lab_app, *corr_data["stim_order"]))
    
    for i, sess_roi_corrs in enumerate(corr_data["roi_corrs"]):
        sub_ax = plot_util.get_subax(ax, i)
        title = (f"Mouse {mouse_ns[i]} (sess {sess_ns[i]}, {lines[i]} "
            f"{planes[i]}{dendstr_pr}{nroi_strs[i]})")
        
        # top plot: correlations
        corr = f"Corr = {corr_data['corrs'][i]:.2f}"
        sub_ax.plot(
            sess_roi_corrs[0], sess_roi_corrs[1], marker=".", linestyle="None", 
            label=corr)
        sub_ax.set_title(title, y=1.01)
        if sub_ax.is_last_row():
            sub_ax.set_xlabel(f"{stimstr_prs[0].capitalize()} correlations")
        if sub_ax.is_first_col():
            sub_ax.set_ylabel(f"{stimstr_prs[1].capitalize()} correlations")
        sub_ax.legend()

    plot_util.turn_off_extra(ax, n_sess)

    fig.suptitle(suptitle, fontsize="xx-large", y=1)

    if savedir is None:
        savedir = Path(
            figpar["dirs"]["roi"],
            figpar["dirs"]["pupil"])

    savename = f"roi_diff_corrbyroi_{sessstr}{dendstr}"

    fulldir = plot_util.savefig(fig, savename, savedir, **figpar["save"])

    return fulldir, savename