Пример #1
0
def run_autocorr(sessions,
                 analysis,
                 analyspar,
                 sesspar,
                 stimpar,
                 autocorrpar,
                 figpar,
                 datatype="roi"):
    """
    run_autocorr(sessions, analysis, analyspar, sesspar, stimpar, autocorrpar, 
                 figpar)

    Calculates and plots autocorrelation during stimulus blocks.
    Saves results and parameters relevant to analysis in a dictionary.

    Required args:
        - sessions (list)          : list of Session objects
        - analysis (str)           : analysis type (e.g., "a")
        - analyspar (AnalysPar)    : named tuple containing analysis parameters
        - sesspar (SessPar)        : named tuple containing session parameters
        - stimpar (StimPar)        : named tuple containing stimulus parameters
        - autocorrpar (AutocorrPar): named tuple containing autocorrelation 
                                     analysis parameters
        - figpar (dict)            : dictionary containing figure parameters

    Optional args:
        - datatype (str): type of data (e.g., "roi", "run")
    """

    sessstr_pr = sess_str_util.sess_par_str(sesspar.sess_n, stimpar.stimtype,
                                            sesspar.plane, stimpar.visflow_dir,
                                            stimpar.visflow_size, stimpar.gabk,
                                            "print")
    dendstr_pr = sess_str_util.dend_par_str(analyspar.dend, sesspar.plane,
                                            datatype, "print")

    datastr = sess_str_util.datatype_par_str(datatype)

    logger.info(
        f"Analysing and plotting {datastr} autocorrelations "
        f"({sessstr_pr}{dendstr_pr}).",
        extra={"spacing": "\n"})

    xrans = []
    stats = []
    for sess in sessions:
        if datatype == "roi" and (sess.only_tracked_rois != analyspar.tracked):
            raise RuntimeError(
                "sess.only_tracked_rois should match analyspar.tracked.")
        stim = sess.get_stim(stimpar.stimtype)
        all_segs = stim.get_segs_by_criteria(visflow_dir=stimpar.visflow_dir,
                                             visflow_size=stimpar.visflow_size,
                                             gabk=stimpar.gabk,
                                             by="block")
        sess_traces = []
        for segs in all_segs:
            if len(segs) == 0:
                continue
            segs = sorted(segs)
            # check that segs are contiguous
            if max(np.diff(segs)) > 1:
                raise NotImplementedError("Segments used for autocorrelation "
                                          "must be contiguous within blocks.")
            if datatype == "roi":
                frame_edges = stim.get_fr_by_seg(
                    [min(segs), max(segs)], fr_type="twop")
                fr = list(range(min(frame_edges[0]), max(frame_edges[1]) + 1))
                traces = gen_util.reshape_df_data(sess.get_roi_traces(
                    fr,
                    fluor=analyspar.fluor,
                    rem_bad=analyspar.rem_bad,
                    scale=analyspar.scale),
                                                  squeeze_cols=True)

            elif datatype == "run":
                if autocorrpar.byitem != False:
                    raise ValueError("autocorrpar.byitem must be False for "
                                     "running data.")
                frame_edges = stim.get_fr_by_seg(
                    [min(segs), max(segs)], fr_type="stim")
                fr = list(range(min(frame_edges[0]), max(frame_edges[1]) + 1))

                traces = sess.get_run_velocity_by_fr(
                    fr,
                    fr_type="stim",
                    rem_bad=analyspar.rem_bad,
                    scale=analyspar.scale).to_numpy().reshape(1, -1)

            sess_traces.append(traces)

        # Calculate autocorr stats while filtering some warnings
        msgs = ["Degrees of freedom", "invalid value encountered"]
        categs = [RuntimeWarning, RuntimeWarning]
        with gen_util.TempWarningFilter(msgs, categs):
            xran, ac_st = math_util.autocorr_stats(sess_traces,
                                                   autocorrpar.lag_s,
                                                   sess.twop_fps,
                                                   byitem=autocorrpar.byitem,
                                                   stats=analyspar.stats,
                                                   error=analyspar.error)

        if not autocorrpar.byitem:  # also add a 10x lag
            lag_fr = 10 * int(autocorrpar.lag_s * sess.twop_fps)
            _, ac_st_10x = math_util.autocorr_stats(sess_traces,
                                                    lag_fr,
                                                    byitem=autocorrpar.byitem,
                                                    stats=analyspar.stats,
                                                    error=analyspar.error)
            downsamp = range(0, ac_st_10x.shape[-1], 10)

            if len(downsamp) != ac_st.shape[-1]:
                raise RuntimeError("Failed to downsample correctly. "
                                   "Check implementation.")
            ac_st = np.stack([ac_st, ac_st_10x[:, downsamp]], axis=1)
        xrans.append(xran)
        stats.append(ac_st)

    autocorr_data = {
        "xrans": [xran.tolist() for xran in xrans],
        "stats": [stat.tolist() for stat in stats]
    }

    sess_info = sess_gen_util.get_sess_info(sessions,
                                            analyspar.fluor,
                                            incl_roi=(datatype == "roi"),
                                            rem_bad=analyspar.rem_bad)

    extrapar = {
        "analysis": analysis,
        "datatype": datatype,
    }

    info = {
        "analyspar": analyspar._asdict(),
        "sesspar": sesspar._asdict(),
        "stimpar": stimpar._asdict(),
        "extrapar": extrapar,
        "autocorrpar": autocorrpar._asdict(),
        "autocorr_data": autocorr_data,
        "sess_info": sess_info
    }

    fulldir, savename = gen_plots.plot_autocorr(figpar=figpar, **info)

    file_util.saveinfo(info, savename, fulldir, "json")
