예제 #1
0
def get_sess_df_columns(session, analyspar, roi=True): 
    """
    get_sess_df_columns(session, analyspar)

    Returns basic session dataframe columns.

    Required args:
        - session (Session):
            Session object
        - analyspar (AnalysPar): 
            named tuple containing analysis parameters, used if sess_df is None

    Optional args:
        - roi (bool):
            if True, ROI data is included in sess_df, used if sess_df is None


    Returns:
        - sess_df_cols (list):
            session dataframe columns
    """

    sess_df = sess_gen_util.get_sess_info(
        [session], fluor=analyspar.fluor, incl_roi=roi, return_df=True, 
        rem_bad=analyspar.rem_bad
        )

    sess_df_cols = sess_df.columns.tolist()

    return sess_df_cols
예제 #2
0
def run_autocorr(sessions,
                 analysis,
                 analyspar,
                 sesspar,
                 stimpar,
                 autocorrpar,
                 figpar,
                 datatype="roi"):
    """
    run_autocorr(sessions, analysis, analyspar, sesspar, stimpar, autocorrpar, 
                 figpar)

    Calculates and plots autocorrelation during stimulus blocks.
    Saves results and parameters relevant to analysis in a dictionary.

    Required args:
        - sessions (list)          : list of Session objects
        - analysis (str)           : analysis type (e.g., "a")
        - analyspar (AnalysPar)    : named tuple containing analysis parameters
        - sesspar (SessPar)        : named tuple containing session parameters
        - stimpar (StimPar)        : named tuple containing stimulus parameters
        - autocorrpar (AutocorrPar): named tuple containing autocorrelation 
                                     analysis parameters
        - figpar (dict)            : dictionary containing figure parameters

    Optional args:
        - datatype (str): type of data (e.g., "roi", "run")
    """

    sessstr_pr = sess_str_util.sess_par_str(sesspar.sess_n, stimpar.stimtype,
                                            sesspar.plane, stimpar.visflow_dir,
                                            stimpar.visflow_size, stimpar.gabk,
                                            "print")
    dendstr_pr = sess_str_util.dend_par_str(analyspar.dend, sesspar.plane,
                                            datatype, "print")

    datastr = sess_str_util.datatype_par_str(datatype)

    logger.info(
        f"Analysing and plotting {datastr} autocorrelations "
        f"({sessstr_pr}{dendstr_pr}).",
        extra={"spacing": "\n"})

    xrans = []
    stats = []
    for sess in sessions:
        if datatype == "roi" and (sess.only_tracked_rois != analyspar.tracked):
            raise RuntimeError(
                "sess.only_tracked_rois should match analyspar.tracked.")
        stim = sess.get_stim(stimpar.stimtype)
        all_segs = stim.get_segs_by_criteria(visflow_dir=stimpar.visflow_dir,
                                             visflow_size=stimpar.visflow_size,
                                             gabk=stimpar.gabk,
                                             by="block")
        sess_traces = []
        for segs in all_segs:
            if len(segs) == 0:
                continue
            segs = sorted(segs)
            # check that segs are contiguous
            if max(np.diff(segs)) > 1:
                raise NotImplementedError("Segments used for autocorrelation "
                                          "must be contiguous within blocks.")
            if datatype == "roi":
                frame_edges = stim.get_fr_by_seg(
                    [min(segs), max(segs)], fr_type="twop")
                fr = list(range(min(frame_edges[0]), max(frame_edges[1]) + 1))
                traces = gen_util.reshape_df_data(sess.get_roi_traces(
                    fr,
                    fluor=analyspar.fluor,
                    rem_bad=analyspar.rem_bad,
                    scale=analyspar.scale),
                                                  squeeze_cols=True)

            elif datatype == "run":
                if autocorrpar.byitem != False:
                    raise ValueError("autocorrpar.byitem must be False for "
                                     "running data.")
                frame_edges = stim.get_fr_by_seg(
                    [min(segs), max(segs)], fr_type="stim")
                fr = list(range(min(frame_edges[0]), max(frame_edges[1]) + 1))

                traces = sess.get_run_velocity_by_fr(
                    fr,
                    fr_type="stim",
                    rem_bad=analyspar.rem_bad,
                    scale=analyspar.scale).to_numpy().reshape(1, -1)

            sess_traces.append(traces)

        # Calculate autocorr stats while filtering some warnings
        msgs = ["Degrees of freedom", "invalid value encountered"]
        categs = [RuntimeWarning, RuntimeWarning]
        with gen_util.TempWarningFilter(msgs, categs):
            xran, ac_st = math_util.autocorr_stats(sess_traces,
                                                   autocorrpar.lag_s,
                                                   sess.twop_fps,
                                                   byitem=autocorrpar.byitem,
                                                   stats=analyspar.stats,
                                                   error=analyspar.error)

        if not autocorrpar.byitem:  # also add a 10x lag
            lag_fr = 10 * int(autocorrpar.lag_s * sess.twop_fps)
            _, ac_st_10x = math_util.autocorr_stats(sess_traces,
                                                    lag_fr,
                                                    byitem=autocorrpar.byitem,
                                                    stats=analyspar.stats,
                                                    error=analyspar.error)
            downsamp = range(0, ac_st_10x.shape[-1], 10)

            if len(downsamp) != ac_st.shape[-1]:
                raise RuntimeError("Failed to downsample correctly. "
                                   "Check implementation.")
            ac_st = np.stack([ac_st, ac_st_10x[:, downsamp]], axis=1)
        xrans.append(xran)
        stats.append(ac_st)

    autocorr_data = {
        "xrans": [xran.tolist() for xran in xrans],
        "stats": [stat.tolist() for stat in stats]
    }

    sess_info = sess_gen_util.get_sess_info(sessions,
                                            analyspar.fluor,
                                            incl_roi=(datatype == "roi"),
                                            rem_bad=analyspar.rem_bad)

    extrapar = {
        "analysis": analysis,
        "datatype": datatype,
    }

    info = {
        "analyspar": analyspar._asdict(),
        "sesspar": sesspar._asdict(),
        "stimpar": stimpar._asdict(),
        "extrapar": extrapar,
        "autocorrpar": autocorrpar._asdict(),
        "autocorr_data": autocorr_data,
        "sess_info": sess_info
    }

    fulldir, savename = gen_plots.plot_autocorr(figpar=figpar, **info)

    file_util.saveinfo(info, savename, fulldir, "json")
