Example #1
0
def run_autocorr(sessions,
                 analysis,
                 analyspar,
                 sesspar,
                 stimpar,
                 autocorrpar,
                 figpar,
                 datatype="roi"):
    """
    run_autocorr(sessions, analysis, analyspar, sesspar, stimpar, autocorrpar, 
                 figpar)

    Calculates and plots autocorrelation during stimulus blocks.
    Saves results and parameters relevant to analysis in a dictionary.

    Required args:
        - sessions (list)          : list of Session objects
        - analysis (str)           : analysis type (e.g., "a")
        - analyspar (AnalysPar)    : named tuple containing analysis parameters
        - sesspar (SessPar)        : named tuple containing session parameters
        - stimpar (StimPar)        : named tuple containing stimulus parameters
        - autocorrpar (AutocorrPar): named tuple containing autocorrelation 
                                     analysis parameters
        - figpar (dict)            : dictionary containing figure parameters

    Optional args:
        - datatype (str): type of data (e.g., "roi", "run")
    """

    sessstr_pr = sess_str_util.sess_par_str(sesspar.sess_n, stimpar.stimtype,
                                            sesspar.plane, stimpar.visflow_dir,
                                            stimpar.visflow_size, stimpar.gabk,
                                            "print")
    dendstr_pr = sess_str_util.dend_par_str(analyspar.dend, sesspar.plane,
                                            datatype, "print")

    datastr = sess_str_util.datatype_par_str(datatype)

    logger.info(
        f"Analysing and plotting {datastr} autocorrelations "
        f"({sessstr_pr}{dendstr_pr}).",
        extra={"spacing": "\n"})

    xrans = []
    stats = []
    for sess in sessions:
        if datatype == "roi" and (sess.only_tracked_rois != analyspar.tracked):
            raise RuntimeError(
                "sess.only_tracked_rois should match analyspar.tracked.")
        stim = sess.get_stim(stimpar.stimtype)
        all_segs = stim.get_segs_by_criteria(visflow_dir=stimpar.visflow_dir,
                                             visflow_size=stimpar.visflow_size,
                                             gabk=stimpar.gabk,
                                             by="block")
        sess_traces = []
        for segs in all_segs:
            if len(segs) == 0:
                continue
            segs = sorted(segs)
            # check that segs are contiguous
            if max(np.diff(segs)) > 1:
                raise NotImplementedError("Segments used for autocorrelation "
                                          "must be contiguous within blocks.")
            if datatype == "roi":
                frame_edges = stim.get_fr_by_seg(
                    [min(segs), max(segs)], fr_type="twop")
                fr = list(range(min(frame_edges[0]), max(frame_edges[1]) + 1))
                traces = gen_util.reshape_df_data(sess.get_roi_traces(
                    fr,
                    fluor=analyspar.fluor,
                    rem_bad=analyspar.rem_bad,
                    scale=analyspar.scale),
                                                  squeeze_cols=True)

            elif datatype == "run":
                if autocorrpar.byitem != False:
                    raise ValueError("autocorrpar.byitem must be False for "
                                     "running data.")
                frame_edges = stim.get_fr_by_seg(
                    [min(segs), max(segs)], fr_type="stim")
                fr = list(range(min(frame_edges[0]), max(frame_edges[1]) + 1))

                traces = sess.get_run_velocity_by_fr(
                    fr,
                    fr_type="stim",
                    rem_bad=analyspar.rem_bad,
                    scale=analyspar.scale).to_numpy().reshape(1, -1)

            sess_traces.append(traces)

        # Calculate autocorr stats while filtering some warnings
        msgs = ["Degrees of freedom", "invalid value encountered"]
        categs = [RuntimeWarning, RuntimeWarning]
        with gen_util.TempWarningFilter(msgs, categs):
            xran, ac_st = math_util.autocorr_stats(sess_traces,
                                                   autocorrpar.lag_s,
                                                   sess.twop_fps,
                                                   byitem=autocorrpar.byitem,
                                                   stats=analyspar.stats,
                                                   error=analyspar.error)

        if not autocorrpar.byitem:  # also add a 10x lag
            lag_fr = 10 * int(autocorrpar.lag_s * sess.twop_fps)
            _, ac_st_10x = math_util.autocorr_stats(sess_traces,
                                                    lag_fr,
                                                    byitem=autocorrpar.byitem,
                                                    stats=analyspar.stats,
                                                    error=analyspar.error)
            downsamp = range(0, ac_st_10x.shape[-1], 10)

            if len(downsamp) != ac_st.shape[-1]:
                raise RuntimeError("Failed to downsample correctly. "
                                   "Check implementation.")
            ac_st = np.stack([ac_st, ac_st_10x[:, downsamp]], axis=1)
        xrans.append(xran)
        stats.append(ac_st)

    autocorr_data = {
        "xrans": [xran.tolist() for xran in xrans],
        "stats": [stat.tolist() for stat in stats]
    }

    sess_info = sess_gen_util.get_sess_info(sessions,
                                            analyspar.fluor,
                                            incl_roi=(datatype == "roi"),
                                            rem_bad=analyspar.rem_bad)

    extrapar = {
        "analysis": analysis,
        "datatype": datatype,
    }

    info = {
        "analyspar": analyspar._asdict(),
        "sesspar": sesspar._asdict(),
        "stimpar": stimpar._asdict(),
        "extrapar": extrapar,
        "autocorrpar": autocorrpar._asdict(),
        "autocorr_data": autocorr_data,
        "sess_info": sess_info
    }

    fulldir, savename = gen_plots.plot_autocorr(figpar=figpar, **info)

    file_util.saveinfo(info, savename, fulldir, "json")
Example #2
0
def run_full_traces(sessions,
                    analysis,
                    analyspar,
                    sesspar,
                    figpar,
                    datatype="roi"):
    """
    run_full_traces(sessions, analysis, analyspar, sesspar, figpar)

    Plots full traces across an entire session. If ROI traces are plotted,
    each ROI is scaled and plotted separately and an average is plotted.
    
    Saves results and parameters relevant to analysis in a dictionary.

    Required args:
        - sessions (list)      : list of Session objects
        - analysis (str)       : analysis type (e.g., "f")
        - analyspar (AnalysPar): named tuple containing analysis parameters
        - sesspar (SessPar)    : named tuple containing session parameters
        - figpar (dict)        : dictionary containing figure parameters
    
    Optional args:
        - datatype (str): type of data (e.g., "roi", "run")
    """

    dendstr_pr = sess_str_util.dend_par_str(analyspar.dend, sesspar.plane,
                                            datatype, "print")

    sessstr_pr = (f"session: {sesspar.sess_n}, "
                  f"plane: {sesspar.plane}{dendstr_pr}")

    datastr = sess_str_util.datatype_par_str(datatype)

    logger.info(
        f"Plotting {datastr} traces across an entire "
        f"session\n({sessstr_pr}).",
        extra={"spacing": "\n"})

    figpar = copy.deepcopy(figpar)
    if figpar["save"]["use_dt"] is None:
        figpar["save"]["use_dt"] = gen_util.create_time_str()

    all_tr, roi_tr, all_edges, all_pars = [], [], [], []
    for sess in sessions:
        # get the block edges and parameters
        edge_fr, par_descrs = [], []
        for stim in sess.stims:
            stim_str = stim.stimtype
            if stim.stimtype == "visflow":
                stim_str = "vis. flow"
            if datatype == "roi":
                fr_type = "twop"
            elif datatype == "run":
                fr_type = "stim"
            else:
                gen_util.accepted_values_error("datatype", datatype,
                                               ["roi", "run"])
            for b in stim.block_params.index:
                row = stim.block_params.loc[b]
                edge_fr.append([
                    int(row[f"start_frame_{fr_type}"]),
                    int(row[f"stop_frame_{fr_type}"])
                ])
                par_vals = [row[param] for param in stim.stim_params]
                pars_str = "\n".join([str(par) for par in par_vals][0:2])

                par_descrs.append(
                    sess_str_util.pars_to_descr(
                        f"{stim_str.capitalize()}\n{pars_str}"))

        if datatype == "roi":
            if sess.only_tracked_rois != analyspar.tracked:
                raise RuntimeError(
                    "sess.only_tracked_rois should match analyspar.tracked.")
            nanpol = None
            if not analyspar.rem_bad:
                nanpol = "omit"
            all_rois = gen_util.reshape_df_data(sess.get_roi_traces(
                None, analyspar.fluor, analyspar.rem_bad,
                analyspar.scale)["roi_traces"],
                                                squeeze_cols=True)
            full_tr = math_util.get_stats(all_rois,
                                          analyspar.stats,
                                          analyspar.error,
                                          axes=0,
                                          nanpol=nanpol).tolist()
            roi_tr.append(all_rois.tolist())
        elif datatype == "run":
            full_tr = sess.get_run_velocity(
                rem_bad=analyspar.rem_bad,
                scale=analyspar.scale).to_numpy().squeeze().tolist()
            roi_tr = None
        all_tr.append(full_tr)
        all_edges.append(edge_fr)
        all_pars.append(par_descrs)

    extrapar = {
        "analysis": analysis,
        "datatype": datatype,
    }

    trace_info = {
        "all_tr": all_tr,
        "all_edges": all_edges,
        "all_pars": all_pars
    }

    sess_info = sess_gen_util.get_sess_info(sessions,
                                            analyspar.fluor,
                                            incl_roi=(datatype == "roi"),
                                            rem_bad=analyspar.rem_bad)

    info = {
        "analyspar": analyspar._asdict(),
        "sesspar": sesspar._asdict(),
        "extrapar": extrapar,
        "sess_info": sess_info,
        "trace_info": trace_info
    }

    fulldir, savename = gen_plots.plot_full_traces(roi_tr=roi_tr,
                                                   figpar=figpar,
                                                   **info)
    file_util.saveinfo(info, savename, fulldir, "json")