Пример #2
0
def run_trace_corr_acr_sess(sessions,
                            analysis,
                            analyspar,
                            sesspar,
                            stimpar,
                            figpar,
                            datatype="roi"):
    """
    run_trace_corr_acr_sess(sessions, analysis, analyspar, sesspar, 
                            stimpar, quantpar, figpar)

    Retrieves trace statistics by session x unexp val and calculates 
    correlations across sessions per unexp val.
    
    Currently only logs results to the console. Does NOT save results and 
    parameters relevant to analysis in a dictionary.

    Required args:
        - sessions (list)      : list of Session objects
        - analysis (str)       : analysis type (e.g., "r")
        - analyspar (AnalysPar): named tuple containing analysis parameters
        - sesspar (SessPar)    : named tuple containing session parameters
        - stimpar (StimPar)    : named tuple containing stimulus parameters
        - figpar (dict)        : dictionary containing figure parameters
    
    Optional args:
        - datatype (str): type of data (e.g., "roi", "run")
    """

    sessstr_pr = sess_str_util.sess_par_str(sesspar.sess_n, stimpar.stimtype,
                                            sesspar.plane, stimpar.visflow_dir,
                                            stimpar.visflow_size, stimpar.gabk,
                                            "print")
    # dendstr_pr = sess_str_util.dend_par_str(
    # analyspar.dend, sesspar.plane, datatype, "print")

    datastr = sess_str_util.datatype_par_str(datatype)

    if sesspar.plane in ["any", "all"] and sesspar.runtype == "pilot":
        logger.warning("Planes may not match between sessions for a mouse!")

    logger.info(
        "Analysing and plotting correlations between unexpected vs "
        f"expected {datastr} traces between sessions ({sessstr_pr}).",
        extra={"spacing": "\n"})

    figpar = copy.deepcopy(figpar)
    if figpar["save"]["use_dt"] is None:
        figpar["save"]["use_dt"] = gen_util.create_time_str()

    prev_level = logger.level
    if prev_level > logging.INFO:
        logger.setLevel(logging.INFO)
        logger.warning("Temporarily lowered log level for correlation "
                       "analysis results.")

    unexps = ["exp", "unexp"]

    # correlate average traces between sessions for each mouse and each
    # unexpected value
    all_counts = []
    all_me_tr = []
    all_corrs = []
    logger.info("Intramouse correlations", extra={"spacing": "\n"})
    for sess_grp in sessions:
        logger.info(f"Mouse {sess_grp[0].mouse_n}, sess {sess_grp[0].sess_n} "
                    f"vs {sess_grp[1].sess_n} corr:")
        trace_info = quant_analys.trace_stats_by_qu_sess(sess_grp,
                                                         analyspar,
                                                         stimpar,
                                                         1, [0],
                                                         byroi=False,
                                                         by_exp=True,
                                                         datatype=datatype)
        # remove quant dim
        grp_stats = np.asarray(trace_info[1]).squeeze(2)
        all_counts.append([[qu_c[0] for qu_c in c] for c in trace_info[2]])
        # get mean/median per grp (sess x unexp_val x frame)
        grp_me = grp_stats[:, :, 0]
        grp_corrs = []
        # collect correlations
        corrs = [
            st.pearsonr(grp_me[0, s], grp_me[1, s]) for s in range(len(unexps))
        ]
        corr_max = np.argmax([corr[0] for corr in corrs])
        for s, (unexp, corr) in enumerate(zip(unexps, corrs)):
            sig_str = "*" if corr[1] < 0.05 else ""
            high_str = " +" if corr_max == s else ""
            logger.info(
                f"{unexp}: {corr[0]:.4f} "
                f"(p={corr[1]:.2f}{sig_str}){high_str}",
                extra={"spacing": TAB})
            corr = corr[0]
            grp_corrs.append(corr)
        all_corrs.append(grp_corrs)
        all_me_tr.append(grp_me)

    # mice x sess x unexp x frame
    all_me_tr = np.asarray(all_me_tr)
    logger.info("Intermouse correlations", extra={"spacing": "\n"})
    all_mouse_corrs = []
    for n, m1_sess_mes in enumerate(all_me_tr):
        if n + 1 < len(all_me_tr):
            mouse_corrs = []
            for n_add, m2_sess_mes in enumerate(all_me_tr[n + 1:]):
                sess_corrs = []
                logger.info(f"Mouse {sessions[n][0].mouse_n} vs "
                            f"{sessions[n + 1 + n_add][0].mouse_n} corr:")
                for se, m1_s1_me in enumerate(m1_sess_mes):
                    unexp_corrs = []
                    logger.info(f"sess {sessions[n][se].sess_n}:",
                                extra={"spacing": TAB})
                    # collect correlations
                    corrs = [
                        st.pearsonr(m1_s1_me[s], m2_sess_mes[se][s])
                        for s in range(len(unexps))
                    ]
                    corr_max = np.argmax([corr[0] for corr in corrs])
                    for s, (unexp, corr) in enumerate(zip(unexps, corrs)):
                        sig_str = "*" if corr[1] < 0.05 else ""
                        high_str = " +" if corr_max == s else ""
                        logger.info(
                            f"{unexp}: {corr[0]:.4f} "
                            f"(p={corr[1]:.2f}{sig_str}){high_str}",
                            extra={"spacing": f"{TAB}{TAB}"})
                        corr = corr[0]
                        unexp_corrs.append(corr)
                    sess_corrs.append(unexp_corrs)
                mouse_corrs.append(sess_corrs)
            all_mouse_corrs.append(mouse_corrs)

    # reset logger level
    logger.setLevel(prev_level)