예제 #3
0
def run_mag_change(sessions,
                   analysis,
                   seed,
                   analyspar,
                   sesspar,
                   stimpar,
                   permpar,
                   quantpar,
                   figpar,
                   datatype="roi"):
    """
    run_mag_change(sessions, analysis, seed, analyspar, sesspar, stimpar, 
                   permpar, quantpar, figpar)

    Calculates and plots the magnitude of change in activity of ROIs between 
    the first and last quantile for expected vs unexpected sequences.
    Saves results and parameters relevant to analysis in a dictionary.

    Required args:
        - sessions (list)      : list of Session objects
        - analysis (str)       : analysis type (e.g., "m")
        - seed (int)           : seed value to use. (-1 treated as None) 
        - analyspar (AnalysPar): named tuple containing analysis parameters
        - sesspar (SessPar)    : named tuple containing session parameters
        - stimpar (StimPar)    : named tuple containing stimulus parameters
        - permpar (PermPar)    : named tuple containing permutation parameters
        - quantpar (QuantPar)  : named tuple containing quantile analysis 
                                 parameters
        - figpar (dict)        : dictionary containing figure parameters   

    Optional args:
        - datatype (str): type of data (e.g., "roi", "run") 
    """

    sessstr_pr = sess_str_util.sess_par_str(sesspar.sess_n, stimpar.stimtype,
                                            sesspar.plane, stimpar.visflow_dir,
                                            stimpar.visflow_size, stimpar.gabk,
                                            "print")
    dendstr_pr = sess_str_util.dend_par_str(analyspar.dend, sesspar.plane,
                                            datatype, "print")

    datastr = sess_str_util.datatype_par_str(datatype)

    logger.info(
        f"Calculating and plotting the magnitude changes in {datastr} "
        f"activity across quantiles \n({sessstr_pr}{dendstr_pr}).",
        extra={"spacing": "\n"})

    if permpar.multcomp:
        permpar = sess_ntuple_util.get_modif_ntuple(permpar, "multcomp",
                                                    len(sessions))

    # get full data: session x unexp x quants of interest x [ROI x seq]
    integ_info = quant_analys.trace_stats_by_qu_sess(sessions,
                                                     analyspar,
                                                     stimpar,
                                                     quantpar.n_quants,
                                                     quantpar.qu_idx,
                                                     by_exp=True,
                                                     integ=True,
                                                     ret_arr=True,
                                                     datatype=datatype)
    all_counts = integ_info[-2]
    qu_data = integ_info[-1]

    # extract session info
    mouse_ns = [sess.mouse_n for sess in sessions]
    lines = [sess.line for sess in sessions]

    if analyspar.rem_bad:
        nanpol = None
    else:
        nanpol = "omit"

    seed = rand_util.seed_all(seed, "cpu", log_seed=False)

    mags = quant_analys.qu_mags(qu_data,
                                permpar,
                                mouse_ns,
                                lines,
                                analyspar.stats,
                                analyspar.error,
                                nanpol=nanpol,
                                op_qu="diff",
                                op_unexp="diff")

    # convert mags items to list
    mags = copy.deepcopy(mags)
    mags["all_counts"] = all_counts
    for key in ["mag_st", "L2", "mag_rel_th", "L2_rel_th"]:
        mags[key] = mags[key].tolist()

    sess_info = sess_gen_util.get_sess_info(sessions,
                                            analyspar.fluor,
                                            incl_roi=(datatype == "roi"),
                                            rem_bad=analyspar.rem_bad)

    extrapar = {"analysis": analysis, "datatype": datatype, "seed": seed}

    info = {
        "analyspar": analyspar._asdict(),
        "sesspar": sesspar._asdict(),
        "stimpar": stimpar._asdict(),
        "extrapar": extrapar,
        "permpar": permpar._asdict(),
        "quantpar": quantpar._asdict(),
        "mags": mags,
        "sess_info": sess_info
    }

    fulldir, savename = gen_plots.plot_mag_change(figpar=figpar, **info)

    file_util.saveinfo(info, savename, fulldir, "json")