Example #3
0
def run_mag_change(sessions,
                   analysis,
                   seed,
                   analyspar,
                   sesspar,
                   stimpar,
                   permpar,
                   quantpar,
                   figpar,
                   datatype="roi"):
    """
    run_mag_change(sessions, analysis, seed, analyspar, sesspar, stimpar, 
                   permpar, quantpar, figpar)

    Calculates and plots the magnitude of change in activity of ROIs between 
    the first and last quantile for expected vs unexpected sequences.
    Saves results and parameters relevant to analysis in a dictionary.

    Required args:
        - sessions (list)      : list of Session objects
        - analysis (str)       : analysis type (e.g., "m")
        - seed (int)           : seed value to use. (-1 treated as None) 
        - analyspar (AnalysPar): named tuple containing analysis parameters
        - sesspar (SessPar)    : named tuple containing session parameters
        - stimpar (StimPar)    : named tuple containing stimulus parameters
        - permpar (PermPar)    : named tuple containing permutation parameters
        - quantpar (QuantPar)  : named tuple containing quantile analysis 
                                 parameters
        - figpar (dict)        : dictionary containing figure parameters   

    Optional args:
        - datatype (str): type of data (e.g., "roi", "run") 
    """

    sessstr_pr = sess_str_util.sess_par_str(sesspar.sess_n, stimpar.stimtype,
                                            sesspar.plane, stimpar.visflow_dir,
                                            stimpar.visflow_size, stimpar.gabk,
                                            "print")
    dendstr_pr = sess_str_util.dend_par_str(analyspar.dend, sesspar.plane,
                                            datatype, "print")

    datastr = sess_str_util.datatype_par_str(datatype)

    logger.info(
        f"Calculating and plotting the magnitude changes in {datastr} "
        f"activity across quantiles \n({sessstr_pr}{dendstr_pr}).",
        extra={"spacing": "\n"})

    if permpar.multcomp:
        permpar = sess_ntuple_util.get_modif_ntuple(permpar, "multcomp",
                                                    len(sessions))

    # get full data: session x unexp x quants of interest x [ROI x seq]
    integ_info = quant_analys.trace_stats_by_qu_sess(sessions,
                                                     analyspar,
                                                     stimpar,
                                                     quantpar.n_quants,
                                                     quantpar.qu_idx,
                                                     by_exp=True,
                                                     integ=True,
                                                     ret_arr=True,
                                                     datatype=datatype)
    all_counts = integ_info[-2]
    qu_data = integ_info[-1]

    # extract session info
    mouse_ns = [sess.mouse_n for sess in sessions]
    lines = [sess.line for sess in sessions]

    if analyspar.rem_bad:
        nanpol = None
    else:
        nanpol = "omit"

    seed = rand_util.seed_all(seed, "cpu", log_seed=False)

    mags = quant_analys.qu_mags(qu_data,
                                permpar,
                                mouse_ns,
                                lines,
                                analyspar.stats,
                                analyspar.error,
                                nanpol=nanpol,
                                op_qu="diff",
                                op_unexp="diff")

    # convert mags items to list
    mags = copy.deepcopy(mags)
    mags["all_counts"] = all_counts
    for key in ["mag_st", "L2", "mag_rel_th", "L2_rel_th"]:
        mags[key] = mags[key].tolist()

    sess_info = sess_gen_util.get_sess_info(sessions,
                                            analyspar.fluor,
                                            incl_roi=(datatype == "roi"),
                                            rem_bad=analyspar.rem_bad)

    extrapar = {"analysis": analysis, "datatype": datatype, "seed": seed}

    info = {
        "analyspar": analyspar._asdict(),
        "sesspar": sesspar._asdict(),
        "stimpar": stimpar._asdict(),
        "extrapar": extrapar,
        "permpar": permpar._asdict(),
        "quantpar": quantpar._asdict(),
        "mags": mags,
        "sess_info": sess_info
    }

    fulldir, savename = gen_plots.plot_mag_change(figpar=figpar, **info)

    file_util.saveinfo(info, savename, fulldir, "json")
Example #4
0
def run_traces_by_qu_unexp_sess(sessions,
                                analysis,
                                analyspar,
                                sesspar,
                                stimpar,
                                quantpar,
                                figpar,
                                datatype="roi"):
    """
    run_traces_by_qu_unexp_sess(sessions, analysis, analyspar, sesspar, 
                               stimpar, quantpar, figpar)

    Retrieves trace statistics by session x unexp val x quantile and
    plots traces across ROIs by quantile/unexpected with each session in a 
    separate subplot.
    
    Also runs analysis for one quantile (full data).
    
    Saves results and parameters relevant to analysis in a dictionary.

    Required args:
        - sessions (list)      : list of Session objects
        - analysis (str)       : analysis type (e.g., "t")
        - analyspar (AnalysPar): named tuple containing analysis parameters
        - sesspar (SessPar)    : named tuple containing session parameters
        - stimpar (StimPar)    : named tuple containing stimulus parameters
        - quantpar (QuantPar)  : named tuple containing quantile analysis 
                                 parameters
        - figpar (dict)        : dictionary containing figure parameters
    
    Optional args:
        - datatype (str): type of data (e.g., "roi", "run")
    """

    sessstr_pr = sess_str_util.sess_par_str(sesspar.sess_n, stimpar.stimtype,
                                            sesspar.plane, stimpar.visflow_dir,
                                            stimpar.visflow_size, stimpar.gabk,
                                            "print")
    dendstr_pr = sess_str_util.dend_par_str(analyspar.dend, sesspar.plane,
                                            datatype, "print")

    datastr = sess_str_util.datatype_par_str(datatype)

    logger.info(
        f"Analysing and plotting unexpected vs expected {datastr} "
        f"traces by quantile ({quantpar.n_quants}) \n({sessstr_pr}"
        f"{dendstr_pr}).",
        extra={"spacing": "\n"})

    # modify quantpar to retain all quantiles
    quantpar_one = sess_ntuple_util.init_quantpar(1, 0)
    n_quants = quantpar.n_quants
    quantpar_mult = sess_ntuple_util.init_quantpar(n_quants, "all")

    figpar = copy.deepcopy(figpar)
    if figpar["save"]["use_dt"] is None:
        figpar["save"]["use_dt"] = gen_util.create_time_str()

    for quantpar in [quantpar_one, quantpar_mult]:
        logger.info(f"{quantpar.n_quants} quant", extra={"spacing": "\n"})
        # get the stats (all) separating by session, unexpected and quantiles
        trace_info = quant_analys.trace_stats_by_qu_sess(sessions,
                                                         analyspar,
                                                         stimpar,
                                                         quantpar.n_quants,
                                                         quantpar.qu_idx,
                                                         byroi=False,
                                                         by_exp=True,
                                                         datatype=datatype)

        extrapar = {
            "analysis": analysis,
            "datatype": datatype,
        }

        xrans = [xran.tolist() for xran in trace_info[0]]
        all_stats = [sessst.tolist() for sessst in trace_info[1]]
        trace_stats = {
            "xrans": xrans,
            "all_stats": all_stats,
            "all_counts": trace_info[2]
        }

        sess_info = sess_gen_util.get_sess_info(sessions,
                                                analyspar.fluor,
                                                incl_roi=(datatype == "roi"),
                                                rem_bad=analyspar.rem_bad)

        info = {
            "analyspar": analyspar._asdict(),
            "sesspar": sesspar._asdict(),
            "stimpar": stimpar._asdict(),
            "quantpar": quantpar._asdict(),
            "extrapar": extrapar,
            "sess_info": sess_info,
            "trace_stats": trace_stats
        }

        fulldir, savename = gen_plots.plot_traces_by_qu_unexp_sess(
            figpar=figpar, **info)
        file_util.saveinfo(info, savename, fulldir, "json")
Example #5
0
def run_traces_by_qu_lock_sess(sessions,
                               analysis,
                               seed,
                               analyspar,
                               sesspar,
                               stimpar,
                               quantpar,
                               figpar,
                               datatype="roi"):
    """
    run_traces_by_qu_lock_sess(sessions, analysis, analyspar, sesspar, 
                               stimpar, quantpar, figpar)

    Retrieves trace statistics by session x quantile at the transition of
    expected to unexpected sequences (or v.v.) and plots traces across ROIs by 
    quantile with each session in a separate subplot.
    
    Also runs analysis for one quantile (full data) with different unexpected 
    lengths grouped separated 
    
    Saves results and parameters relevant to analysis in a dictionary.

    Required args:
        - sessions (list)      : list of Session objects
        - analysis (str)       : analysis type (e.g., "l")
        - seed (int)           : seed value to use. (-1 treated as None)
        - analyspar (AnalysPar): named tuple containing analysis parameters
        - sesspar (SessPar)    : named tuple containing session parameters
        - stimpar (StimPar)    : named tuple containing stimulus parameters
        - quantpar (QuantPar)  : named tuple containing quantile analysis 
                                 parameters
        - figpar (dict)        : dictionary containing figure parameters
    
    Optional args:
        - datatype (str): type of data (e.g., "roi", "run")

    """

    sessstr_pr = sess_str_util.sess_par_str(sesspar.sess_n, stimpar.stimtype,
                                            sesspar.plane, stimpar.visflow_dir,
                                            stimpar.visflow_size, stimpar.gabk,
                                            "print")
    dendstr_pr = sess_str_util.dend_par_str(analyspar.dend, sesspar.plane,
                                            datatype, "print")

    datastr = sess_str_util.datatype_par_str(datatype)

    logger.info(
        f"Analysing and plotting unexpected vs expected {datastr} "
        f"traces locked to unexpected onset by quantile ({quantpar.n_quants}) "
        f"\n({sessstr_pr}{dendstr_pr}).",
        extra={"spacing": "\n"})

    seed = rand_util.seed_all(seed, "cpu", log_seed=False)

    # modify quantpar to retain all quantiles
    quantpar_one = sess_ntuple_util.init_quantpar(1, 0)
    n_quants = quantpar.n_quants
    quantpar_mult = sess_ntuple_util.init_quantpar(n_quants, "all")

    if stimpar.stimtype == "visflow":
        pre_post = [2.0, 6.0]
    elif stimpar.stimtype == "gabors":
        pre_post = [2.0, 8.0]
    else:
        gen_util.accepted_values_error("stimpar.stimtype", stimpar.stimtype,
                                       ["visflow", "gabors"])
    logger.warning("Setting pre to {}s and post to {}s.".format(*pre_post))

    stimpar = sess_ntuple_util.get_modif_ntuple(stimpar, ["pre", "post"],
                                                pre_post)

    figpar = copy.deepcopy(figpar)
    if figpar["save"]["use_dt"] is None:
        figpar["save"]["use_dt"] = gen_util.create_time_str()

    for baseline in [None, stimpar.pre]:
        basestr_pr = sess_str_util.base_par_str(baseline, "print")
        for quantpar in [quantpar_one, quantpar_mult]:
            locks = ["unexp", "exp"]
            if quantpar.n_quants == 1:
                locks.append("unexp_split")
            # get the stats (all) separating by session and quantiles
            for lock in locks:
                logger.info(
                    f"{quantpar.n_quants} quant, {lock} lock{basestr_pr}",
                    extra={"spacing": "\n"})
                if lock == "unexp_split":
                    trace_info = quant_analys.trace_stats_by_exp_len_sess(
                        sessions,
                        analyspar,
                        stimpar,
                        quantpar.n_quants,
                        quantpar.qu_idx,
                        byroi=False,
                        nan_empty=True,
                        baseline=baseline,
                        datatype=datatype)
                else:
                    trace_info = quant_analys.trace_stats_by_qu_sess(
                        sessions,
                        analyspar,
                        stimpar,
                        quantpar.n_quants,
                        quantpar.qu_idx,
                        byroi=False,
                        lock=lock,
                        nan_empty=True,
                        baseline=baseline,
                        datatype=datatype)

                # for comparison, locking to middle of expected sample (1 quant)
                exp_samp = quant_analys.trace_stats_by_qu_sess(
                    sessions,
                    analyspar,
                    stimpar,
                    quantpar_one.n_quants,
                    quantpar_one.qu_idx,
                    byroi=False,
                    lock="exp_samp",
                    nan_empty=True,
                    baseline=baseline,
                    datatype=datatype)

                extrapar = {
                    "analysis": analysis,
                    "datatype": datatype,
                    "seed": seed,
                }

                xrans = [xran.tolist() for xran in trace_info[0]]
                all_stats = [sessst.tolist() for sessst in trace_info[1]]
                exp_stats = [expst.tolist() for expst in exp_samp[1]]
                trace_stats = {
                    "xrans": xrans,
                    "all_stats": all_stats,
                    "all_counts": trace_info[2],
                    "lock": lock,
                    "baseline": baseline,
                    "exp_stats": exp_stats,
                    "exp_counts": exp_samp[2]
                }

                if lock == "unexp_split":
                    trace_stats["unexp_lens"] = trace_info[3]

                sess_info = sess_gen_util.get_sess_info(
                    sessions,
                    analyspar.fluor,
                    incl_roi=(datatype == "roi"),
                    rem_bad=analyspar.rem_bad)

                info = {
                    "analyspar": analyspar._asdict(),
                    "sesspar": sesspar._asdict(),
                    "stimpar": stimpar._asdict(),
                    "quantpar": quantpar._asdict(),
                    "extrapar": extrapar,
                    "sess_info": sess_info,
                    "trace_stats": trace_stats
                }

                fulldir, savename = gen_plots.plot_traces_by_qu_lock_sess(
                    figpar=figpar, **info)
                file_util.saveinfo(info, savename, fulldir, "json")