Пример #3
0
def run_full_traces(sessions,
                    analysis,
                    analyspar,
                    sesspar,
                    figpar,
                    datatype="roi"):
    """
    run_full_traces(sessions, analysis, analyspar, sesspar, figpar)

    Plots full traces across an entire session. If ROI traces are plotted,
    each ROI is scaled and plotted separately and an average is plotted.
    
    Saves results and parameters relevant to analysis in a dictionary.

    Required args:
        - sessions (list)      : list of Session objects
        - analysis (str)       : analysis type (e.g., "f")
        - analyspar (AnalysPar): named tuple containing analysis parameters
        - sesspar (SessPar)    : named tuple containing session parameters
        - figpar (dict)        : dictionary containing figure parameters
    
    Optional args:
        - datatype (str): type of data (e.g., "roi", "run")
    """

    dendstr_pr = sess_str_util.dend_par_str(analyspar.dend, sesspar.plane,
                                            datatype, "print")

    sessstr_pr = (f"session: {sesspar.sess_n}, "
                  f"plane: {sesspar.plane}{dendstr_pr}")

    datastr = sess_str_util.datatype_par_str(datatype)

    logger.info(
        f"Plotting {datastr} traces across an entire "
        f"session\n({sessstr_pr}).",
        extra={"spacing": "\n"})

    figpar = copy.deepcopy(figpar)
    if figpar["save"]["use_dt"] is None:
        figpar["save"]["use_dt"] = gen_util.create_time_str()

    all_tr, roi_tr, all_edges, all_pars = [], [], [], []
    for sess in sessions:
        # get the block edges and parameters
        edge_fr, par_descrs = [], []
        for stim in sess.stims:
            stim_str = stim.stimtype
            if stim.stimtype == "visflow":
                stim_str = "vis. flow"
            if datatype == "roi":
                fr_type = "twop"
            elif datatype == "run":
                fr_type = "stim"
            else:
                gen_util.accepted_values_error("datatype", datatype,
                                               ["roi", "run"])
            for b in stim.block_params.index:
                row = stim.block_params.loc[b]
                edge_fr.append([
                    int(row[f"start_frame_{fr_type}"]),
                    int(row[f"stop_frame_{fr_type}"])
                ])
                par_vals = [row[param] for param in stim.stim_params]
                pars_str = "\n".join([str(par) for par in par_vals][0:2])

                par_descrs.append(
                    sess_str_util.pars_to_descr(
                        f"{stim_str.capitalize()}\n{pars_str}"))

        if datatype == "roi":
            if sess.only_tracked_rois != analyspar.tracked:
                raise RuntimeError(
                    "sess.only_tracked_rois should match analyspar.tracked.")
            nanpol = None
            if not analyspar.rem_bad:
                nanpol = "omit"
            all_rois = gen_util.reshape_df_data(sess.get_roi_traces(
                None, analyspar.fluor, analyspar.rem_bad,
                analyspar.scale)["roi_traces"],
                                                squeeze_cols=True)
            full_tr = math_util.get_stats(all_rois,
                                          analyspar.stats,
                                          analyspar.error,
                                          axes=0,
                                          nanpol=nanpol).tolist()
            roi_tr.append(all_rois.tolist())
        elif datatype == "run":
            full_tr = sess.get_run_velocity(
                rem_bad=analyspar.rem_bad,
                scale=analyspar.scale).to_numpy().squeeze().tolist()
            roi_tr = None
        all_tr.append(full_tr)
        all_edges.append(edge_fr)
        all_pars.append(par_descrs)

    extrapar = {
        "analysis": analysis,
        "datatype": datatype,
    }

    trace_info = {
        "all_tr": all_tr,
        "all_edges": all_edges,
        "all_pars": all_pars
    }

    sess_info = sess_gen_util.get_sess_info(sessions,
                                            analyspar.fluor,
                                            incl_roi=(datatype == "roi"),
                                            rem_bad=analyspar.rem_bad)

    info = {
        "analyspar": analyspar._asdict(),
        "sesspar": sesspar._asdict(),
        "extrapar": extrapar,
        "sess_info": sess_info,
        "trace_info": trace_info
    }

    fulldir, savename = gen_plots.plot_full_traces(roi_tr=roi_tr,
                                                   figpar=figpar,
                                                   **info)
    file_util.saveinfo(info, savename, fulldir, "json")