예제 #4
0
def run_full_traces(sessions,
                    analysis,
                    analyspar,
                    sesspar,
                    figpar,
                    datatype="roi"):
    """
    run_full_traces(sessions, analysis, analyspar, sesspar, figpar)

    Plots full traces across an entire session. If ROI traces are plotted,
    each ROI is scaled and plotted separately and an average is plotted.
    
    Saves results and parameters relevant to analysis in a dictionary.

    Required args:
        - sessions (list)      : list of Session objects
        - analysis (str)       : analysis type (e.g., "f")
        - analyspar (AnalysPar): named tuple containing analysis parameters
        - sesspar (SessPar)    : named tuple containing session parameters
        - figpar (dict)        : dictionary containing figure parameters
    
    Optional args:
        - datatype (str): type of data (e.g., "roi", "run")
    """

    dendstr_pr = sess_str_util.dend_par_str(analyspar.dend, sesspar.plane,
                                            datatype, "print")

    sessstr_pr = (f"session: {sesspar.sess_n}, "
                  f"plane: {sesspar.plane}{dendstr_pr}")

    datastr = sess_str_util.datatype_par_str(datatype)

    logger.info(
        f"Plotting {datastr} traces across an entire "
        f"session\n({sessstr_pr}).",
        extra={"spacing": "\n"})

    figpar = copy.deepcopy(figpar)
    if figpar["save"]["use_dt"] is None:
        figpar["save"]["use_dt"] = gen_util.create_time_str()

    all_tr, roi_tr, all_edges, all_pars = [], [], [], []
    for sess in sessions:
        # get the block edges and parameters
        edge_fr, par_descrs = [], []
        for stim in sess.stims:
            stim_str = stim.stimtype
            if stim.stimtype == "visflow":
                stim_str = "vis. flow"
            if datatype == "roi":
                fr_type = "twop"
            elif datatype == "run":
                fr_type = "stim"
            else:
                gen_util.accepted_values_error("datatype", datatype,
                                               ["roi", "run"])
            for b in stim.block_params.index:
                row = stim.block_params.loc[b]
                edge_fr.append([
                    int(row[f"start_frame_{fr_type}"]),
                    int(row[f"stop_frame_{fr_type}"])
                ])
                par_vals = [row[param] for param in stim.stim_params]
                pars_str = "\n".join([str(par) for par in par_vals][0:2])

                par_descrs.append(
                    sess_str_util.pars_to_descr(
                        f"{stim_str.capitalize()}\n{pars_str}"))

        if datatype == "roi":
            if sess.only_tracked_rois != analyspar.tracked:
                raise RuntimeError(
                    "sess.only_tracked_rois should match analyspar.tracked.")
            nanpol = None
            if not analyspar.rem_bad:
                nanpol = "omit"
            all_rois = gen_util.reshape_df_data(sess.get_roi_traces(
                None, analyspar.fluor, analyspar.rem_bad,
                analyspar.scale)["roi_traces"],
                                                squeeze_cols=True)
            full_tr = math_util.get_stats(all_rois,
                                          analyspar.stats,
                                          analyspar.error,
                                          axes=0,
                                          nanpol=nanpol).tolist()
            roi_tr.append(all_rois.tolist())
        elif datatype == "run":
            full_tr = sess.get_run_velocity(
                rem_bad=analyspar.rem_bad,
                scale=analyspar.scale).to_numpy().squeeze().tolist()
            roi_tr = None
        all_tr.append(full_tr)
        all_edges.append(edge_fr)
        all_pars.append(par_descrs)

    extrapar = {
        "analysis": analysis,
        "datatype": datatype,
    }

    trace_info = {
        "all_tr": all_tr,
        "all_edges": all_edges,
        "all_pars": all_pars
    }

    sess_info = sess_gen_util.get_sess_info(sessions,
                                            analyspar.fluor,
                                            incl_roi=(datatype == "roi"),
                                            rem_bad=analyspar.rem_bad)

    info = {
        "analyspar": analyspar._asdict(),
        "sesspar": sesspar._asdict(),
        "extrapar": extrapar,
        "sess_info": sess_info,
        "trace_info": trace_info
    }

    fulldir, savename = gen_plots.plot_full_traces(roi_tr=roi_tr,
                                                   figpar=figpar,
                                                   **info)
    file_util.saveinfo(info, savename, fulldir, "json")