Example #6
0
def plot_autocorr(analyspar, sesspar, stimpar, extrapar, autocorrpar, 
                  sess_info, autocorr_data, figpar=None, savedir=None):
    """
    plot_autocorr(analyspar, sesspar, stimpar, extrapar, autocorrpar, 
                  sess_info, autocorr_data)

    From dictionaries, plots autocorrelation during stimulus blocks.

    Required args:
        - analyspar (dict)    : dictionary with keys of AnalysPar namedtuple
        - sesspar (dict)      : dictionary with keys of SessPar namedtuple 
        - stimpar (dict)      : dictionary with keys of StimPar namedtuple
        - extrapar (dict)     : dictionary containing additional analysis 
                                parameters
            ["analysis"] (str): analysis type (e.g., "a")
            ["datatype"] (str): datatype (e.g., "run", "roi")
        - autocorrpar (dict)  : dictionary with keys of AutocorrPar namedtuple
        - sess_info (dict)    : dictionary containing information from each
                                session 
            ["mouse_ns"] (list)   : mouse numbers
            ["sess_ns"] (list)    : session numbers  
            ["lines"] (list)      : mouse lines
            ["planes"] (list)     : imaging planes
            ["nrois"] (list)      : number of ROIs in session

        - autocorr_data (dict): dictionary containing data to plot:
            ["xrans"] (list): list of lag values in seconds for each session
            ["stats"] (list): list of 3D arrays (or nested lists) of
                              autocorrelation statistics, structured as:
                                     sessions stats (me, err) 
                                     x ROI or 1x and 10x lag 
                                     x lag
    
    Optional args:
        - figpar (dict): dictionary containing the following figure parameter 
                         dictionaries
                         default: None
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters
        - savedir (str): path of directory in which to save plots.
                         default: None
    
    Returns:
        - fulldir (str) : final path of the directory in which the figure is 
                          saved (may differ from input savedir, if datetime 
                          subfolder is added.)
        - savename (str): name under which the figure is saved
    """


    statstr_pr = sess_str_util.stat_par_str(
        analyspar["stats"], analyspar["error"], "print")
    stimstr_pr = sess_str_util.stim_par_str(
        stimpar["stimtype"], stimpar["visflow_dir"], stimpar["visflow_size"], 
        stimpar["gabk"], "print")
    dendstr_pr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"], "print")

    sessstr = sess_str_util.sess_par_str(
        sesspar["sess_n"], stimpar["stimtype"], sesspar["plane"], 
        stimpar["visflow_dir"],stimpar["visflow_size"], stimpar["gabk"]) 
    dendstr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"])
     
    datatype = extrapar["datatype"]
    if datatype == "roi":
        fluorstr_pr = sess_str_util.fluor_par_str(
            analyspar["fluor"], str_type="print")
        if autocorrpar["byitem"]:
            title_str = u"{}\nautocorrelation".format(fluorstr_pr)
        else:
            title_str = "\nautocorr. acr. ROIs" .format(fluorstr_pr)

    elif datatype == "run":
        datastr = sess_str_util.datatype_par_str(datatype)
        title_str = u"\n{} autocorrelation".format(datastr)

    if stimpar["stimtype"] == "gabors":
        seq_bars = [-1.5, 1.5] # light lines
    else:
        seq_bars = [-1.0, 1.0] # light lines

    # extract some info from sess_info
    keys = ["mouse_ns", "sess_ns", "lines", "planes"]
    [mouse_ns, sess_ns, lines, planes] = [sess_info[key] for key in keys]
    nroi_strs = sess_str_util.get_nroi_strs(sess_info, empty=(datatype!="roi")) 

    n_sess = len(mouse_ns)

    xrans = autocorr_data["xrans"]
    stats = [np.asarray(stat) for stat in autocorr_data["stats"]]

    lag_s = autocorrpar["lag_s"]
    xticks = np.linspace(-lag_s, lag_s, lag_s*2+1)
    yticks = np.linspace(0, 1, 6)

    if figpar is None:
        figpar = sess_plot_util.init_figpar()

    byitemstr = ""
    if autocorrpar["byitem"]:
        byitemstr = "_byroi"

    fig, ax = plot_util.init_fig(n_sess, **figpar["init"])
    for i in range(n_sess):
        sub_ax = plot_util.get_subax(ax, i)
        title = (f"Mouse {mouse_ns[i]} - {stimstr_pr}, "
            u"{} ".format(statstr_pr) + f"{title_str} (sess "
            f"{sess_ns[i]}, {lines[i]} {planes[i]}{dendstr_pr}{nroi_strs[i]})")
        # transpose to ROI/lag x stats x series
        sess_stats = stats[i].transpose(1, 0, 2) 
        for s, sub_stats in enumerate(sess_stats):
            lab = None
            if not autocorrpar["byitem"]:
                lab = ["actual lag", "10x lag"][s]

            plot_util.plot_traces(
                sub_ax, xrans[i], sub_stats[0], sub_stats[1:], xticks=xticks, 
                yticks=yticks, alpha=0.2, label=lab)

        plot_util.add_bars(sub_ax, hbars=seq_bars)
        sub_ax.set_ylim([0, 1])
        sub_ax.set_title(title, y=1.02)
        if sub_ax.is_last_row():
            sub_ax.set_xlabel("Lag (s)")

    plot_util.turn_off_extra(ax, n_sess)

    if savedir is None:
        savedir = Path(
            figpar["dirs"][datatype], 
            figpar["dirs"]["autocorr"])

    savename = (f"{datatype}_autocorr{byitemstr}_{sessstr}{dendstr}")

    fulldir = plot_util.savefig(fig, savename, savedir, **figpar["save"])

    return fulldir, savename
Example #7
0
def plot_glm_expl_var(analyspar, sesspar, stimpar, extrapar, glmpar,
                      sess_info, all_expl_var, figpar=None, savedir=None):
    """
    plot_glm_expl_var(analyspar, sesspar, stimpar, extrapar, 
                       sess_info, all_expl_var)

    From dictionaries, plots explained variance for different variables for 
    each ROI.

    Required args:
        - analyspar (dict)    : dictionary with keys of AnalysPar namedtuple
        - sesspar (dict)      : dictionary with keys of SessPar namedtuple 
        - stimpar (dict)      : dictionary with keys of StimPar namedtuple
        - glmpar (dict)       : dictionary with keys of GLMPar namedtuple
        - extrapar (dict)     : dictionary containing additional analysis 
                                parameters
            ["analysis"] (str): analysis type (e.g., "v")
        - sess_info (dict)    : dictionary containing information from each
                                session 
            ["mouse_ns"] (list)   : mouse numbers
            ["sess_ns"] (list)    : session numbers  
            ["lines"] (list)      : mouse lines
            ["planes"] (list)     : imaging planes
            ["nrois"] (list)      : number of ROIs in session

        - all_expl_var (list) : list of dictionaries with explained variance 
                                for each session set, with each glm 
                                coefficient as a key:
            ["full"] (list)    : list of full explained variance stats for 
                                 every ROI, structured as ROI x stats
            ["coef_all"] (dict): max explained variance for each ROI with each
                                 coefficient as a key, structured as ROI x stats
            ["coef_uni"] (dict): unique explained variance for each ROI with 
                                 each coefficient as a key, 
                                 structured as ROI x stats
            ["rois"] (list)    : ROI numbers (-1 for GLMs fit to 
                                 mean/median ROI activity)
    
    Optional args:
        - figpar (dict): dictionary containing the following figure parameter 
                         dictionaries
                         default: None
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters
        - savedir (str): path of directory in which to save plots.
                         default: None
    
    Returns:
        - fulldir (str) : final path of the directory in which the figure is 
                          saved (may differ from input savedir, if datetime 
                          subfolder is added.)
        - savename (str): name under which the figure is saved
    """

    stimstr_pr = sess_str_util.stim_par_str(
        stimpar["stimtype"], stimpar["visflow_dir"], stimpar["visflow_size"], 
        stimpar["gabk"], "print")
    dendstr_pr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], "roi", "print")

    sessstr = sess_str_util.sess_par_str(
        sesspar["sess_n"], stimpar["stimtype"], sesspar["plane"], 
        stimpar["visflow_dir"], stimpar["visflow_size"], stimpar["gabk"]) 
    dendstr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], "roi")

    # extract some info from sess_info
    keys = ["mouse_ns", "sess_ns", "lines", "planes"]
    [mouse_ns, sess_ns, lines, planes] = [sess_info[key] for key in keys]

    n_sess = len(mouse_ns)
    
    nroi_strs = sess_str_util.get_nroi_strs(sess_info, style="par")

    plot_bools = [ev["rois"] not in [[-1], "all"] for ev in all_expl_var]
    n_sess = sum(plot_bools)

    if stimpar["stimtype"] == "gabors":
        xyzc_dims = ["unexpected", "gabor_frame", "run_data", "pup_diam_data"]
        log_dims = xyzc_dims + ["gabor_mean_orientation"]
    elif stimpar["stimtype"] == "visflow":
        xyzc_dims = [
            "unexpected", "main_flow_direction", "run_data", "pup_diam_data"
            ]
        log_dims = xyzc_dims
    
    # start plotting
    logger.info("Plotting GLM full and unique explained variance for "
        f"{', '.join(xyzc_dims)}.", extra={"spacing": "\n"})

    if n_sess > 0:
        if figpar is None:
            figpar = sess_plot_util.init_figpar()

        figpar = copy.deepcopy(figpar)
        cmap = plot_util.linclab_colormap(nbins=100, no_white=True)

        if figpar["save"]["use_dt"] is None:
            figpar["save"]["use_dt"] = gen_util.create_time_str()
        figpar["init"]["ncols"] = n_sess
        figpar["init"]["sharex"] = False
        figpar["init"]["sharey"] = False
        figpar["init"]["gs"] = {"wspace": 0.2, "hspace": 0.35}
        figpar["save"]["fig_ext"] = "png"
        
        fig, ax = plot_util.init_fig(2 * n_sess, **figpar["init"], proj="3d")

        fig.suptitle("Explained variance per ROI", y=1)

        # get colormap range
        c_range = [np.inf, -np.inf]
        c_key = xyzc_dims[3]

        for expl_var in all_expl_var:
            for var_type in ["coef_all", "coef_uni"]:
                rs = np.where(np.asarray(expl_var["rois"]) != -1)[0]
                if c_key in expl_var[var_type].keys():
                    c_data = np.asarray(expl_var[var_type][c_key])[rs, 0]
                    # adjust colormap range
                    c_range[0] = np.min([c_range[0], min(c_data)])
                    c_range[1] = np.max([c_range[1], max(c_data)])
        
        if not np.isfinite(sum(c_range)):
            c_range = [-0.5, 0.5] # dummy range
        else:
            c_range = plot_util.rounded_lims(c_range, out=True)

    else:
        logger.info("No plots, as only results across ROIs are included")
        fig = None

    i = 0
    for expl_var in all_expl_var:
        # collect info for plotting and logging results across ROIs
        rs = np.where(np.asarray(expl_var["rois"]) != -1)[0]
        all_rs = np.where(np.asarray(expl_var["rois"]) == -1)[0]
        if len(all_rs) != 1:
            raise RuntimeError("Expected only one result for all ROIs.")
        else:
            all_rs = all_rs[0]
            full_ev = expl_var["full"][all_rs]

        title = (f"Mouse {mouse_ns[i]} - {stimstr_pr}\n(sess {sess_ns[i]}, "
                f"{lines[i]} {planes[i]}{dendstr_pr},{nroi_strs[i]})")
        logger.info(title, extra={"spacing": "\n"})

        math_util.log_stats(full_ev, stat_str="\nFull explained variance")

        dim_length = max([len(dim) for dim in log_dims])
        
        for v, var_type in enumerate(["coef_all", "coef_uni"]):
            if var_type == "coef_all":
                sub_title = "Explained variance per coefficient"
            elif var_type == "coef_uni":
                sub_title = "Unique explained variance\nper coefficient"
            logger.info(sub_title, extra={"spacing": "\n"})

            dims_all = []
            for key in log_dims:
                if key in xyzc_dims:
                    # get mean/med
                    if key not in expl_var[var_type].keys():
                        dims_all.append("dummy")
                        continue

                    dims_all.append(np.asarray(expl_var[var_type][key])[rs, 0])
                math_util.log_stats(
                    expl_var[var_type][key][all_rs], 
                    stat_str=key.ljust(dim_length), log_spacing=TAB
                    )

            if not plot_bools[-1]:
                continue

            if v == 0:
                y = 1.12
                subpl_title = f"{title}\n{sub_title}"
            else:
                y = 1.02
                subpl_title = sub_title

            # retrieve values and names for each dimension, including dummy 
            # dimensions
            use_xyzc_dims = []
            n_vals = None
            dummies = []
            pads = [16, 16, 20]
            for d, dim in enumerate(dims_all):
                dim_name = xyzc_dims[d].replace("_", " ")
                if " direction"  in dim_name:
                    dim_name = dim_name.replace(" direction", "\ndirection")
                    pads[d] = 24
                if isinstance(dim, str) and dim == "dummy":
                    dummies.append(d)
                    use_xyzc_dims.append(f"{dim_name} (dummy)")
                else:
                    n_vals = len(dim)
                    use_xyzc_dims.append(dim_name)
            
            for d in dummies:
                dims_all[d] = np.zeros(n_vals)

            [x_data, y_data, z_data, c_data] = dims_all

            sub_ax = ax[v, i]
            im = sub_ax.scatter(
                x_data, y_data, z_data, c=c_data, cmap=cmap, 
                vmin=c_range[0], vmax=c_range[1]
                )
            sub_ax.set_title(subpl_title, y=y)
            # sub_ax.set_zlim3d(0, 1.0)

            # adjust padding for z axis
            sub_ax.tick_params(axis='z', which='major', pad=10)

            # add labels
            sub_ax.set_xlabel(use_xyzc_dims[0], labelpad=pads[0])
            sub_ax.set_ylabel(use_xyzc_dims[1], labelpad=pads[1])
            sub_ax.set_zlabel(use_xyzc_dims[2], labelpad=pads[2])

            if v == 0:
                full_ev_lab = math_util.log_stats(
                    full_ev, stat_str="Full EV", ret_str_only=True
                    )
                sub_ax.plot([], [], c="k", label=full_ev_lab)
                sub_ax.legend()

        i += 1

    if fig is not None:
        plot_util.add_colorbar(
            fig, im, n_sess, label=use_xyzc_dims[3],
            space_fact=np.max([2, n_sess])
            )

        # plot 0 planes, and lines
        for sub_ax in ax.reshape(-1):
            sub_ax.autoscale(False)
            all_lims = [sub_ax.get_xlim(), sub_ax.get_ylim(), sub_ax.get_zlim()]
            xs, ys, zs = [
                [vs[0] - (vs[1] - vs[0]) * 0.02, vs[1] + (vs[1] - vs[0]) * 0.02]
                for vs in all_lims
                ]
            
            for plane in ["x", "y", "z"]:
                if plane == "x":
                    xx, yy = np.meshgrid(xs, ys)
                    zz = xx * 0
                    x_flat = xs
                    y_flat, z_flat = [0, 0], [0, 0]
                elif plane == "y":
                    yy, zz = np.meshgrid(ys, zs)
                    xx = yy * 0
                    y_flat = ys
                    z_flat, x_flat = [0, 0], [0, 0]
                elif plane == "z":
                    zz, xx = np.meshgrid(zs, xs)
                    yy = zz * 0
                    z_flat = zs
                    x_flat, y_flat = [0, 0], [0, 0]
                
                sub_ax.plot_surface(xx, yy, zz, alpha=0.05, color="k")
                sub_ax.plot(
                    x_flat, y_flat, z_flat, alpha=0.4, color="k", ls=(0, (2, 2))
                    )

    if savedir is None:
        savedir = Path(
            figpar["dirs"]["roi"],
            figpar["dirs"]["glm"])

    savename = (f"roi_glm_ev_{sessstr}{dendstr}")

    fulldir = plot_util.savefig(fig, savename, savedir, **figpar["save"])

    return fulldir, savename                              