Пример #4
0
def run_mag_change(sessions,
                   analysis,
                   seed,
                   analyspar,
                   sesspar,
                   stimpar,
                   permpar,
                   quantpar,
                   figpar,
                   datatype="roi"):
    """
    run_mag_change(sessions, analysis, seed, analyspar, sesspar, stimpar, 
                   permpar, quantpar, figpar)

    Calculates and plots the magnitude of change in activity of ROIs between 
    the first and last quantile for expected vs unexpected sequences.
    Saves results and parameters relevant to analysis in a dictionary.

    Required args:
        - sessions (list)      : list of Session objects
        - analysis (str)       : analysis type (e.g., "m")
        - seed (int)           : seed value to use. (-1 treated as None) 
        - analyspar (AnalysPar): named tuple containing analysis parameters
        - sesspar (SessPar)    : named tuple containing session parameters
        - stimpar (StimPar)    : named tuple containing stimulus parameters
        - permpar (PermPar)    : named tuple containing permutation parameters
        - quantpar (QuantPar)  : named tuple containing quantile analysis 
                                 parameters
        - figpar (dict)        : dictionary containing figure parameters   

    Optional args:
        - datatype (str): type of data (e.g., "roi", "run") 
    """

    sessstr_pr = sess_str_util.sess_par_str(sesspar.sess_n, stimpar.stimtype,
                                            sesspar.plane, stimpar.visflow_dir,
                                            stimpar.visflow_size, stimpar.gabk,
                                            "print")
    dendstr_pr = sess_str_util.dend_par_str(analyspar.dend, sesspar.plane,
                                            datatype, "print")

    datastr = sess_str_util.datatype_par_str(datatype)

    logger.info(
        f"Calculating and plotting the magnitude changes in {datastr} "
        f"activity across quantiles \n({sessstr_pr}{dendstr_pr}).",
        extra={"spacing": "\n"})

    if permpar.multcomp:
        permpar = sess_ntuple_util.get_modif_ntuple(permpar, "multcomp",
                                                    len(sessions))

    # get full data: session x unexp x quants of interest x [ROI x seq]
    integ_info = quant_analys.trace_stats_by_qu_sess(sessions,
                                                     analyspar,
                                                     stimpar,
                                                     quantpar.n_quants,
                                                     quantpar.qu_idx,
                                                     by_exp=True,
                                                     integ=True,
                                                     ret_arr=True,
                                                     datatype=datatype)
    all_counts = integ_info[-2]
    qu_data = integ_info[-1]

    # extract session info
    mouse_ns = [sess.mouse_n for sess in sessions]
    lines = [sess.line for sess in sessions]

    if analyspar.rem_bad:
        nanpol = None
    else:
        nanpol = "omit"

    seed = rand_util.seed_all(seed, "cpu", log_seed=False)

    mags = quant_analys.qu_mags(qu_data,
                                permpar,
                                mouse_ns,
                                lines,
                                analyspar.stats,
                                analyspar.error,
                                nanpol=nanpol,
                                op_qu="diff",
                                op_unexp="diff")

    # convert mags items to list
    mags = copy.deepcopy(mags)
    mags["all_counts"] = all_counts
    for key in ["mag_st", "L2", "mag_rel_th", "L2_rel_th"]:
        mags[key] = mags[key].tolist()

    sess_info = sess_gen_util.get_sess_info(sessions,
                                            analyspar.fluor,
                                            incl_roi=(datatype == "roi"),
                                            rem_bad=analyspar.rem_bad)

    extrapar = {"analysis": analysis, "datatype": datatype, "seed": seed}

    info = {
        "analyspar": analyspar._asdict(),
        "sesspar": sesspar._asdict(),
        "stimpar": stimpar._asdict(),
        "extrapar": extrapar,
        "permpar": permpar._asdict(),
        "quantpar": quantpar._asdict(),
        "mags": mags,
        "sess_info": sess_info
    }

    fulldir, savename = gen_plots.plot_mag_change(figpar=figpar, **info)

    file_util.saveinfo(info, savename, fulldir, "json")