예제 #5
0
def run_traces_by_qu_lock_sess(sessions,
                               analysis,
                               seed,
                               analyspar,
                               sesspar,
                               stimpar,
                               quantpar,
                               figpar,
                               datatype="roi"):
    """
    run_traces_by_qu_lock_sess(sessions, analysis, analyspar, sesspar, 
                               stimpar, quantpar, figpar)

    Retrieves trace statistics by session x quantile at the transition of
    expected to unexpected sequences (or v.v.) and plots traces across ROIs by 
    quantile with each session in a separate subplot.
    
    Also runs analysis for one quantile (full data) with different unexpected 
    lengths grouped separated 
    
    Saves results and parameters relevant to analysis in a dictionary.

    Required args:
        - sessions (list)      : list of Session objects
        - analysis (str)       : analysis type (e.g., "l")
        - seed (int)           : seed value to use. (-1 treated as None)
        - analyspar (AnalysPar): named tuple containing analysis parameters
        - sesspar (SessPar)    : named tuple containing session parameters
        - stimpar (StimPar)    : named tuple containing stimulus parameters
        - quantpar (QuantPar)  : named tuple containing quantile analysis 
                                 parameters
        - figpar (dict)        : dictionary containing figure parameters
    
    Optional args:
        - datatype (str): type of data (e.g., "roi", "run")

    """

    sessstr_pr = sess_str_util.sess_par_str(sesspar.sess_n, stimpar.stimtype,
                                            sesspar.plane, stimpar.visflow_dir,
                                            stimpar.visflow_size, stimpar.gabk,
                                            "print")
    dendstr_pr = sess_str_util.dend_par_str(analyspar.dend, sesspar.plane,
                                            datatype, "print")

    datastr = sess_str_util.datatype_par_str(datatype)

    logger.info(
        f"Analysing and plotting unexpected vs expected {datastr} "
        f"traces locked to unexpected onset by quantile ({quantpar.n_quants}) "
        f"\n({sessstr_pr}{dendstr_pr}).",
        extra={"spacing": "\n"})

    seed = rand_util.seed_all(seed, "cpu", log_seed=False)

    # modify quantpar to retain all quantiles
    quantpar_one = sess_ntuple_util.init_quantpar(1, 0)
    n_quants = quantpar.n_quants
    quantpar_mult = sess_ntuple_util.init_quantpar(n_quants, "all")

    if stimpar.stimtype == "visflow":
        pre_post = [2.0, 6.0]
    elif stimpar.stimtype == "gabors":
        pre_post = [2.0, 8.0]
    else:
        gen_util.accepted_values_error("stimpar.stimtype", stimpar.stimtype,
                                       ["visflow", "gabors"])
    logger.warning("Setting pre to {}s and post to {}s.".format(*pre_post))

    stimpar = sess_ntuple_util.get_modif_ntuple(stimpar, ["pre", "post"],
                                                pre_post)

    figpar = copy.deepcopy(figpar)
    if figpar["save"]["use_dt"] is None:
        figpar["save"]["use_dt"] = gen_util.create_time_str()

    for baseline in [None, stimpar.pre]:
        basestr_pr = sess_str_util.base_par_str(baseline, "print")
        for quantpar in [quantpar_one, quantpar_mult]:
            locks = ["unexp", "exp"]
            if quantpar.n_quants == 1:
                locks.append("unexp_split")
            # get the stats (all) separating by session and quantiles
            for lock in locks:
                logger.info(
                    f"{quantpar.n_quants} quant, {lock} lock{basestr_pr}",
                    extra={"spacing": "\n"})
                if lock == "unexp_split":
                    trace_info = quant_analys.trace_stats_by_exp_len_sess(
                        sessions,
                        analyspar,
                        stimpar,
                        quantpar.n_quants,
                        quantpar.qu_idx,
                        byroi=False,
                        nan_empty=True,
                        baseline=baseline,
                        datatype=datatype)
                else:
                    trace_info = quant_analys.trace_stats_by_qu_sess(
                        sessions,
                        analyspar,
                        stimpar,
                        quantpar.n_quants,
                        quantpar.qu_idx,
                        byroi=False,
                        lock=lock,
                        nan_empty=True,
                        baseline=baseline,
                        datatype=datatype)

                # for comparison, locking to middle of expected sample (1 quant)
                exp_samp = quant_analys.trace_stats_by_qu_sess(
                    sessions,
                    analyspar,
                    stimpar,
                    quantpar_one.n_quants,
                    quantpar_one.qu_idx,
                    byroi=False,
                    lock="exp_samp",
                    nan_empty=True,
                    baseline=baseline,
                    datatype=datatype)

                extrapar = {
                    "analysis": analysis,
                    "datatype": datatype,
                    "seed": seed,
                }

                xrans = [xran.tolist() for xran in trace_info[0]]
                all_stats = [sessst.tolist() for sessst in trace_info[1]]
                exp_stats = [expst.tolist() for expst in exp_samp[1]]
                trace_stats = {
                    "xrans": xrans,
                    "all_stats": all_stats,
                    "all_counts": trace_info[2],
                    "lock": lock,
                    "baseline": baseline,
                    "exp_stats": exp_stats,
                    "exp_counts": exp_samp[2]
                }

                if lock == "unexp_split":
                    trace_stats["unexp_lens"] = trace_info[3]

                sess_info = sess_gen_util.get_sess_info(
                    sessions,
                    analyspar.fluor,
                    incl_roi=(datatype == "roi"),
                    rem_bad=analyspar.rem_bad)

                info = {
                    "analyspar": analyspar._asdict(),
                    "sesspar": sesspar._asdict(),
                    "stimpar": stimpar._asdict(),
                    "quantpar": quantpar._asdict(),
                    "extrapar": extrapar,
                    "sess_info": sess_info,
                    "trace_stats": trace_stats
                }

                fulldir, savename = gen_plots.plot_traces_by_qu_lock_sess(
                    figpar=figpar, **info)
                file_util.saveinfo(info, savename, fulldir, "json")