Example #8
0
def plot_traces_by_qu_lock_sess(analyspar, sesspar, stimpar, extrapar, 
                                quantpar, sess_info, trace_stats, 
                                figpar=None, savedir=None, modif=False):
    """
    plot_traces_by_qu_lock_sess(analyspar, sesspar, stimpar, extrapar, 
                                quantpar, sess_info, trace_stats)

    From dictionaries, plots traces by quantile, locked to transitions from 
    unexpected to expected or v.v. with each session in a separate subplot.
    
    Returns figure name and save directory path.
    
    Required args:
        - analyspar (dict)  : dictionary with keys of AnalysPar namedtuple
        - sesspar (dict)    : dictionary with keys of SessPar namedtuple
        - stimpar (dict)    : dictionary with keys of StimPar namedtuple
        - extrapar (dict)   : dictionary containing additional analysis 
                              parameters
            ["analysis"] (str): analysis type (e.g., "l")
            ["datatype"] (str): datatype (e.g., "run", "roi")
        - quantpar (dict)   : dictionary with keys of QuantPar namedtuple
        - sess_info (dict)  : dictionary containing information from each
                              session 
            ["mouse_ns"] (list)   : mouse numbers
            ["sess_ns"] (list)    : session numbers  
            ["lines"] (list)      : mouse lines
            ["planes"] (list)     : imaging planes
            if datatype == 
            ["nrois"] (list)      : number of ROIs in session

        - trace_stats (dict): dictionary containing trace stats information
            ["xrans"] (list)           : time values for the 2p frames for each 
                                         session
            ["all_stats"] (list)       : list of 4D arrays or lists of trace 
                                         data statistics across ROIs for each 
                                         session, structured as:
                                            (unexp_len x) quantiles x
                                            stats (me, err) x frames
            ["all_counts"] (array-like): number of sequences, structured as:
                                                sess x (unexp_len x) quantiles
            ["lock"] (str)             : value to which segments are locked:
                                         "unexp", "exp" or "unexp_split"
            ["baseline"] (num)         : number of seconds used for baseline
            ["exp_stats"] (list)       : list of 3D arrays or lists of trace 
                                         data statistics across ROIs for
                                         expected sampled sequences, 
                                         structured as:
                                            quantiles (1) x stats (me, err) 
                                            x frames
            ["exp_counts"] (array-like): number of sequences corresponding to
                                         exp_stats, structured as:
                                            sess x quantiles (1)
            
            if data is by unexp_len:
            ["unexp_lens"] (list)       : number of consecutive segments for
                                         each unexp_len, structured by session
                
    Optional args:
        - figpar (dict): dictionary containing the following figure parameter 
                         dictionaries
                         default: None
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters
        - savedir (str): path of directory in which to save plots.
                         default: None   
        - modif (bool) : if True, modified (slimmed-down) plots are created
                         instead
                         default: False
    
    Returns:
        - fulldir (str) : final path of the directory in which the figure is 
                          saved (may differ from input savedir, if datetime 
                          subfolder is added.)
        - savename (str): name under which the figure is saved
    """
    analyspar["dend"] = None
    stimstr_pr = sess_str_util.stim_par_str(
        stimpar["stimtype"], stimpar["visflow_dir"], stimpar["visflow_size"],
        stimpar["gabk"], "print")
    statstr_pr = sess_str_util.stat_par_str(
        analyspar["stats"], analyspar["error"], "print")
    dendstr_pr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"], "print")
        
    sessstr = sess_str_util.sess_par_str(
        sesspar["sess_n"], stimpar["stimtype"], sesspar["plane"], 
        stimpar["visflow_dir"], stimpar["visflow_size"], stimpar["gabk"])
    dendstr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"])
     
    basestr = sess_str_util.base_par_str(trace_stats["baseline"])
    basestr_pr = sess_str_util.base_par_str(trace_stats["baseline"], "print")

    datatype = extrapar["datatype"]
    dimstr = sess_str_util.datatype_dim_str(datatype)

    # extract some info from sess_info
    keys = ["mouse_ns", "sess_ns", "lines", "planes"]
    [mouse_ns, sess_ns, lines, planes] = [sess_info[key] for key in keys]
    nroi_strs = sess_str_util.get_nroi_strs(sess_info, empty=(datatype!="roi")) 

    n_sess = len(mouse_ns)

    xrans      = [np.asarray(xran) for xran in trace_stats["xrans"]]
    all_stats  = [np.asarray(sessst) for sessst in trace_stats["all_stats"]]
    exp_stats  = [np.asarray(expst) for expst in trace_stats["exp_stats"]]
    all_counts = trace_stats["all_counts"]
    exp_counts = trace_stats["exp_counts"]

    lock  = trace_stats["lock"]
    col_idx = 0
    if "unexp" in lock:
        lock = "unexp"
        col_idx = 1
    
    # plot unexp_lens default values
    if stimpar["stimtype"] == "gabors":
        DEFAULT_UNEXP_LEN = [3.0, 4.5, 6.0]
        if stimpar["gabfr"] not in ["any", "all"]:
            offset = sess_str_util.gabfr_nbrs(stimpar["gabfr"])
    else:
        DEFAULT_UNEXP_LEN = [2.0, 3.0, 4.0]
    
    offset = 0
    unexp_lab, len_ext = "", ""
    unexp_lens = [[None]] * n_sess
    unexp_len_default = True
    if "unexp_lens" in trace_stats.keys():
        unexp_len_default = False
        unexp_lens = trace_stats["unexp_lens"]
        len_ext = "_bylen"
        
        if stimpar["stimtype"] == "gabors":
            unexp_lens = [
                [sl * 1.5/5 - 0.3 * offset for sl in sls] for sls in unexp_lens
                ]

    inv = 1 if lock == "unexp" else -1
    # RANGE TO PLOT
    if modif:
        st_val = -2.0
        end_val  = 6.0
        n_ticks = int((end_val - st_val) // 2 + 1)
    else:
        n_ticks = 21

    if figpar is None:
        figpar = sess_plot_util.init_figpar()
    figpar = copy.deepcopy(figpar)
    if modif:
        figpar["init"]["subplot_wid"] = 6.5
    else:
        figpar["init"]["subplot_wid"] *= 2

    fig, ax = plot_util.init_fig(n_sess, **figpar["init"])
    exp_min, exp_max = np.inf, -np.inf
    for i, (stats, counts) in enumerate(zip(all_stats, all_counts)):
        sub_ax = plot_util.get_subax(ax, i)

        # plot expected data
        if exp_stats[i].shape[0] != 1:
            raise ValueError("Expected only one quantile for exp_stats.")

        n_lines = quantpar["n_quants"] * len(unexp_lens[i])
        cols = sess_plot_util.get_quant_cols(n_lines)[0][col_idx]
        if len(cols) < n_lines:
            cols = [None] * n_lines

        if modif:
            line = "2/3" if "23" in lines[i] else "5"
            plane = "somata" if "soma" in planes[i] else "dendrites"
            title = f"M{mouse_ns[i]} - layer {line} {plane}{dendstr_pr}"
            lab = "exp" if i == 0 else None
            y_ax = None if i == 0 else ""

            st, end = 0, len(xrans[i])
            st_vals = list(filter(
                lambda j: xrans[i][j] <= st_val, range(len(xrans[i]))
                ))
            end_vals = list(filter(
                lambda j: xrans[i][j] >= end_val, range(len(xrans[i]))
                ))
            if len(st_vals) != 0:
                st = st_vals[-1]
            if len(end_vals) != 0:
                end = end_vals[0] + 1
            time_slice = slice(st, end)

        else:
            title = (f"Mouse {mouse_ns[i]} - {stimstr_pr}, "
                u"{} ".format(statstr_pr) + f"{lock} locked across {dimstr}"
                f"{basestr_pr}\n(sess {sess_ns[i]}, {lines[i]} {planes[i]}"
                f"{dendstr_pr}{nroi_strs[i]})")
            lab = f"exp (no lock) ({exp_counts[i][0]})"
            y_ax = None
            st = 0
            end = len(xrans[i])
            time_slice = slice(None) # use all

        # add length markers
        use_unexp_lens = unexp_lens[i]
        if unexp_len_default:
            use_unexp_lens = DEFAULT_UNEXP_LEN
        leng_col = sess_plot_util.get_quant_cols(1)[0][col_idx][0]
        for leng in use_unexp_lens:
            if leng is None:
                continue
            edge = leng * inv
            if edge < 0:
                edge = np.max([xrans[i][st], edge])
            elif edge > 0:
                edge = np.min([xrans[i][end - 1], edge])
            plot_util.add_vshade(
                sub_ax, 0, edge, color=leng_col, alpha=0.1)

        sess_plot_util.add_axislabels(
            sub_ax, fluor=analyspar["fluor"], datatype=datatype, y_ax=y_ax
            )
        plot_util.add_bars(sub_ax, hbars=0)
        alpha = np.min([0.4, 0.8 / n_lines])

        if stimpar["stimtype"] == "gabors":
            sess_plot_util.plot_gabfr_pattern(
                sub_ax, xrans[i], offset=offset, bars_omit=[0] + unexp_lens[i]
                )

        plot_util.plot_traces(
            sub_ax, xrans[i][time_slice], exp_stats[i][0][0, time_slice], 
            exp_stats[i][0][1:, time_slice], n_xticks=n_ticks,
            alpha=alpha, label=lab, alpha_line=0.8, color="darkgray", 
            xticks="auto")

        # get expected data range to adjust y lims
        exp_min = np.min([exp_min, np.nanmin(exp_stats[i][0][0])])
        exp_max = np.max([exp_max, np.nanmax(exp_stats[i][0][0])])

        n = 0 # count lines plotted
        for s, unexp_len in enumerate(unexp_lens[i]):
            if unexp_len is not None:
                counts, stats = all_counts[i][s], all_stats[i][s]       
                # remove offset   
                unexp_lab = f"unexp len {unexp_len + 0.3 * offset}"
            else:
                unexp_lab = "unexp" if modif else f"{lock} lock"
            for q, qu_idx in enumerate(quantpar["qu_idx"]):
                qu_lab = ""
                if quantpar["n_quants"] > 1:
                    qu_lab = "{} ".format(sess_str_util.quantile_str(
                        qu_idx, quantpar["n_quants"], str_type="print"
                        ))
                lab = f"{qu_lab}{unexp_lab}"
                if modif:
                    lab = lab if i == 0 else None
                else:
                    lab = f"{lab} ({counts[q]})"
                if n == 2 and cols[n] is None:
                    sub_ax.plot([], []) # to advance the color cycle (past gray)
                plot_util.plot_traces(sub_ax, xrans[i][time_slice], 
                    stats[q][0, time_slice], stats[q][1:, time_slice], title, 
                    alpha=alpha, label=lab, n_xticks=n_ticks, alpha_line=0.8, 
                    color=cols[n], xticks="auto")
                n += 1
            if unexp_len is not None:
                plot_util.add_bars(
                    sub_ax, hbars=unexp_len, color=sub_ax.lines[-1].get_color(), 
                    alpha=1)
    
    plot_util.turn_off_extra(ax, n_sess)

    if savedir is None:
        savedir = Path(
            figpar["dirs"][datatype], 
            figpar["dirs"]["unexp_qu"], 
            f"{lock}_lock", basestr.replace("_", ""))

    if not modif:
        if stimpar["stimtype"] == "visflow":
            plot_util.rel_confine_ylims(sub_ax, [exp_min, exp_max], 5)

    qu_str = f"_{quantpar['n_quants']}q"
    if quantpar["n_quants"] == 1:
        qu_str = ""
 
    savename = (f"{datatype}_av_{lock}_lock{len_ext}{basestr}_{sessstr}"
        f"{dendstr}{qu_str}")
    fulldir = plot_util.savefig(fig, savename, savedir, **figpar["save"])

    return fulldir, savename
Example #9
0
def plot_mag_change(analyspar, sesspar, stimpar, extrapar, permpar, quantpar, 
                    sess_info, mags, figpar=None, savedir=None):
    """
    plot_mag_change(analyspar, sesspar, stimpar, extrapar, permpar, quantpar, 
                    sess_info, mags) 

    From dictionaries, plots magnitude of change in unexpected and expected
    responses across quantiles.

    Returns figure name and save directory path.

    Required args:
        - analyspar (dict): dictionary with keys of AnalysPar namedtuple
        - sesspar (dict)  : dictionary with keys of SessPar namedtuple 
        - stimpar (dict)  : dictionary with keys of StimPar namedtuple
        - extrapar (dict) : dictionary containing additional analysis 
                            parameters
            ["analysis"] (str): analysis type (e.g., "m")
            ["datatype"] (str): datatype (e.g., "run", "roi")
            ["seed"]     (int): seed value used
        - permpar (dict)  : dictionary with keys of PermPar namedtuple 
        - quantpar (dict) : dictionary with keys of QuantPar namedtuple
        - roigrppar (dict): dictionary with keys of RoiGrpPar namedtuple
        - sess_info (dict): dictionary containing information from each
                            session 
            ["mouse_ns"] (list)   : mouse numbers
            ["sess_ns"] (list)    : session numbers  
            ["lines"] (list)      : mouse lines
            ["planes"] (list)     : imaging planes
            ["nrois"] (list)      : number of ROIs in session

        - mags (dict)     : dictionary containing magnitude data to plot
            ["L2"] (array-like)    : nested list containing L2 norms, 
                                     structured as: 
                                         sess x scaling x unexp
            ["L2_sig"] (list)      : L2 significance results for each session 
                                         ("hi", "lo" or "no")
            ["mag_sig"] (list)     : magnitude significance results for each 
                                     session 
                                         ("hi", "lo" or "no")
            ["mag_st"] (array-like): array or nested list containing magnitude 
                                     stats across ROIs, structured as: 
                                         sess x scaling x unexp x stats

    Optional args:
        - figpar (dict): dictionary containing the following figure parameter 
                         dictionaries
                         default: None
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters
        - savedir (str): path of directory in which to save plots.
                         default: None
    
    Returns:
        - fulldir (str) : final path of the directory in which the figure is 
                          saved (may differ from input savedir, if datetime 
                          subfolder is added.)
        - savename (str): name under which the figure is saved
    """
    
    sessstr_pr = sess_str_util.sess_par_str(
        sesspar["sess_n"], stimpar["stimtype"], sesspar["plane"], 
        stimpar["visflow_dir"], stimpar["visflow_size"], stimpar["gabk"], 
        "print")
    statstr_pr = sess_str_util.stat_par_str(
        analyspar["stats"], analyspar["error"], "print")
    dendstr_pr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"], "print")
        
    sessstr = sess_str_util.sess_par_str(
        sesspar["sess_n"], stimpar["stimtype"], sesspar["plane"], 
        stimpar["visflow_dir"],stimpar["visflow_size"], stimpar["gabk"]) 
    dendstr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"])
     
    datatype = extrapar["datatype"]
    dimstr = sess_str_util.datatype_dim_str(datatype)
    
    # extract some info from sess_info
    keys = ["mouse_ns", "sess_ns", "lines", "planes"]
    [mouse_ns, sess_ns, lines, planes] = [sess_info[key] for key in keys]
    nroi_strs = sess_str_util.get_nroi_strs(
        sess_info, empty=(datatype!="roi"), style="par"
        )    

    n_sess = len(mouse_ns)

    qu_ns = [gen_util.pos_idx(q, quantpar["n_quants"]) + 1 
        for q in quantpar["qu_idx"]]
    if len(qu_ns) != 2:
        raise ValueError(f"Expected 2 quantiles, not {len(qu_ns)}.")
    
    mag_st = np.asarray(mags["mag_st"])

    if figpar is None:
        figpar = sess_plot_util.init_figpar()

    figpar = copy.deepcopy(figpar)
    if figpar["save"]["use_dt"] is None:
        figpar["save"]["use_dt"] = gen_util.create_time_str()
 
    figpar["init"]["subplot_wid"] *= n_sess/2.0
    
    scales = [False, True]

    # get plot elements
    barw = 0.75
    # scaling strings for printing and filenames
    leg = ["exp", "unexp"]    
    cent, bar_pos, xlims = plot_util.get_barplot_xpos(n_sess, len(leg), barw)   
    title = (u"Magnitude ({}) of difference in activity".format(statstr_pr) +
        f"\nbetween Q{qu_ns[0]} and {qu_ns[1]} across {dimstr} "
        f"\n({sessstr_pr})")
    labels = [f"Mouse {mouse_ns[i]} sess {sess_ns[i]},\n {lines[i]} {planes[i]}"
        f"{dendstr_pr}{nroi_strs[i]}" for i in range(n_sess)]

    figs, axs = [], []
    for sc, scale in enumerate(scales):
        scalestr_pr = sess_str_util.scale_par_str(scale, "print")
        fig, ax = plot_util.init_fig(1, **figpar["init"])
        figs.append(fig)
        axs.append(ax)
        sub_ax = ax[0, 0]
        # always set ticks (even again) before setting labels
        sub_ax.set_xticks(cent)
        sub_ax.set_xticklabels(labels)
        title_scale = u"{}{}".format(title, scalestr_pr)
        sess_plot_util.add_axislabels(
            sub_ax, fluor=analyspar["fluor"], area=True, scale=scale, x_ax="", 
            datatype=datatype)
        for s, lab in enumerate(leg):
            xpos = list(zip(*bar_pos))[s]
            plot_util.plot_bars(
                sub_ax, xpos, mag_st[:, sc, s, 0], err=mag_st[:, sc, s, 1:], 
                width=barw, xlims=xlims, xticks="None", label=lab, capsize=4,
                title=title_scale, hline=0)
    
    # add significance markers
    for i in range(n_sess):
        signif = mags["mag_sig"][i]
        if signif in ["hi", "lo"]:
            xpos = bar_pos[i]
            for sc, (ax, scale) in enumerate(zip(axs, scales)):
                yval = mag_st[i, sc, :, 0]
                yerr = mag_st[i, sc, :, 1:]
                plot_util.plot_barplot_signif(ax[0, 0], xpos, yval, yerr)
    
    plot_util.turn_off_extra(ax, n_sess)

   # figure directory
    if savedir is None:
        savedir = Path(
            figpar["dirs"][datatype], 
            figpar["dirs"]["unexp_qu"], 
            figpar["dirs"]["mags"])
    
    log_dir = False
    for i, (fig, scale) in enumerate(zip(figs, scales)):
        if i == len(figs) - 1:
            log_dir = True
        scalestr = sess_str_util.scale_par_str(scale)
        savename = f"{datatype}_mag_diff_{sessstr}{dendstr}"
        savename_full = f"{savename}{scalestr}"
        fulldir = plot_util.savefig(
            fig, savename_full, savedir, log_dir=log_dir, ** figpar["save"])

    return fulldir, savename
Example #10
0
def plot_full_traces(analyspar, sesspar, extrapar, sess_info, trace_info, 
                     roi_tr=None, figpar=None, savedir=None):
    """
    plot_full_traces(analyspar, sesspar, extrapar, sess_info, trace_info)

    From dictionaries, plots full traces for each session in a separate subplot.
    
    Returns figure name and save directory path.
    
    Required args:
        - analyspar (dict)  : dictionary with keys of AnalysPar namedtuple
        - sesspar (dict)    : dictionary with keys of SessPar namedtuple
        - extrapar (dict)   : dictionary containing additional analysis 
                              parameters
            ["analysis"] (str): analysis type (e.g., "f")
            ["datatype"] (str): datatype (e.g., "run", "roi")
        - sess_info (dict)  : dictionary containing information from each
                              session 
            ["mouse_ns"] (list)   : mouse numbers
            ["sess_ns"] (list)    : session numbers  
            ["lines"] (list)      : mouse lines
            ["planes"] (list)     : imaging planes
            ["nrois"] (list)      : number of ROIs in session

        - trace_info (dict): dictionary containing trace information
            ["all_tr"] (nested list): trace values structured as
                                          sess x 
                                          (me/err if datatype is "roi" x)
                                          frames
            ["all_edges"] (list)    : frame edge values for each parameter, 
                                      structured as sess x block x 
                                                    edges ([start, end])
            ["all_pars"] (list)     : stimulus parameter strings structured as 
                                                    sess x block
                
    Optional args:
        - roi_tr (list): trace values for each ROI, structured as 
                         sess x ROI x frames
                         default: None
        - figpar (dict): dictionary containing the following figure parameter 
                         dictionaries
                         default: None
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters
        - savedir (str): path of directory in which to save plots.
                         default: None    
    
    Returns:
        - fulldir (str) : final path of the directory in which the figure is 
                          saved (may differ from input savedir, if datetime 
                          subfolder is added.)
        - savename (str): name under which the figure is saved
    """
 
    statstr_pr = sess_str_util.stat_par_str(
        analyspar["stats"], analyspar["error"], "print")
    dendstr_pr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"], "print")
        
    sessstr = f"sess{sesspar['sess_n']}_{sesspar['plane']}"
    dendstr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"])

    datatype = extrapar["datatype"]

    # extract some info from sess_info
    keys = ["mouse_ns", "sess_ns", "lines", "planes"]
    [mouse_ns, sess_ns, lines, planes] = [sess_info[key] for key in keys]
    nroi_strs = sess_str_util.get_nroi_strs(sess_info, empty=(datatype!="roi")) 
    n_sess = len(mouse_ns)

    if figpar is None:
        figpar = sess_plot_util.init_figpar()
    figpar = copy.deepcopy(figpar)
    figpar["init"]["subplot_wid"] = 10
    figpar["init"]["ncols"] = n_sess
    if datatype == "roi":
        figpar["save"]["fig_ext"] = "jpg"
        figpar["init"]["sharex"] = False
        figpar["init"]["sharey"] = False
        # set subplot ratios and removes extra space between plots vertically
        gs = {"height_ratios": [5, 1], "hspace": 0.1} 
        n_rows = 2
        if roi_tr is None:
            raise ValueError("Cannot plot data as ROI traces are missing "
                "(not recorded in dictionary, likely due to size).")
    else:
        gs = None
        n_rows = 1

    if datatype == "roi" and not figpar["save"]["save_fig"]:
        warnings.warn("Figure plotting is being skipped. Since full ROI traces "
            "are not saved to dictionary, to actually plot traces, analysis "
            "will have to be rerun with 'save_fig' set to True.", 
            stacklevel=1)

    fig, ax = plot_util.init_fig(n_sess*n_rows, gs=gs, **figpar["init"])

    label_height = 0.8
    if datatype == "roi":
        fig.subplots_adjust(top=0.92) # remove extra white space at top
        label_height = 0.55
    for i in range(n_sess):
        title = (f"Mouse {mouse_ns[i]} (sess {sess_ns[i]}, {lines[i]} "
            f"{planes[i]}{dendstr_pr}{nroi_strs[i]})")
        sub_axs = ax[:, i]
        sub_axs[0].set_title(title, y=1.02)
        if datatype == "roi":
            xran = range(len(trace_info["all_tr"][i][1]))   
            # each ROI (top subplot)
            plot_util.plot_sep_data(sub_axs[0], np.asarray(roi_tr[i]), 0.1)
            sess_plot_util.add_axislabels(
                sub_axs[0], fluor=analyspar["fluor"], scale=True, 
                datatype=datatype, x_ax="")
            
            # average across ROIs (bottom subplot)
            av_tr = np.asarray(trace_info["all_tr"][i])
            subtitle = u"{} across ROIs".format(statstr_pr)
            plot_util.plot_traces(
                sub_axs[1], xran, av_tr[0], av_tr[1:], lw=0.2, xticks="auto",
                title=subtitle
                )
        else:
            xran = range(len(trace_info["all_tr"][i]))
            run_tr = np.asarray(trace_info["all_tr"][i])
            sub_axs[0].plot(run_tr, lw=0.2)
        for b, block in enumerate(trace_info["all_edges"][i]):
            # all block labels to the lower plot
            plot_util.add_labels(
                sub_axs[-1], trace_info["all_pars"][i][b], np.mean(block), 
                label_height, color="k")
            sess_plot_util.add_axislabels(
                sub_axs[-1], fluor=analyspar["fluor"], datatype=datatype, 
                x_ax="")
            plot_util.remove_ticks(sub_axs[-1], True, False)
            plot_util.remove_graph_bars(sub_axs[-1], bars="horiz")
            # add lines to both plots
            for r in range(n_rows):
                plot_util.add_bars(sub_axs[r], bars=block)
                
    if savedir is None:
        savedir = Path(
            figpar["dirs"][datatype], 
            figpar["dirs"]["full"])

    y = 1 if datatype == "run" else 0.98
    fig.suptitle("Full traces across sessions", fontsize="xx-large", y=y)

    savename = f"{datatype}_tr_{sessstr}{dendstr}"
    fulldir = plot_util.savefig(
        fig, savename, savedir, dpi=400, **figpar["save"])


    return fulldir, savename
Example #11
0
def plot_traces_by_qu_unexp_sess(analyspar, sesspar, stimpar, extrapar, 
                                quantpar, sess_info, trace_stats, figpar=None, 
                                savedir=None, modif=False):
    """
    plot_traces_by_qu_unexp_sess(analyspar, sesspar, stimpar, extrapar, 
                                quantpar, sess_info, trace_stats)

    From dictionaries, plots traces by quantile/unexpected with each session in a 
    separate subplot.
    
    Returns figure name and save directory path.
    
    Required args:
        - analyspar (dict)  : dictionary with keys of AnalysPar namedtuple
        - sesspar (dict)    : dictionary with keys of SessPar namedtuple
        - stimpar (dict)    : dictionary with keys of StimPar namedtuple
        - extrapar (dict)   : dictionary containing additional analysis 
                              parameters
            ["analysis"] (str): analysis type (e.g., "t")
            ["datatype"] (str): datatype (e.g., "run", "roi")
        - quantpar (dict)   : dictionary with keys of QuantPar namedtuple
        - sess_info (dict)  : dictionary containing information from each
                              session 
            ["mouse_ns"] (list)   : mouse numbers
            ["sess_ns"] (list)    : session numbers  
            ["lines"] (list)      : mouse lines
            ["planes"] (list)     : imaging planes
            if extrapar["datatype"] == "roi":
            ["nrois"] (list)      : number of ROIs in session

        - trace_stats (dict): dictionary containing trace stats information
            ["xrans"] (list)           : time values for the frames, for each 
                                         session
            ["all_stats"] (list)       : list of 4D arrays or lists of trace 
                                         data statistics across ROIs for each
                                         session, structured as:
                                            sess x unexp x quantiles x
                                            stats (me, err) x frames
            ["all_counts"] (array-like): number of sequences, structured as:
                                                sess x unexp x quantiles
                
    Optional args:
        - figpar (dict): dictionary containing the following figure parameter 
                         dictionaries
                         default: None
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters
        - savedir (str): path of directory in which to save plots.
                         default: None    
        - modif (bool) : if True, modified (slimmed-down) plots are created
                         instead
                         default: False
    
    Returns:
        - fulldir (str) : final path of the directory in which the figure is 
                          saved (may differ from input savedir, if datetime 
                          subfolder is added.)
        - savename (str): name under which the figure is saved
    """
 
    stimstr_pr = sess_str_util.stim_par_str(
        stimpar["stimtype"], stimpar["visflow_dir"], stimpar["visflow_size"],
        stimpar["gabk"], "print")
    statstr_pr = sess_str_util.stat_par_str(
        analyspar["stats"], analyspar["error"], "print")
    dendstr_pr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"], "print")
        
    sessstr = sess_str_util.sess_par_str(
        sesspar["sess_n"], stimpar["stimtype"], sesspar["plane"], 
        stimpar["visflow_dir"], stimpar["visflow_size"], stimpar["gabk"])
    dendstr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"])
     
    datatype = extrapar["datatype"]
    dimstr = sess_str_util.datatype_dim_str(datatype)

    # extract some info from sess_info
    keys = ["mouse_ns", "sess_ns", "lines", "planes"]
    [mouse_ns, sess_ns, lines, planes] = [sess_info[key] for key in keys]
    nroi_strs = sess_str_util.get_nroi_strs(sess_info, empty=(datatype!="roi")) 

    n_sess = len(mouse_ns)

    xrans      = [np.asarray(xran) for xran in trace_stats["xrans"]]
    all_stats  = [np.asarray(sessst) for sessst in trace_stats["all_stats"]]
    all_counts = trace_stats["all_counts"]

    cols, lab_cols = sess_plot_util.get_quant_cols(quantpar["n_quants"])
    alpha = np.min([0.4, 0.8 / quantpar["n_quants"]])

    unexps = ["exp", "unexp"]
    n = 6
    if figpar is None:
        figpar = sess_plot_util.init_figpar()
    
    fig, ax = plot_util.init_fig(n_sess, **figpar["init"])
    for i in range(n_sess):
        sub_ax = plot_util.get_subax(ax, i)
        for s, [col, leg_ext] in enumerate(zip(cols, unexps)):
            for q, qu_idx in enumerate(quantpar["qu_idx"]):
                qu_lab = ""
                if quantpar["n_quants"] > 1:
                    qu_lab = "{} ".format(sess_str_util.quantile_str(
                        qu_idx, quantpar["n_quants"], str_type="print"
                        ))
                if modif:
                    line = "2/3" if "23" in lines[i] else "5"
                    plane = "somata" if "soma" in planes[i] else "dendrites"
                    title = f"M{mouse_ns[i]} - layer {line} {plane}{dendstr_pr}"
                    leg = f"{qu_lab}{leg_ext}" if i == 0 else None
                    y_ax = None if i == 0 else ""

                else:
                    title=(f"Mouse {mouse_ns[i]} - {stimstr_pr}, " 
                        u"{}\n".format(statstr_pr) + f"across {dimstr} (sess "
                        f"{sess_ns[i]}, {lines[i]} {planes[i]}{dendstr_pr}"
                        f"{nroi_strs[i]})")
                    leg = f"{qu_lab}{leg_ext} ({all_counts[i][s][q]})"
                    y_ax = None

                plot_util.plot_traces(
                    sub_ax, xrans[i], all_stats[i][s, q, 0], 
                    all_stats[i][s, q, 1:], title, color=col[q], alpha=alpha, 
                    label=leg, n_xticks=n, xticks="auto")
                sess_plot_util.add_axislabels(
                    sub_ax, fluor=analyspar["fluor"], datatype=datatype, 
                    y_ax=y_ax)
    
    plot_util.turn_off_extra(ax, n_sess)

    if stimpar["stimtype"] == "gabors": 
        sess_plot_util.plot_labels(
            ax, stimpar["gabfr"], "both", pre=stimpar["pre"], 
            post=stimpar["post"], cols=lab_cols, 
            sharey=figpar["init"]["sharey"])
    
    if savedir is None:
        savedir = Path(
            figpar["dirs"][datatype], 
            figpar["dirs"]["unexp_qu"])

    qu_str = f"_{quantpar['n_quants']}q"
    if quantpar["n_quants"] == 1:
        qu_str = ""

    savename = f"{datatype}_av_{sessstr}{dendstr}{qu_str}"
    fulldir = plot_util.savefig(fig, savename, savedir, **figpar["save"])

    return fulldir, savename
def plot_pup_diff_corr(analyspar, sesspar, stimpar, extrapar, 
                       sess_info, corr_data, figpar=None, savedir=None):
    """
    plot_pup_diff_corr(analyspar, sesspar, stimpar, extrapar, 
                       sess_info, corr_data)

    From dictionaries, plots correlation between unexpected-locked changes in 
    pupil diameter and running or ROI data for each session.

    Required args:
        - analyspar (dict)    : dictionary with keys of AnalysPar namedtuple
        - sesspar (dict)      : dictionary with keys of SessPar namedtuple 
        - stimpar (dict)      : dictionary with keys of StimPar namedtuple
        - extrapar (dict)     : dictionary containing additional analysis 
                                parameters
            ["analysis"] (str): analysis type (e.g., "c")
            ["datatype"] (str): datatype (e.g., "run", "roi")
        - sess_info (dict)    : dictionary containing information from each
                                session 
            ["mouse_ns"] (list)   : mouse numbers
            ["sess_ns"] (list)    : session numbers  
            ["lines"] (list)      : mouse lines
            ["planes"] (list)     : imaging planes
            ["nrois"] (list)      : number of ROIs in session
        - corr_data (dict)    : dictionary containing data to plot:
            ["corrs"] (list): list of correlation values between pupil and 
                              running or ROI differences for each session
            ["diffs"] (list): list of differences for each session, structured
                                  as [pupil, ROI/run] x trials x frames
    
    Optional args:
        - figpar (dict) : dictionary containing the following figure parameter 
                          dictionaries
                          default: None
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters
        - savedir (Path): path of directory in which to save plots.
                          default: None
    
    Returns:
        - fulldir (Path): final path of the directory in which the figure is 
                          saved (may differ from input savedir, if datetime 
                          subfolder is added.)
        - savename (str): name under which the figure is saved
    """
    statstr_pr = sess_str_util.stat_par_str(
        analyspar["stats"], analyspar["error"], "print")
    stimstr_pr = sess_str_util.stim_par_str(
        stimpar["stimtype"], stimpar["visflow_dir"], stimpar["visflow_size"], 
        stimpar["gabk"], "print")
    dendstr_pr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"], "print")

    sessstr = sess_str_util.sess_par_str(
        sesspar["sess_n"], stimpar["stimtype"], sesspar["plane"], 
        stimpar["visflow_dir"],stimpar["visflow_size"], stimpar["gabk"]) 
    dendstr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"])
    
    datatype = extrapar["datatype"]
    datastr = sess_str_util.datatype_par_str(datatype)

    if datatype == "roi":
        label_str = sess_str_util.fluor_par_str(
            analyspar["fluor"], str_type="print")
        full_label_str = u"{}, {} across ROIs".format(
            label_str, analyspar["stats"])
    elif datatype == "run":
        label_str = datastr
        full_label_str = datastr
    
    lab_app = (f" ({analyspar['stats']} over "
        f"{stimpar['pre']}/{stimpar['post']} sec)")

    logger.info(f"Plotting pupil vs {datastr} changes.")
    
    delta = "\u0394"

    # extract some info from sess_info
    keys = ["mouse_ns", "sess_ns", "lines", "planes"]
    [mouse_ns, sess_ns, lines, planes] = [sess_info[key] for key in keys]

    n_sess = len(mouse_ns)
    nroi_strs = sess_str_util.get_nroi_strs(
        sess_info, empty=(datatype!="roi"), style="comma"
        ) 

    if figpar is None:
        figpar = sess_plot_util.init_figpar()

    figpar = copy.deepcopy(figpar)
    if figpar["save"]["use_dt"] is None:
        figpar["save"]["use_dt"] = gen_util.create_time_str()
    figpar["init"]["sharex"] = False
    figpar["init"]["sharey"] = False
    figpar["init"]["ncols"] = n_sess
    
    fig, ax = plot_util.init_fig(2 * n_sess, **figpar["init"])
    suptitle = (f"Relationship between pupil diam. and {datastr} changes, "
        "locked to unexpected events")
    
    for i, sess_diffs in enumerate(corr_data["diffs"]):
        sub_axs = ax[:, i]
        title = (f"Mouse {mouse_ns[i]} - {stimstr_pr}, " + 
            u"{}".format(statstr_pr) + f"\n(sess {sess_ns[i]}, {lines[i]} "
            f"{planes[i]}{dendstr_pr}{nroi_strs[i]})")
        
        # top plot: correlations
        corr = f"Corr = {corr_data['corrs'][i]:.2f}"
        sub_axs[0].plot(
            sess_diffs[0], sess_diffs[1], marker=".", linestyle="None", 
            label=corr)
        sub_axs[0].set_title(title, y=1.01)
        sub_axs[0].set_xlabel(u"{} pupil diam.{}".format(delta, lab_app))
        if i == 0:
            sub_axs[0].set_ylabel(u"{} {}\n{}".format(
                delta, full_label_str, lab_app))
        sub_axs[0].legend()
        
        # bottom plot: differences across occurrences
        data_lab = u"{} {}".format(delta, label_str)   
        pup_lab = u"{} pupil diam.".format(delta)
        cols = []
        scaled = []
        for d, lab in enumerate([pup_lab, data_lab]):
            scaled.append(math_util.scale_data(
                np.asarray(sess_diffs[d]), sc_type="min_max")[0])
            art, = sub_axs[1].plot(scaled[-1], marker=".")
            cols.append(sub_axs[-1].lines[-1].get_color())
            if i == n_sess - 1: # only for last graph
                art.set_label(lab)
                sub_axs[1].legend()
        sub_axs[1].set_xlabel("Unexpected event occurrence")
        if i == 0:
            sub_axs[1].set_ylabel(
                u"{} response locked\nto unexpected onset (scaled)".format(delta))
        # shade area between lines
        plot_util.plot_btw_traces(
            sub_axs[1], scaled[0], scaled[1], color=cols, alpha=0.4)

    fig.suptitle(suptitle, fontsize="xx-large", y=1)

    if savedir is None:
        savedir = Path(
            figpar["dirs"][datatype],
            figpar["dirs"]["pupil"])

    savename = f"{datatype}_diff_corr_{sessstr}{dendstr}"

    fulldir = plot_util.savefig(fig, savename, savedir, **figpar["save"])

    return fulldir, savename                              
def plot_pup_roi_stim_corr(analyspar, sesspar, stimpar, extrapar, 
                           sess_info, corr_data, figpar=None, savedir=None):
    """
    plot_pup_roi_stim_corr(analyspar, sesspar, stimpar, extrapar, 
                           sess_info, corr_data)

    From dictionaries, plots correlation between unexpected-locked changes in 
    pupil diameter and each ROI, for gabors versus visual flow responses for 
    each session.
    
    Required args:
        - analyspar (dict)    : dictionary with keys of AnalysPar namedtuple
        - sesspar (dict)      : dictionary with keys of SessPar namedtuple 
        - stimpar (dict)      : dictionary with keys of StimPar namedtuple
        - extrapar (dict)     : dictionary containing additional analysis 
                                parameters
            ["analysis"] (str): analysis type (e.g., "r")
            ["datatype"] (str): datatype (e.g., "roi")
        - sess_info (dict)    : dictionary containing information from each
                                session 
            ["mouse_ns"] (list)   : mouse numbers
            ["sess_ns"] (list)    : session numbers  
            ["lines"] (list)      : mouse lines
            ["planes"] (list)     : imaging planes
            ["nrois"] (list)      : number of ROIs in session
        - corr_data (dict)    : dictionary containing data to plot:
            ["stim_order"] (list): ordered list of stimtypes
            ["roi_corrs"] (list) : nested list of correlations between pupil 
                                   and ROI responses changes locked to 
                                   unexpected, structured as 
                                       session x stimtype x ROI
            ["corrs"] (list)     : list of correlation between stimtype
                                   correlations for each session
    
    Optional args:
        - figpar (dict) : dictionary containing the following figure parameter 
                          dictionaries
                          default: None
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters
        - savedir (Path): path of directory in which to save plots.
                          default: None
    
    Returns:
        - fulldir (Path): final path of the directory in which the figure is 
                          saved (may differ from input savedir, if datetime 
                          subfolder is added.)
        - savename (str): name under which the figure is saved
    """

    stimstr_prs = []
    for stimtype in corr_data["stim_order"]:
        stimstr_pr = sess_str_util.stim_par_str(
            stimtype, stimpar["visflow_dir"], stimpar["visflow_size"], 
            stimpar["gabk"], "print")
        stimstr_pr = stimstr_pr[:-1] if stimstr_pr[-1] == "s" else stimstr_pr
        stimstr_prs.append(stimstr_pr)
        
    dendstr_pr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"], "print")

    sessstr = f"sess{sesspar['sess_n']}_{sesspar['plane']}" 
    dendstr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"])

    label_str = sess_str_util.fluor_par_str(
        analyspar["fluor"], str_type="print")
    lab_app = (f" ({analyspar['stats']} over "
        f"{stimpar['pre']}/{stimpar['post']} sec)")

    logger.info("Plotting pupil-ROI difference correlations for "
        "{} vs {}.".format(*corr_data["stim_order"]))

    # extract some info from sess_info
    keys = ["mouse_ns", "sess_ns", "lines", "planes"]
    [mouse_ns, sess_ns, lines, planes] = [sess_info[key] for key in keys]

    n_sess = len(mouse_ns)
    nroi_strs = sess_str_util.get_nroi_strs(sess_info, style="comma")

    if figpar is None:
        figpar = sess_plot_util.init_figpar()

    figpar = copy.deepcopy(figpar)
    if figpar["save"]["use_dt"] is None:
        figpar["save"]["use_dt"] = gen_util.create_time_str()
    figpar["init"]["sharex"] = True
    figpar["init"]["sharey"] = True
    
    fig, ax = plot_util.init_fig(n_sess, **figpar["init"])
    suptitle = (u"Relationship between pupil diam. and {} changes, locked to "
        "unexpected events\n{} for each ROI ({} vs {})".format(
            label_str, lab_app, *corr_data["stim_order"]))
    
    for i, sess_roi_corrs in enumerate(corr_data["roi_corrs"]):
        sub_ax = plot_util.get_subax(ax, i)
        title = (f"Mouse {mouse_ns[i]} (sess {sess_ns[i]}, {lines[i]} "
            f"{planes[i]}{dendstr_pr}{nroi_strs[i]})")
        
        # top plot: correlations
        corr = f"Corr = {corr_data['corrs'][i]:.2f}"
        sub_ax.plot(
            sess_roi_corrs[0], sess_roi_corrs[1], marker=".", linestyle="None", 
            label=corr)
        sub_ax.set_title(title, y=1.01)
        if sub_ax.is_last_row():
            sub_ax.set_xlabel(f"{stimstr_prs[0].capitalize()} correlations")
        if sub_ax.is_first_col():
            sub_ax.set_ylabel(f"{stimstr_prs[1].capitalize()} correlations")
        sub_ax.legend()

    plot_util.turn_off_extra(ax, n_sess)

    fig.suptitle(suptitle, fontsize="xx-large", y=1)

    if savedir is None:
        savedir = Path(
            figpar["dirs"]["roi"],
            figpar["dirs"]["pupil"])

    savename = f"roi_diff_corrbyroi_{sessstr}{dendstr}"

    fulldir = plot_util.savefig(fig, savename, savedir, **figpar["save"])

    return fulldir, savename                           
Example #14
0
def run_glms(sessions,
             analysis,
             seed,
             analyspar,
             sesspar,
             stimpar,
             glmpar,
             figpar,
             parallel=False):
    """
    run_glms(sessions, analysis, seed, analyspar, sesspar, stimpar, glmpar, 
             figpar)
    """

    seed = rand_util.seed_all(seed, "cpu", log_seed=False)

    sessstr_pr = sess_str_util.sess_par_str(sesspar.sess_n, stimpar.stimtype,
                                            sesspar.plane, stimpar.visflow_dir,
                                            stimpar.visflow_size, stimpar.gabk,
                                            "print")
    dendstr_pr = sess_str_util.dend_par_str(analyspar.dend, sesspar.plane,
                                            "roi", "print")

    logger.info(
        "Analysing and plotting explained variance in ROI activity "
        f"({sessstr_pr}{dendstr_pr}).",
        extra={"spacing": "\n"})

    if glmpar.each_roi:  # must do each session separately
        glm_type = "per_ROI_per_sess"
        sess_batches = sessions
        logger.info("Per ROI, each session separately.")
    else:
        glm_type = "across_sess"
        sess_batches = [sessions]
        logger.info(f"Across ROIs, {len(sessions)} sessions together.")

    # optionally runs in parallel, or propagates parallel to next level
    parallel_here = (parallel and not (glmpar.each_roi)
                     and (len(sess_batches) != 1))
    parallel_after = True if (parallel and not (parallel_here)) else False

    args_list = [analyspar, sesspar, stimpar, glmpar]
    args_dict = {
        "parallel": parallel_after,  # proactively set next parallel 
        "seed": seed,
    }
    all_expl_var = gen_util.parallel_wrap(run_glm,
                                          sess_batches,
                                          args_list,
                                          args_dict,
                                          parallel=parallel_here)

    if glmpar.each_roi:
        sessions = sess_batches
    else:
        sessions = sess_batches[0]

    sess_info = sess_gen_util.get_sess_info(sessions,
                                            analyspar.fluor,
                                            rem_bad=analyspar.rem_bad)

    extrapar = {
        "analysis": analysis,
        "seed": seed,
        "glm_type": glm_type,
    }

    info = {
        "analyspar": analyspar._asdict(),
        "sesspar": sesspar._asdict(),
        "stimpar": stimpar._asdict(),
        "glmpar": glmpar._asdict(),
        "extrapar": extrapar,
        "all_expl_var": all_expl_var,
        "sess_info": sess_info
    }

    fulldir, savename = glm_plots.plot_glm_expl_var(figpar=figpar, **info)

    file_util.saveinfo(info, savename, fulldir, "json")

    return
Example #15
0
def run_pup_roi_stim_corr(sessions,
                          analysis,
                          analyspar,
                          sesspar,
                          stimpar,
                          figpar,
                          datatype="roi",
                          parallel=False):
    """
    run_pup_roi_stim_corr(sessions, analysis, analyspar, sesspar, stimpar, 
                          figpar)
    
    Calculates and plots correlation between pupil and ROI changes locked to
    unexpected for gabors vs visflow.
    
    Saves results and parameters relevant to analysis in a dictionary.

    Required args:
        - sessions (list)      : list of Session objects
        - analysis (str)       : analysis type (e.g., "r")
        - analyspar (AnalysPar): named tuple containing analysis parameters
        - sesspar (SessPar)    : named tuple containing session parameters
        - stimpar (StimPar)    : named tuple containing stimulus parameters
        - figpar (dict)        : dictionary containing figure parameters
    
    Optional args:
        - datatype (str) : type of data (e.g., "roi", "run")
        - parallel (bool): if True, some of the analysis is parallelized across 
                           CPU cores
                           default: False
    """

    if datatype != "roi":
        raise NotImplementedError(
            "Analysis only implemented for roi datatype.")

    stimtypes = ["gabors", "visflow"]
    if stimpar.stimtype != "both":
        non_stimtype = stimtypes[1 - stimtypes.index(stimpar.stimtype)]
        warnings.warn(
            "stimpar.stimtype will be set to 'both', but non default "
            f"{non_stimtype} parameters are lost.",
            category=RuntimeWarning,
            stacklevel=1)
        stimpar_dict = stimpar._asdict()
        for key in list(stimpar_dict.keys()):  # remove any "none"s
            if stimpar_dict[key] == "none":
                stimpar_dict.pop(key)

    sessstr_pr = f"session: {sesspar.sess_n}, plane: {sesspar.plane}"
    dendstr_pr = sess_str_util.dend_par_str(analyspar.dend, sesspar.plane,
                                            datatype, "print")
    stimstr_pr = []
    stimpars = []
    for stimtype in stimtypes:
        stimpar_dict["stimtype"] = stimtype
        stimpar_dict["gabfr"] = 3
        stimpars.append(sess_ntuple_util.init_stimpar(**stimpar_dict))
        stimstr_pr.append(
            sess_str_util.stim_par_str(stimtype, stimpars[-1].visflow_dir,
                                       stimpars[-1].visflow_size,
                                       stimpars[-1].gabk, "print"))
    stimpar_dict = stimpars[0]._asdict()
    stimpar_dict["stimtype"] = "both"

    logger.info(
        "Analysing and plotting correlations between unexpected vs "
        f"expected ROI traces between sessions ({sessstr_pr}{dendstr_pr}).",
        extra={"spacing": "\n"})
    sess_corrs = []
    sess_roi_corrs = []
    for sess in sessions:
        if datatype == "roi" and (sess.only_tracked_rois != analyspar.tracked):
            raise RuntimeError(
                "sess.only_tracked_rois should match analyspar.tracked.")
        stim_corrs = []
        for sub_stimpar in stimpars:
            diffs = peristim_data(sess,
                                  sub_stimpar,
                                  datatype="roi",
                                  returns="diff",
                                  first_unexp=True,
                                  rem_bad=analyspar.rem_bad,
                                  scale=analyspar.scale)
            [pup_diff, roi_diff] = diffs
            nrois = roi_diff.shape[-1]
            # optionally runs in parallel
            if parallel and nrois > 1:
                n_jobs = gen_util.get_n_jobs(nrois)
                with gen_util.ParallelLogging():
                    corrs = Parallel(n_jobs=n_jobs)(
                        delayed(np.corrcoef)(roi_diff[:, r], pup_diff)
                        for r in range(nrois))
                corrs = np.asarray([corr[0, 1] for corr in corrs])
            else:
                corrs = np.empty(nrois)
                for r in range(nrois):  # cycle through ROIs
                    corrs[r] = np.corrcoef(roi_diff[:, r], pup_diff)[0, 1]
            stim_corrs.append(corrs)
        sess_corrs.append(np.corrcoef(stim_corrs[0], stim_corrs[1])[0, 1])
        sess_roi_corrs.append([corrs.tolist() for corrs in stim_corrs])

    extrapar = {
        "analysis": analysis,
        "datatype": datatype,
    }

    corr_data = {
        "stim_order": stimtypes,
        "roi_corrs": sess_roi_corrs,
        "corrs": sess_corrs
    }

    sess_info = sess_gen_util.get_sess_info(sessions,
                                            analyspar.fluor,
                                            incl_roi=(datatype == "roi"),
                                            rem_bad=analyspar.rem_bad)

    info = {
        "analyspar": analyspar._asdict(),
        "sesspar": sesspar._asdict(),
        "stimpar": stimpar_dict,
        "extrapar": extrapar,
        "sess_info": sess_info,
        "corr_data": corr_data
    }

    fulldir, savename = pup_plots.plot_pup_roi_stim_corr(figpar=figpar, **info)

    file_util.saveinfo(info, savename, fulldir, "json")
Example #16
0
def run_pupil_diff_corr(sessions,
                        analysis,
                        analyspar,
                        sesspar,
                        stimpar,
                        figpar,
                        datatype="roi"):
    """
    run_pupil_diff_corr(sessions, analysis, analyspar, sesspar, 
                        stimpar, figpar)
    
    Calculates and plots between pupil and ROI/running changes
    locked to each unexpected, as well as the correlation.

    Saves results and parameters relevant to analysis in a dictionary.

    Required args:
        - sessions (list)      : list of Session objects
        - analysis (str)       : analysis type (e.g., "c")
        - analyspar (AnalysPar): named tuple containing analysis parameters
        - sesspar (SessPar)    : named tuple containing session parameters
        - stimpar (StimPar)    : named tuple containing stimulus parameters
        - figpar (dict)        : dictionary containing figure parameters
    
    Optional args:
        - datatype (str): type of data (e.g., "roi", "run")
    """

    sessstr_pr = sess_str_util.sess_par_str(sesspar.sess_n, stimpar.stimtype,
                                            sesspar.plane, stimpar.visflow_dir,
                                            stimpar.visflow_size, stimpar.gabk,
                                            "print")
    dendstr_pr = sess_str_util.dend_par_str(analyspar.dend, sesspar.plane,
                                            datatype, "print")

    datastr = sess_str_util.datatype_par_str(datatype)

    logger.info(
        "Analysing and plotting correlations between unexpected vs "
        f"expected {datastr} traces between sessions ({sessstr_pr}"
        f"{dendstr_pr}).",
        extra={"spacing": "\n"})

    sess_diffs = []
    sess_corr = []

    for sess in sessions:
        if datatype == "roi" and (sess.only_tracked_rois != analyspar.tracked):
            raise RuntimeError(
                "sess.only_tracked_rois should match analyspar.tracked.")
        diffs = peristim_data(sess,
                              stimpar,
                              datatype=datatype,
                              returns="diff",
                              scale=analyspar.scale,
                              first_unexp=True)
        [pup_diff, data_diff] = diffs
        # trials (x ROIs)
        if datatype == "roi":
            if analyspar.rem_bad:
                nanpol = None
            else:
                nanpol = "omit"
            data_diff = math_util.mean_med(data_diff,
                                           analyspar.stats,
                                           axis=-1,
                                           nanpol=nanpol)
        elif datatype != "run":
            gen_util.accepted_values_error("datatype", datatype,
                                           ["roi", "run"])
        sess_corr.append(np.corrcoef(pup_diff, data_diff)[0, 1])
        sess_diffs.append([diff.tolist() for diff in [pup_diff, data_diff]])

    extrapar = {
        "analysis": analysis,
        "datatype": datatype,
    }

    corr_data = {"corrs": sess_corr, "diffs": sess_diffs}

    sess_info = sess_gen_util.get_sess_info(sessions,
                                            analyspar.fluor,
                                            incl_roi=(datatype == "roi"),
                                            rem_bad=analyspar.rem_bad)

    info = {
        "analyspar": analyspar._asdict(),
        "sesspar": sesspar._asdict(),
        "stimpar": stimpar._asdict(),
        "extrapar": extrapar,
        "sess_info": sess_info,
        "corr_data": corr_data
    }

    fulldir, savename = pup_plots.plot_pup_diff_corr(figpar=figpar, **info)

    file_util.saveinfo(info, savename, fulldir, "json")