Пример #5
0
def run_traces_by_qu_lock_sess(sessions,
                               analysis,
                               seed,
                               analyspar,
                               sesspar,
                               stimpar,
                               quantpar,
                               figpar,
                               datatype="roi"):
    """
    run_traces_by_qu_lock_sess(sessions, analysis, analyspar, sesspar, 
                               stimpar, quantpar, figpar)

    Retrieves trace statistics by session x quantile at the transition of
    expected to unexpected sequences (or v.v.) and plots traces across ROIs by 
    quantile with each session in a separate subplot.
    
    Also runs analysis for one quantile (full data) with different unexpected 
    lengths grouped separated 
    
    Saves results and parameters relevant to analysis in a dictionary.

    Required args:
        - sessions (list)      : list of Session objects
        - analysis (str)       : analysis type (e.g., "l")
        - seed (int)           : seed value to use. (-1 treated as None)
        - analyspar (AnalysPar): named tuple containing analysis parameters
        - sesspar (SessPar)    : named tuple containing session parameters
        - stimpar (StimPar)    : named tuple containing stimulus parameters
        - quantpar (QuantPar)  : named tuple containing quantile analysis 
                                 parameters
        - figpar (dict)        : dictionary containing figure parameters
    
    Optional args:
        - datatype (str): type of data (e.g., "roi", "run")

    """

    sessstr_pr = sess_str_util.sess_par_str(sesspar.sess_n, stimpar.stimtype,
                                            sesspar.plane, stimpar.visflow_dir,
                                            stimpar.visflow_size, stimpar.gabk,
                                            "print")
    dendstr_pr = sess_str_util.dend_par_str(analyspar.dend, sesspar.plane,
                                            datatype, "print")

    datastr = sess_str_util.datatype_par_str(datatype)

    logger.info(
        f"Analysing and plotting unexpected vs expected {datastr} "
        f"traces locked to unexpected onset by quantile ({quantpar.n_quants}) "
        f"\n({sessstr_pr}{dendstr_pr}).",
        extra={"spacing": "\n"})

    seed = rand_util.seed_all(seed, "cpu", log_seed=False)

    # modify quantpar to retain all quantiles
    quantpar_one = sess_ntuple_util.init_quantpar(1, 0)
    n_quants = quantpar.n_quants
    quantpar_mult = sess_ntuple_util.init_quantpar(n_quants, "all")

    if stimpar.stimtype == "visflow":
        pre_post = [2.0, 6.0]
    elif stimpar.stimtype == "gabors":
        pre_post = [2.0, 8.0]
    else:
        gen_util.accepted_values_error("stimpar.stimtype", stimpar.stimtype,
                                       ["visflow", "gabors"])
    logger.warning("Setting pre to {}s and post to {}s.".format(*pre_post))

    stimpar = sess_ntuple_util.get_modif_ntuple(stimpar, ["pre", "post"],
                                                pre_post)

    figpar = copy.deepcopy(figpar)
    if figpar["save"]["use_dt"] is None:
        figpar["save"]["use_dt"] = gen_util.create_time_str()

    for baseline in [None, stimpar.pre]:
        basestr_pr = sess_str_util.base_par_str(baseline, "print")
        for quantpar in [quantpar_one, quantpar_mult]:
            locks = ["unexp", "exp"]
            if quantpar.n_quants == 1:
                locks.append("unexp_split")
            # get the stats (all) separating by session and quantiles
            for lock in locks:
                logger.info(
                    f"{quantpar.n_quants} quant, {lock} lock{basestr_pr}",
                    extra={"spacing": "\n"})
                if lock == "unexp_split":
                    trace_info = quant_analys.trace_stats_by_exp_len_sess(
                        sessions,
                        analyspar,
                        stimpar,
                        quantpar.n_quants,
                        quantpar.qu_idx,
                        byroi=False,
                        nan_empty=True,
                        baseline=baseline,
                        datatype=datatype)
                else:
                    trace_info = quant_analys.trace_stats_by_qu_sess(
                        sessions,
                        analyspar,
                        stimpar,
                        quantpar.n_quants,
                        quantpar.qu_idx,
                        byroi=False,
                        lock=lock,
                        nan_empty=True,
                        baseline=baseline,
                        datatype=datatype)

                # for comparison, locking to middle of expected sample (1 quant)
                exp_samp = quant_analys.trace_stats_by_qu_sess(
                    sessions,
                    analyspar,
                    stimpar,
                    quantpar_one.n_quants,
                    quantpar_one.qu_idx,
                    byroi=False,
                    lock="exp_samp",
                    nan_empty=True,
                    baseline=baseline,
                    datatype=datatype)

                extrapar = {
                    "analysis": analysis,
                    "datatype": datatype,
                    "seed": seed,
                }

                xrans = [xran.tolist() for xran in trace_info[0]]
                all_stats = [sessst.tolist() for sessst in trace_info[1]]
                exp_stats = [expst.tolist() for expst in exp_samp[1]]
                trace_stats = {
                    "xrans": xrans,
                    "all_stats": all_stats,
                    "all_counts": trace_info[2],
                    "lock": lock,
                    "baseline": baseline,
                    "exp_stats": exp_stats,
                    "exp_counts": exp_samp[2]
                }

                if lock == "unexp_split":
                    trace_stats["unexp_lens"] = trace_info[3]

                sess_info = sess_gen_util.get_sess_info(
                    sessions,
                    analyspar.fluor,
                    incl_roi=(datatype == "roi"),
                    rem_bad=analyspar.rem_bad)

                info = {
                    "analyspar": analyspar._asdict(),
                    "sesspar": sesspar._asdict(),
                    "stimpar": stimpar._asdict(),
                    "quantpar": quantpar._asdict(),
                    "extrapar": extrapar,
                    "sess_info": sess_info,
                    "trace_stats": trace_stats
                }

                fulldir, savename = gen_plots.plot_traces_by_qu_lock_sess(
                    figpar=figpar, **info)
                file_util.saveinfo(info, savename, fulldir, "json")
Пример #6
0
def run_traces_by_qu_unexp_sess(sessions,
                                analysis,
                                analyspar,
                                sesspar,
                                stimpar,
                                quantpar,
                                figpar,
                                datatype="roi"):
    """
    run_traces_by_qu_unexp_sess(sessions, analysis, analyspar, sesspar, 
                               stimpar, quantpar, figpar)

    Retrieves trace statistics by session x unexp val x quantile and
    plots traces across ROIs by quantile/unexpected with each session in a 
    separate subplot.
    
    Also runs analysis for one quantile (full data).
    
    Saves results and parameters relevant to analysis in a dictionary.

    Required args:
        - sessions (list)      : list of Session objects
        - analysis (str)       : analysis type (e.g., "t")
        - analyspar (AnalysPar): named tuple containing analysis parameters
        - sesspar (SessPar)    : named tuple containing session parameters
        - stimpar (StimPar)    : named tuple containing stimulus parameters
        - quantpar (QuantPar)  : named tuple containing quantile analysis 
                                 parameters
        - figpar (dict)        : dictionary containing figure parameters
    
    Optional args:
        - datatype (str): type of data (e.g., "roi", "run")
    """

    sessstr_pr = sess_str_util.sess_par_str(sesspar.sess_n, stimpar.stimtype,
                                            sesspar.plane, stimpar.visflow_dir,
                                            stimpar.visflow_size, stimpar.gabk,
                                            "print")
    dendstr_pr = sess_str_util.dend_par_str(analyspar.dend, sesspar.plane,
                                            datatype, "print")

    datastr = sess_str_util.datatype_par_str(datatype)

    logger.info(
        f"Analysing and plotting unexpected vs expected {datastr} "
        f"traces by quantile ({quantpar.n_quants}) \n({sessstr_pr}"
        f"{dendstr_pr}).",
        extra={"spacing": "\n"})

    # modify quantpar to retain all quantiles
    quantpar_one = sess_ntuple_util.init_quantpar(1, 0)
    n_quants = quantpar.n_quants
    quantpar_mult = sess_ntuple_util.init_quantpar(n_quants, "all")

    figpar = copy.deepcopy(figpar)
    if figpar["save"]["use_dt"] is None:
        figpar["save"]["use_dt"] = gen_util.create_time_str()

    for quantpar in [quantpar_one, quantpar_mult]:
        logger.info(f"{quantpar.n_quants} quant", extra={"spacing": "\n"})
        # get the stats (all) separating by session, unexpected and quantiles
        trace_info = quant_analys.trace_stats_by_qu_sess(sessions,
                                                         analyspar,
                                                         stimpar,
                                                         quantpar.n_quants,
                                                         quantpar.qu_idx,
                                                         byroi=False,
                                                         by_exp=True,
                                                         datatype=datatype)

        extrapar = {
            "analysis": analysis,
            "datatype": datatype,
        }

        xrans = [xran.tolist() for xran in trace_info[0]]
        all_stats = [sessst.tolist() for sessst in trace_info[1]]
        trace_stats = {
            "xrans": xrans,
            "all_stats": all_stats,
            "all_counts": trace_info[2]
        }

        sess_info = sess_gen_util.get_sess_info(sessions,
                                                analyspar.fluor,
                                                incl_roi=(datatype == "roi"),
                                                rem_bad=analyspar.rem_bad)

        info = {
            "analyspar": analyspar._asdict(),
            "sesspar": sesspar._asdict(),
            "stimpar": stimpar._asdict(),
            "quantpar": quantpar._asdict(),
            "extrapar": extrapar,
            "sess_info": sess_info,
            "trace_stats": trace_stats
        }

        fulldir, savename = gen_plots.plot_traces_by_qu_unexp_sess(
            figpar=figpar, **info)
        file_util.saveinfo(info, savename, fulldir, "json")