예제 #6
0
def run_traces_by_qu_unexp_sess(sessions,
                                analysis,
                                analyspar,
                                sesspar,
                                stimpar,
                                quantpar,
                                figpar,
                                datatype="roi"):
    """
    run_traces_by_qu_unexp_sess(sessions, analysis, analyspar, sesspar, 
                               stimpar, quantpar, figpar)

    Retrieves trace statistics by session x unexp val x quantile and
    plots traces across ROIs by quantile/unexpected with each session in a 
    separate subplot.
    
    Also runs analysis for one quantile (full data).
    
    Saves results and parameters relevant to analysis in a dictionary.

    Required args:
        - sessions (list)      : list of Session objects
        - analysis (str)       : analysis type (e.g., "t")
        - analyspar (AnalysPar): named tuple containing analysis parameters
        - sesspar (SessPar)    : named tuple containing session parameters
        - stimpar (StimPar)    : named tuple containing stimulus parameters
        - quantpar (QuantPar)  : named tuple containing quantile analysis 
                                 parameters
        - figpar (dict)        : dictionary containing figure parameters
    
    Optional args:
        - datatype (str): type of data (e.g., "roi", "run")
    """

    sessstr_pr = sess_str_util.sess_par_str(sesspar.sess_n, stimpar.stimtype,
                                            sesspar.plane, stimpar.visflow_dir,
                                            stimpar.visflow_size, stimpar.gabk,
                                            "print")
    dendstr_pr = sess_str_util.dend_par_str(analyspar.dend, sesspar.plane,
                                            datatype, "print")

    datastr = sess_str_util.datatype_par_str(datatype)

    logger.info(
        f"Analysing and plotting unexpected vs expected {datastr} "
        f"traces by quantile ({quantpar.n_quants}) \n({sessstr_pr}"
        f"{dendstr_pr}).",
        extra={"spacing": "\n"})

    # modify quantpar to retain all quantiles
    quantpar_one = sess_ntuple_util.init_quantpar(1, 0)
    n_quants = quantpar.n_quants
    quantpar_mult = sess_ntuple_util.init_quantpar(n_quants, "all")

    figpar = copy.deepcopy(figpar)
    if figpar["save"]["use_dt"] is None:
        figpar["save"]["use_dt"] = gen_util.create_time_str()

    for quantpar in [quantpar_one, quantpar_mult]:
        logger.info(f"{quantpar.n_quants} quant", extra={"spacing": "\n"})
        # get the stats (all) separating by session, unexpected and quantiles
        trace_info = quant_analys.trace_stats_by_qu_sess(sessions,
                                                         analyspar,
                                                         stimpar,
                                                         quantpar.n_quants,
                                                         quantpar.qu_idx,
                                                         byroi=False,
                                                         by_exp=True,
                                                         datatype=datatype)

        extrapar = {
            "analysis": analysis,
            "datatype": datatype,
        }

        xrans = [xran.tolist() for xran in trace_info[0]]
        all_stats = [sessst.tolist() for sessst in trace_info[1]]
        trace_stats = {
            "xrans": xrans,
            "all_stats": all_stats,
            "all_counts": trace_info[2]
        }

        sess_info = sess_gen_util.get_sess_info(sessions,
                                                analyspar.fluor,
                                                incl_roi=(datatype == "roi"),
                                                rem_bad=analyspar.rem_bad)

        info = {
            "analyspar": analyspar._asdict(),
            "sesspar": sesspar._asdict(),
            "stimpar": stimpar._asdict(),
            "quantpar": quantpar._asdict(),
            "extrapar": extrapar,
            "sess_info": sess_info,
            "trace_stats": trace_stats
        }

        fulldir, savename = gen_plots.plot_traces_by_qu_unexp_sess(
            figpar=figpar, **info)
        file_util.saveinfo(info, savename, fulldir, "json")
예제 #7
0
def run_glms(sessions,
             analysis,
             seed,
             analyspar,
             sesspar,
             stimpar,
             glmpar,
             figpar,
             parallel=False):
    """
    run_glms(sessions, analysis, seed, analyspar, sesspar, stimpar, glmpar, 
             figpar)
    """

    seed = rand_util.seed_all(seed, "cpu", log_seed=False)

    sessstr_pr = sess_str_util.sess_par_str(sesspar.sess_n, stimpar.stimtype,
                                            sesspar.plane, stimpar.visflow_dir,
                                            stimpar.visflow_size, stimpar.gabk,
                                            "print")
    dendstr_pr = sess_str_util.dend_par_str(analyspar.dend, sesspar.plane,
                                            "roi", "print")

    logger.info(
        "Analysing and plotting explained variance in ROI activity "
        f"({sessstr_pr}{dendstr_pr}).",
        extra={"spacing": "\n"})

    if glmpar.each_roi:  # must do each session separately
        glm_type = "per_ROI_per_sess"
        sess_batches = sessions
        logger.info("Per ROI, each session separately.")
    else:
        glm_type = "across_sess"
        sess_batches = [sessions]
        logger.info(f"Across ROIs, {len(sessions)} sessions together.")

    # optionally runs in parallel, or propagates parallel to next level
    parallel_here = (parallel and not (glmpar.each_roi)
                     and (len(sess_batches) != 1))
    parallel_after = True if (parallel and not (parallel_here)) else False

    args_list = [analyspar, sesspar, stimpar, glmpar]
    args_dict = {
        "parallel": parallel_after,  # proactively set next parallel 
        "seed": seed,
    }
    all_expl_var = gen_util.parallel_wrap(run_glm,
                                          sess_batches,
                                          args_list,
                                          args_dict,
                                          parallel=parallel_here)

    if glmpar.each_roi:
        sessions = sess_batches
    else:
        sessions = sess_batches[0]

    sess_info = sess_gen_util.get_sess_info(sessions,
                                            analyspar.fluor,
                                            rem_bad=analyspar.rem_bad)

    extrapar = {
        "analysis": analysis,
        "seed": seed,
        "glm_type": glm_type,
    }

    info = {
        "analyspar": analyspar._asdict(),
        "sesspar": sesspar._asdict(),
        "stimpar": stimpar._asdict(),
        "glmpar": glmpar._asdict(),
        "extrapar": extrapar,
        "all_expl_var": all_expl_var,
        "sess_info": sess_info
    }

    fulldir, savename = glm_plots.plot_glm_expl_var(figpar=figpar, **info)

    file_util.saveinfo(info, savename, fulldir, "json")

    return