Пример #7
0
def plot_autocorr(analyspar, sesspar, stimpar, extrapar, autocorrpar, 
                  sess_info, autocorr_data, figpar=None, savedir=None):
    """
    plot_autocorr(analyspar, sesspar, stimpar, extrapar, autocorrpar, 
                  sess_info, autocorr_data)

    From dictionaries, plots autocorrelation during stimulus blocks.

    Required args:
        - analyspar (dict)    : dictionary with keys of AnalysPar namedtuple
        - sesspar (dict)      : dictionary with keys of SessPar namedtuple 
        - stimpar (dict)      : dictionary with keys of StimPar namedtuple
        - extrapar (dict)     : dictionary containing additional analysis 
                                parameters
            ["analysis"] (str): analysis type (e.g., "a")
            ["datatype"] (str): datatype (e.g., "run", "roi")
        - autocorrpar (dict)  : dictionary with keys of AutocorrPar namedtuple
        - sess_info (dict)    : dictionary containing information from each
                                session 
            ["mouse_ns"] (list)   : mouse numbers
            ["sess_ns"] (list)    : session numbers  
            ["lines"] (list)      : mouse lines
            ["planes"] (list)     : imaging planes
            ["nrois"] (list)      : number of ROIs in session

        - autocorr_data (dict): dictionary containing data to plot:
            ["xrans"] (list): list of lag values in seconds for each session
            ["stats"] (list): list of 3D arrays (or nested lists) of
                              autocorrelation statistics, structured as:
                                     sessions stats (me, err) 
                                     x ROI or 1x and 10x lag 
                                     x lag
    
    Optional args:
        - figpar (dict): dictionary containing the following figure parameter 
                         dictionaries
                         default: None
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters
        - savedir (str): path of directory in which to save plots.
                         default: None
    
    Returns:
        - fulldir (str) : final path of the directory in which the figure is 
                          saved (may differ from input savedir, if datetime 
                          subfolder is added.)
        - savename (str): name under which the figure is saved
    """


    statstr_pr = sess_str_util.stat_par_str(
        analyspar["stats"], analyspar["error"], "print")
    stimstr_pr = sess_str_util.stim_par_str(
        stimpar["stimtype"], stimpar["visflow_dir"], stimpar["visflow_size"], 
        stimpar["gabk"], "print")
    dendstr_pr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"], "print")

    sessstr = sess_str_util.sess_par_str(
        sesspar["sess_n"], stimpar["stimtype"], sesspar["plane"], 
        stimpar["visflow_dir"],stimpar["visflow_size"], stimpar["gabk"]) 
    dendstr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"])
     
    datatype = extrapar["datatype"]
    if datatype == "roi":
        fluorstr_pr = sess_str_util.fluor_par_str(
            analyspar["fluor"], str_type="print")
        if autocorrpar["byitem"]:
            title_str = u"{}\nautocorrelation".format(fluorstr_pr)
        else:
            title_str = "\nautocorr. acr. ROIs" .format(fluorstr_pr)

    elif datatype == "run":
        datastr = sess_str_util.datatype_par_str(datatype)
        title_str = u"\n{} autocorrelation".format(datastr)

    if stimpar["stimtype"] == "gabors":
        seq_bars = [-1.5, 1.5] # light lines
    else:
        seq_bars = [-1.0, 1.0] # light lines

    # extract some info from sess_info
    keys = ["mouse_ns", "sess_ns", "lines", "planes"]
    [mouse_ns, sess_ns, lines, planes] = [sess_info[key] for key in keys]
    nroi_strs = sess_str_util.get_nroi_strs(sess_info, empty=(datatype!="roi")) 

    n_sess = len(mouse_ns)

    xrans = autocorr_data["xrans"]
    stats = [np.asarray(stat) for stat in autocorr_data["stats"]]

    lag_s = autocorrpar["lag_s"]
    xticks = np.linspace(-lag_s, lag_s, lag_s*2+1)
    yticks = np.linspace(0, 1, 6)

    if figpar is None:
        figpar = sess_plot_util.init_figpar()

    byitemstr = ""
    if autocorrpar["byitem"]:
        byitemstr = "_byroi"

    fig, ax = plot_util.init_fig(n_sess, **figpar["init"])
    for i in range(n_sess):
        sub_ax = plot_util.get_subax(ax, i)
        title = (f"Mouse {mouse_ns[i]} - {stimstr_pr}, "
            u"{} ".format(statstr_pr) + f"{title_str} (sess "
            f"{sess_ns[i]}, {lines[i]} {planes[i]}{dendstr_pr}{nroi_strs[i]})")
        # transpose to ROI/lag x stats x series
        sess_stats = stats[i].transpose(1, 0, 2) 
        for s, sub_stats in enumerate(sess_stats):
            lab = None
            if not autocorrpar["byitem"]:
                lab = ["actual lag", "10x lag"][s]

            plot_util.plot_traces(
                sub_ax, xrans[i], sub_stats[0], sub_stats[1:], xticks=xticks, 
                yticks=yticks, alpha=0.2, label=lab)

        plot_util.add_bars(sub_ax, hbars=seq_bars)
        sub_ax.set_ylim([0, 1])
        sub_ax.set_title(title, y=1.02)
        if sub_ax.is_last_row():
            sub_ax.set_xlabel("Lag (s)")

    plot_util.turn_off_extra(ax, n_sess)

    if savedir is None:
        savedir = Path(
            figpar["dirs"][datatype], 
            figpar["dirs"]["autocorr"])

    savename = (f"{datatype}_autocorr{byitemstr}_{sessstr}{dendstr}")

    fulldir = plot_util.savefig(fig, savename, savedir, **figpar["save"])

    return fulldir, savename