예제 #8
0
def run_pup_roi_stim_corr(sessions,
                          analysis,
                          analyspar,
                          sesspar,
                          stimpar,
                          figpar,
                          datatype="roi",
                          parallel=False):
    """
    run_pup_roi_stim_corr(sessions, analysis, analyspar, sesspar, stimpar, 
                          figpar)
    
    Calculates and plots correlation between pupil and ROI changes locked to
    unexpected for gabors vs visflow.
    
    Saves results and parameters relevant to analysis in a dictionary.

    Required args:
        - sessions (list)      : list of Session objects
        - analysis (str)       : analysis type (e.g., "r")
        - analyspar (AnalysPar): named tuple containing analysis parameters
        - sesspar (SessPar)    : named tuple containing session parameters
        - stimpar (StimPar)    : named tuple containing stimulus parameters
        - figpar (dict)        : dictionary containing figure parameters
    
    Optional args:
        - datatype (str) : type of data (e.g., "roi", "run")
        - parallel (bool): if True, some of the analysis is parallelized across 
                           CPU cores
                           default: False
    """

    if datatype != "roi":
        raise NotImplementedError(
            "Analysis only implemented for roi datatype.")

    stimtypes = ["gabors", "visflow"]
    if stimpar.stimtype != "both":
        non_stimtype = stimtypes[1 - stimtypes.index(stimpar.stimtype)]
        warnings.warn(
            "stimpar.stimtype will be set to 'both', but non default "
            f"{non_stimtype} parameters are lost.",
            category=RuntimeWarning,
            stacklevel=1)
        stimpar_dict = stimpar._asdict()
        for key in list(stimpar_dict.keys()):  # remove any "none"s
            if stimpar_dict[key] == "none":
                stimpar_dict.pop(key)

    sessstr_pr = f"session: {sesspar.sess_n}, plane: {sesspar.plane}"
    dendstr_pr = sess_str_util.dend_par_str(analyspar.dend, sesspar.plane,
                                            datatype, "print")
    stimstr_pr = []
    stimpars = []
    for stimtype in stimtypes:
        stimpar_dict["stimtype"] = stimtype
        stimpar_dict["gabfr"] = 3
        stimpars.append(sess_ntuple_util.init_stimpar(**stimpar_dict))
        stimstr_pr.append(
            sess_str_util.stim_par_str(stimtype, stimpars[-1].visflow_dir,
                                       stimpars[-1].visflow_size,
                                       stimpars[-1].gabk, "print"))
    stimpar_dict = stimpars[0]._asdict()
    stimpar_dict["stimtype"] = "both"

    logger.info(
        "Analysing and plotting correlations between unexpected vs "
        f"expected ROI traces between sessions ({sessstr_pr}{dendstr_pr}).",
        extra={"spacing": "\n"})
    sess_corrs = []
    sess_roi_corrs = []
    for sess in sessions:
        if datatype == "roi" and (sess.only_tracked_rois != analyspar.tracked):
            raise RuntimeError(
                "sess.only_tracked_rois should match analyspar.tracked.")
        stim_corrs = []
        for sub_stimpar in stimpars:
            diffs = peristim_data(sess,
                                  sub_stimpar,
                                  datatype="roi",
                                  returns="diff",
                                  first_unexp=True,
                                  rem_bad=analyspar.rem_bad,
                                  scale=analyspar.scale)
            [pup_diff, roi_diff] = diffs
            nrois = roi_diff.shape[-1]
            # optionally runs in parallel
            if parallel and nrois > 1:
                n_jobs = gen_util.get_n_jobs(nrois)
                with gen_util.ParallelLogging():
                    corrs = Parallel(n_jobs=n_jobs)(
                        delayed(np.corrcoef)(roi_diff[:, r], pup_diff)
                        for r in range(nrois))
                corrs = np.asarray([corr[0, 1] for corr in corrs])
            else:
                corrs = np.empty(nrois)
                for r in range(nrois):  # cycle through ROIs
                    corrs[r] = np.corrcoef(roi_diff[:, r], pup_diff)[0, 1]
            stim_corrs.append(corrs)
        sess_corrs.append(np.corrcoef(stim_corrs[0], stim_corrs[1])[0, 1])
        sess_roi_corrs.append([corrs.tolist() for corrs in stim_corrs])

    extrapar = {
        "analysis": analysis,
        "datatype": datatype,
    }

    corr_data = {
        "stim_order": stimtypes,
        "roi_corrs": sess_roi_corrs,
        "corrs": sess_corrs
    }

    sess_info = sess_gen_util.get_sess_info(sessions,
                                            analyspar.fluor,
                                            incl_roi=(datatype == "roi"),
                                            rem_bad=analyspar.rem_bad)

    info = {
        "analyspar": analyspar._asdict(),
        "sesspar": sesspar._asdict(),
        "stimpar": stimpar_dict,
        "extrapar": extrapar,
        "sess_info": sess_info,
        "corr_data": corr_data
    }

    fulldir, savename = pup_plots.plot_pup_roi_stim_corr(figpar=figpar, **info)

    file_util.saveinfo(info, savename, fulldir, "json")