def plot_pup_diff_corr(analyspar, sesspar, stimpar, extrapar, 
                       sess_info, corr_data, figpar=None, savedir=None):
    """
    plot_pup_diff_corr(analyspar, sesspar, stimpar, extrapar, 
                       sess_info, corr_data)

    From dictionaries, plots correlation between unexpected-locked changes in 
    pupil diameter and running or ROI data for each session.

    Required args:
        - analyspar (dict)    : dictionary with keys of AnalysPar namedtuple
        - sesspar (dict)      : dictionary with keys of SessPar namedtuple 
        - stimpar (dict)      : dictionary with keys of StimPar namedtuple
        - extrapar (dict)     : dictionary containing additional analysis 
                                parameters
            ["analysis"] (str): analysis type (e.g., "c")
            ["datatype"] (str): datatype (e.g., "run", "roi")
        - sess_info (dict)    : dictionary containing information from each
                                session 
            ["mouse_ns"] (list)   : mouse numbers
            ["sess_ns"] (list)    : session numbers  
            ["lines"] (list)      : mouse lines
            ["planes"] (list)     : imaging planes
            ["nrois"] (list)      : number of ROIs in session
        - corr_data (dict)    : dictionary containing data to plot:
            ["corrs"] (list): list of correlation values between pupil and 
                              running or ROI differences for each session
            ["diffs"] (list): list of differences for each session, structured
                                  as [pupil, ROI/run] x trials x frames
    
    Optional args:
        - figpar (dict) : dictionary containing the following figure parameter 
                          dictionaries
                          default: None
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters
        - savedir (Path): path of directory in which to save plots.
                          default: None
    
    Returns:
        - fulldir (Path): final path of the directory in which the figure is 
                          saved (may differ from input savedir, if datetime 
                          subfolder is added.)
        - savename (str): name under which the figure is saved
    """
    statstr_pr = sess_str_util.stat_par_str(
        analyspar["stats"], analyspar["error"], "print")
    stimstr_pr = sess_str_util.stim_par_str(
        stimpar["stimtype"], stimpar["visflow_dir"], stimpar["visflow_size"], 
        stimpar["gabk"], "print")
    dendstr_pr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"], "print")

    sessstr = sess_str_util.sess_par_str(
        sesspar["sess_n"], stimpar["stimtype"], sesspar["plane"], 
        stimpar["visflow_dir"],stimpar["visflow_size"], stimpar["gabk"]) 
    dendstr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"])
    
    datatype = extrapar["datatype"]
    datastr = sess_str_util.datatype_par_str(datatype)

    if datatype == "roi":
        label_str = sess_str_util.fluor_par_str(
            analyspar["fluor"], str_type="print")
        full_label_str = u"{}, {} across ROIs".format(
            label_str, analyspar["stats"])
    elif datatype == "run":
        label_str = datastr
        full_label_str = datastr
    
    lab_app = (f" ({analyspar['stats']} over "
        f"{stimpar['pre']}/{stimpar['post']} sec)")

    logger.info(f"Plotting pupil vs {datastr} changes.")
    
    delta = "\u0394"

    # extract some info from sess_info
    keys = ["mouse_ns", "sess_ns", "lines", "planes"]
    [mouse_ns, sess_ns, lines, planes] = [sess_info[key] for key in keys]

    n_sess = len(mouse_ns)
    nroi_strs = sess_str_util.get_nroi_strs(
        sess_info, empty=(datatype!="roi"), style="comma"
        ) 

    if figpar is None:
        figpar = sess_plot_util.init_figpar()

    figpar = copy.deepcopy(figpar)
    if figpar["save"]["use_dt"] is None:
        figpar["save"]["use_dt"] = gen_util.create_time_str()
    figpar["init"]["sharex"] = False
    figpar["init"]["sharey"] = False
    figpar["init"]["ncols"] = n_sess
    
    fig, ax = plot_util.init_fig(2 * n_sess, **figpar["init"])
    suptitle = (f"Relationship between pupil diam. and {datastr} changes, "
        "locked to unexpected events")
    
    for i, sess_diffs in enumerate(corr_data["diffs"]):
        sub_axs = ax[:, i]
        title = (f"Mouse {mouse_ns[i]} - {stimstr_pr}, " + 
            u"{}".format(statstr_pr) + f"\n(sess {sess_ns[i]}, {lines[i]} "
            f"{planes[i]}{dendstr_pr}{nroi_strs[i]})")
        
        # top plot: correlations
        corr = f"Corr = {corr_data['corrs'][i]:.2f}"
        sub_axs[0].plot(
            sess_diffs[0], sess_diffs[1], marker=".", linestyle="None", 
            label=corr)
        sub_axs[0].set_title(title, y=1.01)
        sub_axs[0].set_xlabel(u"{} pupil diam.{}".format(delta, lab_app))
        if i == 0:
            sub_axs[0].set_ylabel(u"{} {}\n{}".format(
                delta, full_label_str, lab_app))
        sub_axs[0].legend()
        
        # bottom plot: differences across occurrences
        data_lab = u"{} {}".format(delta, label_str)   
        pup_lab = u"{} pupil diam.".format(delta)
        cols = []
        scaled = []
        for d, lab in enumerate([pup_lab, data_lab]):
            scaled.append(math_util.scale_data(
                np.asarray(sess_diffs[d]), sc_type="min_max")[0])
            art, = sub_axs[1].plot(scaled[-1], marker=".")
            cols.append(sub_axs[-1].lines[-1].get_color())
            if i == n_sess - 1: # only for last graph
                art.set_label(lab)
                sub_axs[1].legend()
        sub_axs[1].set_xlabel("Unexpected event occurrence")
        if i == 0:
            sub_axs[1].set_ylabel(
                u"{} response locked\nto unexpected onset (scaled)".format(delta))
        # shade area between lines
        plot_util.plot_btw_traces(
            sub_axs[1], scaled[0], scaled[1], color=cols, alpha=0.4)

    fig.suptitle(suptitle, fontsize="xx-large", y=1)

    if savedir is None:
        savedir = Path(
            figpar["dirs"][datatype],
            figpar["dirs"]["pupil"])

    savename = f"{datatype}_diff_corr_{sessstr}{dendstr}"

    fulldir = plot_util.savefig(fig, savename, savedir, **figpar["save"])

    return fulldir, savename                              
Пример #9
0
def run_pupil_diff_corr(sessions,
                        analysis,
                        analyspar,
                        sesspar,
                        stimpar,
                        figpar,
                        datatype="roi"):
    """
    run_pupil_diff_corr(sessions, analysis, analyspar, sesspar, 
                        stimpar, figpar)
    
    Calculates and plots between pupil and ROI/running changes
    locked to each unexpected, as well as the correlation.

    Saves results and parameters relevant to analysis in a dictionary.

    Required args:
        - sessions (list)      : list of Session objects
        - analysis (str)       : analysis type (e.g., "c")
        - analyspar (AnalysPar): named tuple containing analysis parameters
        - sesspar (SessPar)    : named tuple containing session parameters
        - stimpar (StimPar)    : named tuple containing stimulus parameters
        - figpar (dict)        : dictionary containing figure parameters
    
    Optional args:
        - datatype (str): type of data (e.g., "roi", "run")
    """

    sessstr_pr = sess_str_util.sess_par_str(sesspar.sess_n, stimpar.stimtype,
                                            sesspar.plane, stimpar.visflow_dir,
                                            stimpar.visflow_size, stimpar.gabk,
                                            "print")
    dendstr_pr = sess_str_util.dend_par_str(analyspar.dend, sesspar.plane,
                                            datatype, "print")

    datastr = sess_str_util.datatype_par_str(datatype)

    logger.info(
        "Analysing and plotting correlations between unexpected vs "
        f"expected {datastr} traces between sessions ({sessstr_pr}"
        f"{dendstr_pr}).",
        extra={"spacing": "\n"})

    sess_diffs = []
    sess_corr = []

    for sess in sessions:
        if datatype == "roi" and (sess.only_tracked_rois != analyspar.tracked):
            raise RuntimeError(
                "sess.only_tracked_rois should match analyspar.tracked.")
        diffs = peristim_data(sess,
                              stimpar,
                              datatype=datatype,
                              returns="diff",
                              scale=analyspar.scale,
                              first_unexp=True)
        [pup_diff, data_diff] = diffs
        # trials (x ROIs)
        if datatype == "roi":
            if analyspar.rem_bad:
                nanpol = None
            else:
                nanpol = "omit"
            data_diff = math_util.mean_med(data_diff,
                                           analyspar.stats,
                                           axis=-1,
                                           nanpol=nanpol)
        elif datatype != "run":
            gen_util.accepted_values_error("datatype", datatype,
                                           ["roi", "run"])
        sess_corr.append(np.corrcoef(pup_diff, data_diff)[0, 1])
        sess_diffs.append([diff.tolist() for diff in [pup_diff, data_diff]])

    extrapar = {
        "analysis": analysis,
        "datatype": datatype,
    }

    corr_data = {"corrs": sess_corr, "diffs": sess_diffs}

    sess_info = sess_gen_util.get_sess_info(sessions,
                                            analyspar.fluor,
                                            incl_roi=(datatype == "roi"),
                                            rem_bad=analyspar.rem_bad)

    info = {
        "analyspar": analyspar._asdict(),
        "sesspar": sesspar._asdict(),
        "stimpar": stimpar._asdict(),
        "extrapar": extrapar,
        "sess_info": sess_info,
        "corr_data": corr_data
    }

    fulldir, savename = pup_plots.plot_pup_diff_corr(figpar=figpar, **info)

    file_util.saveinfo(info, savename, fulldir, "json")