예제 #9
0
def run_pupil_diff_corr(sessions,
                        analysis,
                        analyspar,
                        sesspar,
                        stimpar,
                        figpar,
                        datatype="roi"):
    """
    run_pupil_diff_corr(sessions, analysis, analyspar, sesspar, 
                        stimpar, figpar)
    
    Calculates and plots between pupil and ROI/running changes
    locked to each unexpected, as well as the correlation.

    Saves results and parameters relevant to analysis in a dictionary.

    Required args:
        - sessions (list)      : list of Session objects
        - analysis (str)       : analysis type (e.g., "c")
        - analyspar (AnalysPar): named tuple containing analysis parameters
        - sesspar (SessPar)    : named tuple containing session parameters
        - stimpar (StimPar)    : named tuple containing stimulus parameters
        - figpar (dict)        : dictionary containing figure parameters
    
    Optional args:
        - datatype (str): type of data (e.g., "roi", "run")
    """

    sessstr_pr = sess_str_util.sess_par_str(sesspar.sess_n, stimpar.stimtype,
                                            sesspar.plane, stimpar.visflow_dir,
                                            stimpar.visflow_size, stimpar.gabk,
                                            "print")
    dendstr_pr = sess_str_util.dend_par_str(analyspar.dend, sesspar.plane,
                                            datatype, "print")

    datastr = sess_str_util.datatype_par_str(datatype)

    logger.info(
        "Analysing and plotting correlations between unexpected vs "
        f"expected {datastr} traces between sessions ({sessstr_pr}"
        f"{dendstr_pr}).",
        extra={"spacing": "\n"})

    sess_diffs = []
    sess_corr = []

    for sess in sessions:
        if datatype == "roi" and (sess.only_tracked_rois != analyspar.tracked):
            raise RuntimeError(
                "sess.only_tracked_rois should match analyspar.tracked.")
        diffs = peristim_data(sess,
                              stimpar,
                              datatype=datatype,
                              returns="diff",
                              scale=analyspar.scale,
                              first_unexp=True)
        [pup_diff, data_diff] = diffs
        # trials (x ROIs)
        if datatype == "roi":
            if analyspar.rem_bad:
                nanpol = None
            else:
                nanpol = "omit"
            data_diff = math_util.mean_med(data_diff,
                                           analyspar.stats,
                                           axis=-1,
                                           nanpol=nanpol)
        elif datatype != "run":
            gen_util.accepted_values_error("datatype", datatype,
                                           ["roi", "run"])
        sess_corr.append(np.corrcoef(pup_diff, data_diff)[0, 1])
        sess_diffs.append([diff.tolist() for diff in [pup_diff, data_diff]])

    extrapar = {
        "analysis": analysis,
        "datatype": datatype,
    }

    corr_data = {"corrs": sess_corr, "diffs": sess_diffs}

    sess_info = sess_gen_util.get_sess_info(sessions,
                                            analyspar.fluor,
                                            incl_roi=(datatype == "roi"),
                                            rem_bad=analyspar.rem_bad)

    info = {
        "analyspar": analyspar._asdict(),
        "sesspar": sesspar._asdict(),
        "stimpar": stimpar._asdict(),
        "extrapar": extrapar,
        "sess_info": sess_info,
        "corr_data": corr_data
    }

    fulldir, savename = pup_plots.plot_pup_diff_corr(figpar=figpar, **info)

    file_util.saveinfo(info, savename, fulldir, "json")
예제 #10
0
def get_check_sess_df(sessions, sess_df=None, analyspar=None, roi=True): 
    """
    get_check_sess_df(sessions)

    Checks a dataframe against existing sessions (that they match and are in 
    the same order), or returns a dataframe with session information if sess_df 
    is None.

    Required args:
        - sessions (list):
            Session objects

    Optional args:
        - sess_df (pd.DataFrame):
            dataframe containing session information (see keys under Returns)
            default: None
        - analyspar (AnalysPar): 
            named tuple containing analysis parameters, used if sess_df is None
        - roi (bool):
            if True, ROI data is included in sess_df, used if sess_df is None


    Returns:
        - sess_df (pd.DataFrame):
            dataframe containing session information under the following keys:
            "mouse_ns", "mouseids", "sess_ns", "sessids", "lines", "planes"
            if datatype == "roi":
                "nrois", "twop_fps"
            if not rem_bad: 
                "bad_rois_{}" (depending on fluor)
    """

    sessions = gen_util.list_if_not(sessions)

    if sess_df is None:
        roi_kwargs = dict()
        if analyspar is None and roi:
            raise ValueError("If sess_df is None, must pass analyspar.")
        elif analyspar is not None:
            roi_kwargs["fluor"] = analyspar.fluor
            roi_kwargs["rem_bad"] = analyspar.rem_bad

        sess_df = sess_gen_util.get_sess_info(
            sessions, incl_roi=roi, return_df=True, **roi_kwargs
            )

    else:
        if len(sess_df) != len(sessions):
            raise ValueError(
                "'sess_df' should have as many rows as 'sessions'.")
        # check order
        sessids = np.asarray([sess.sessid for sess in sessions]).astype(int)
        sess_df_sessids = sess_df.sessids.to_numpy().astype(int)

        if len(sessids) != len(sess_df_sessids):
            raise ValueError("'sess_df' is not the same length at 'sessions'.")

        elif (np.sort(sessids) != np.sort(sess_df_sessids)).any():
            raise ValueError("Sessions do not match ids in 'sess_df'.")

        elif (sessids != sess_df_sessids).any():
            raise ValueError("Sessions do not appear in order in 'sess_df'.")

    return sess_df