コード例 #1
0
def get_pupil_run_block_diffs_df(sessions, analyspar, stimpar, parallel=False):
    """
    get_pupil_run_block_diffs_df(sessions, analyspar, stimpar)

    Returns pupil and running statistic differences (unexp - exp) by block.

    Required args:
        - sessions (list): 
            session objects
        - analyspar (AnalysPar): 
            named tuple containing analysis parameters
        - stimpar (StimPar): 
            named tuple containing stimulus parameters

    Returns:
        - block_df (pd.DataFrame):
            dataframe with a row for each session, and the following 
            columns, in addition to the basic sess_df columns: 
            - run_block_diffs (1D array):
                split differences per block
            - run_block_stats (3D array): 
                block statistics (split x block x stats (me, err))
            - pupil_block_diffs (1D array):
                split differences per block
            - pupil_block_stats (3D array): 
                block statistics (split x block x stats (me, err))
    """

    block_df = misc_analys.get_check_sess_df(sessions,
                                             None,
                                             analyspar,
                                             roi=False)

    # retrieve ROI index information
    args_dict = {
        "analyspar": analyspar,
        "stimpar": stimpar,
    }

    misc_analys.get_check_sess_df(sessions, block_df)
    for datatype in ["pupil", "run"]:
        args_dict["datatype"] = datatype
        # sess x split x block x stats
        block_stats = gen_util.parallel_wrap(basic_analys.get_block_data,
                                             sessions,
                                             args_dict=args_dict,
                                             parallel=parallel)

        block_diffs = []
        for sess_block_data in block_stats:
            # take difference (unexp - exp statistic) for each block
            stat_diffs = sess_block_data[1, ..., 0] - sess_block_data[0, ...,
                                                                      0]
            block_diffs.append(stat_diffs)

        block_df[f"{datatype}_block_stats"] = block_stats
        block_df[f"{datatype}_block_diffs"] = block_diffs

    return block_df
コード例 #2
0
def main(args):
    """
    main(args)

    Runs analyses with parser arguments.

    Required args:
        - args (dict): parser argument dictionary
    """

    # set logger to the specified level
    logger_util.set_level(level=args.log_level)

    args.device = "cpu"

    if args.datadir is None: 
        args.datadir = DEFAULT_DATADIR
    else:
        args.datadir = Path(args.datadir)
    args.mouse_df = DEFAULT_MOUSE_DF_PATH
    args.runtype = "prod"
    args.plane = "soma"
    args.stimtype = "gabors"


    args.omit_sess, args.omit_mice = sess_gen_util.all_omit(
        args.stimtype, args.runtype
        )

    
    all_sessids = sess_gen_util.get_sess_vals(
        args.mouse_df, "sessid", runtype=args.runtype, sess_n=[1, 2, 3], 
        plane=args.plane, min_rois=1, pass_fail="P", omit_sess=args.omit_sess, 
        omit_mice=args.omit_mice)


    # bsizes =[1, 15, 30] #3
    # outchs = [18, 9, 3]
    # hiddims = [100, 35, 5]
    # numlays = [3, 2, 1]
    # lr_exs = [4, 3, 5]
    # convs = [True, False]
    # args.n_epochs = 0

    gen_util.parallel_wrap(
        run_sess_lstm, all_sessids, args_list=[args], parallel=args.parallel)
コード例 #3
0
def estim_vm_by_roi(oris, roi_data, hist_n=1000, parallel=False):
    """
    estim_vm_by_roi(oris, roi_data)

    Returns estimates of von Mises distributions for each ROI based on the 
    input orientations and ROI activations.

    Required args:
        - oris (array-like)  : array of orientations
        - roi_data (2D array): corresponding array of ROI values, 
                               structured as ROI x ori
    
    Optional args:
        - hist_n (int)   : number of values to build the histogram with
                           default: 1000 
        - parallel (bool): if True, some of the analysis is parallelized across 
                           CPU cores

    Returns:
        - tc_oris (list)         : list of orientation values (sorted and 
                                   unique)
        - tc_data (list)         : list of mean integrated fluorescence data 
                                   per unique orientation, structured 
                                   as ROI x orientation
        - tc_vm_pars (nd array)  : tuning curve (von Mises) parameter estimates, 
                                   structured as ROI x par (kappa, mean, scale)
        - tc_hist_pars (nd array): paremeters to create histograms, structured 
                                   as ROI x par (sub, mult)
    """

    # sort by gab orientation
    x_sort_idx = np.argsort(oris)
    xsort = oris[x_sort_idx]
    roi_data_sort = roi_data[:, x_sort_idx]

    # get mask of int values for each unique x value
    x_cuml = np.insert((np.diff(xsort) != 0).cumsum(), 0, 0)
    tc_oris = np.unique(xsort).tolist()

    # optionally run in parallel
    args_list = [x_cuml, tc_oris, hist_n]
    returns = gen_util.parallel_wrap(estim_vm,
                                     roi_data_sort,
                                     args_list,
                                     parallel=parallel,
                                     zip_output=True)

    tc_data = [list(ret) for ret in returns[0]]
    tc_vm_pars = np.asarray([list(ret) for ret in returns[1]])
    tc_hist_pars = np.asarray([list(ret) for ret in returns[2]])

    return tc_oris, tc_data, tc_vm_pars, tc_hist_pars
コード例 #4
0
def mean_signal_sess123(sessions, analyspar, sesspar, figpar, parallel=False):
    """
    mean_signal_sess123(sessions, analyspar, sesspar, figpar)

    Retrieves ROI mean signal values for sessions 1 to 3.
        
    Saves results and parameters relevant to analysis in a dictionary.

    Required args:
        - session (Session):
            Session object
        - analyspar (AnalysPar): 
            named tuple containing analysis parameters
        - sesspar (SessPar): 
            named tuple containing session parameters
        - figpar (dict): 
            dictionary containing figure parameters
    
    Optional args:
        - parallel (bool): 
            if True, some of the analysis is run in parallel across CPU cores 
            default: False
    """

    logger.info("Compiling ROI signal means from session 1 to 3.",
                extra={"spacing": "\n"})

    logger.info("Calculating ROI signal means for each session...",
                extra={"spacing": TAB})
    all_signal_means = gen_util.parallel_wrap(misc_analys.get_snr,
                                              sessions,
                                              [analyspar, "signal_means"],
                                              parallel=parallel)

    sig_mean_df = misc_analys.get_check_sess_df(sessions, analyspar=analyspar)
    sig_mean_df["signal_means"] = [
        sig_mean.tolist() for sig_mean in all_signal_means
    ]

    extrapar = dict()

    info = {
        "analyspar": analyspar._asdict(),
        "sesspar": sesspar._asdict(),
        "extrapar": extrapar,
        "sig_mean_df": sig_mean_df.to_dict()
    }

    helper_fcts.plot_save_all(info, figpar)
コード例 #5
0
def init_sessions(sessids, datadir, mouse_df, analyspar, runtype="prod", 
                  full_table=False, parallel=False):
    """
    init_sessions(sessids, datadir, mouse_df, analyspar)

    Initializes sessions.

    Required args:
        - sessids (list)       : IDs of sessions to load
        - datadir (Path)       : data directory
        - mouse_df (pandas df) : path name of dataframe containing information 
                                 on each session
        - analyspar (AnalysPar): named tuple containing analysis parameters

    Optional args:
        - runtype (str)    : runtype ("pilot" or "prod")
        - full_table (bool): if True, full stimulus dataframe is loaded
                             default: False
        - parallel (bool)  : if True, some analyses are parallelized 
                             across CPU cores 
                             default: False

    Returns:
        - sessions (list): list of sessions
    """

    args_dict = {
        "datadir"   : datadir,
        "mouse_df"  : mouse_df,
        "runtype"   : runtype,
        "full_table": full_table,
        "fluor"     : analyspar.fluor,
        "dend"      : analyspar.dend,
        "omit"      : True,
        "temp_log"  : "warning",
    }

    sessions = gen_util.parallel_wrap(
        sess_gen_util.init_sessions, sessids, args_dict=args_dict, 
        parallel=parallel, use_tqdm=True
        )

    # flatten list of sessions
    sessions = [sess for singles in sessions for sess in singles]

    return sessions
コード例 #6
0
def get_resp_df(sessions, analyspar, stimpar, rel_sess=1, parallel=False):
    """
    get_resp_df(sessions, analyspar, stimpar)

    Returns relative response dataframe for requested sessions.

    Required args:
        - sessions (list): 
            session objects
        - analyspar (AnalysPar): 
            named tuple containing analysis parameters
        - stimpar (StimPar): 
            named tuple containing stimulus parameters

    Optional args:
        - rel_sess (int):
            number of session relative to which data should be scaled, for each 
            mouse. If None, relative data is not added.
            default: 1
        - parallel (bool): 
            if True, some of the analysis is run in parallel across CPU cores 
            default: False

    Returns:
        - resp_data_df (pd.DataFrame):
            data dictionary with response stats (2D array, ROI x stats) under 
            keys for expected ("exp") and unexpected ("unexp") data, 
            separated by Gabor frame (e.g., "exp_3", "unexp_G") 
            if stimpar.stimtype == "gabors", and 
            with "rel_{}" columns added for each input column with "exp" in its 
            name if rel_sess is not None.
    """

    if analyspar.tracked:
        misc_analys.check_sessions_complete(sessions, raise_err=True)

    sessids = [sess.sessid for sess in sessions]
    resp_data_df = misc_analys.get_check_sess_df(sessions, analyspar=analyspar)

    # double check that sessions are in correct order
    if resp_data_df["sessids"].tolist() != sessids:
        raise NotImplementedError(
            "Implementation error. Sessions must appear in correct order in "
            "resp_data_df.")

    logger.info(f"Loading data for each session...", extra={"spacing": TAB})
    data_dicts = gen_util.parallel_wrap(get_sess_integ_resp_dict,
                                        sessions,
                                        args_list=[analyspar, stimpar],
                                        parallel=parallel)

    # add data to df
    misc_analys.get_check_sess_df(sessions, resp_data_df)
    for i, idx in enumerate(resp_data_df.index):
        for key, value in data_dicts[i].items():
            if i == 0:
                resp_data_df = gen_util.set_object_columns(resp_data_df, [key],
                                                           in_place=True)
            resp_data_df.at[idx, key] = value[:,
                                              0]  # retain stat only, not error

    # add relative data
    if rel_sess is not None:
        resp_data_df = add_relative_resp_data(resp_data_df,
                                              analyspar,
                                              rel_sess=rel_sess,
                                              in_place=True)

    return resp_data_df
コード例 #7
0
def get_sess_grped_diffs_df(sessions,
                            analyspar,
                            stimpar,
                            basepar,
                            permpar,
                            split="by_exp",
                            randst=None,
                            parallel=False):
    """
    get_sess_grped_diffs_df(sessions, analyspar, stimpar, basepar)

    Returns split difference statistics for specific sessions, grouped across 
    mice.

    Required args:
        - sessions (list): 
            session objects
        - analyspar (AnalysPar): 
            named tuple containing analysis parameters
        - stimpar (StimPar): 
            named tuple containing stimulus parameters
        - basepar (BasePar): 
            named tuple containing baseline parameters
        - permpar (PermPar): 
            named tuple containing permutation parameters

    Optional args:
        - split (str): 
            how to split data:
            "by_exp" (all exp, all unexp), 
            "unexp_lock" (unexp, preceeding exp), 
            "exp_lock" (exp, preceeding unexp),
            "stim_onset" (grayscr, stim on), 
            "stim_offset" (stim off, grayscr)
            default: "by_exp"
        - randst (int or np.random.RandomState): 
            random state or seed value to use. (-1 treated as None)
            default: None
        - parallel (bool): 
            if True, some of the analysis is run in parallel across CPU cores 
            default: False

    Returns:
        - diffs_df (pd.DataFrame):
            dataframe with one row per session/line/plane, and the following 
            columns, in addition to the basic sess_df columns: 
            - diff_stats (list): split difference stats (me, err)
            - null_CIs (list): adjusted null CI for split differences 
            - raw_p_vals (float): uncorrected p-value for differences within 
                sessions
            - p_vals (float): p-value for differences within sessions, 
                corrected for multiple comparisons and tails
            for session comparisons, e.g. 1v2:
            - raw_p_vals_{}v{} (float): uncorrected p-value for differences
                between sessions 
            - p_vals_{}v{} (float): p-value for differences between sessions, 
                corrected for multiple comparisons and tails
    """

    nanpol = None if analyspar.rem_bad else "omit"

    if analyspar.tracked:
        misc_analys.check_sessions_complete(sessions, raise_err=True)

    sess_diffs_df = misc_analys.get_check_sess_df(sessions, None, analyspar)
    initial_columns = sess_diffs_df.columns.tolist()

    # retrieve ROI index information
    args_dict = {
        "analyspar": analyspar,
        "stimpar": stimpar,
        "basepar": basepar,
        "split": split,
        "return_data": True,
    }

    # sess x split x ROI
    split_stats, split_data = gen_util.parallel_wrap(get_sess_roi_split_stats,
                                                     sessions,
                                                     args_dict=args_dict,
                                                     parallel=parallel,
                                                     zip_output=True)

    misc_analys.get_check_sess_df(sessions, sess_diffs_df)
    sess_diffs_df["roi_split_stats"] = list(split_stats)
    sess_diffs_df["roi_split_data"] = list(split_data)

    columns = initial_columns + ["diff_stats", "null_CIs"]
    diffs_df = pd.DataFrame(columns=columns)

    group_columns = ["lines", "planes", "sess_ns"]
    aggreg_cols = [col for col in initial_columns if col not in group_columns]
    for lp_grp_vals, lp_grp_df in sess_diffs_df.groupby(["lines", "planes"]):
        lp_grp_df = lp_grp_df.sort_values(["sess_ns", "mouse_ns"])
        line, plane = lp_grp_vals
        lp_name = plot_helper_fcts.get_line_plane_name(line, plane)
        logger.info(f"Running permutation tests for {lp_name} sessions...",
                    extra={"spacing": TAB})

        # obtain ROI random split differences per session
        # done here to avoid OOM errors
        lp_rand_diffs = gen_util.parallel_wrap(
            get_rand_split_data,
            lp_grp_df["roi_split_data"].tolist(),
            args_list=[analyspar, permpar, randst],
            parallel=parallel,
            zip_output=False)

        sess_diffs = []
        row_indices = []
        sess_ns = sorted(lp_grp_df["sess_ns"].unique())
        for sess_n in sess_ns:
            row_idx = len(diffs_df)
            row_indices.append(row_idx)
            sess_grp_df = lp_grp_df.loc[lp_grp_df["sess_ns"] == sess_n]

            grp_vals = list(lp_grp_vals) + [sess_n]
            for g, group_column in enumerate(group_columns):
                diffs_df.loc[row_idx, group_column] = grp_vals[g]

            # add aggregated values for initial columns
            diffs_df = misc_analys.aggreg_columns(sess_grp_df,
                                                  diffs_df,
                                                  aggreg_cols,
                                                  row_idx=row_idx,
                                                  in_place=True)

            # group ROI split stats across mice: split x ROIs
            split_stats = np.concatenate(
                sess_grp_df["roi_split_stats"].to_numpy(), axis=-1)

            # take diff and stats across ROIs
            diffs = split_stats[1] - split_stats[0]
            diff_stats = math_util.get_stats(diffs,
                                             stats=analyspar.stats,
                                             error=analyspar.error,
                                             nanpol=nanpol)
            diffs_df.at[row_idx, "diff_stats"] = diff_stats.tolist()
            sess_diffs.append(diffs)

            # group random ROI split diffs across mice, and take stat
            rand_idxs = [
                lp_grp_df.index.tolist().index(idx)
                for idx in sess_grp_df.index
            ]
            rand_diffs = math_util.mean_med(np.concatenate(
                [lp_rand_diffs[r] for r in rand_idxs], axis=0),
                                            axis=0,
                                            stats=analyspar.stats,
                                            nanpol=nanpol)

            # get CIs and p-values
            p_val, null_CI = rand_util.get_p_val_from_rand(
                diff_stats[0],
                rand_diffs,
                return_CIs=True,
                p_thresh=permpar.p_val,
                tails=permpar.tails,
                multcomp=permpar.multcomp,
                nanpol=nanpol)
            diffs_df.loc[row_idx, "p_vals"] = p_val
            diffs_df.at[row_idx, "null_CIs"] = null_CI

        del lp_rand_diffs  # free up memory

        # calculate p-values between sessions (0-1, 0-2, 1-2...)
        p_vals = rand_util.comp_vals_acr_groups(sess_diffs,
                                                n_perms=permpar.n_perms,
                                                stats=analyspar.stats,
                                                paired=analyspar.tracked,
                                                nanpol=nanpol,
                                                randst=randst)
        p = 0
        for i, sess_n in enumerate(sess_ns):
            for j, sess_n2 in enumerate(sess_ns[i + 1:]):
                key = f"p_vals_{int(sess_n)}v{int(sess_n2)}"
                diffs_df.loc[row_indices[i], key] = p_vals[p]
                diffs_df.loc[row_indices[j + 1], key] = p_vals[p]
                p += 1

    # add corrected p-values
    diffs_df = misc_analys.add_corr_p_vals(diffs_df, permpar)

    diffs_df["sess_ns"] = diffs_df["sess_ns"].astype(int)

    return diffs_df
コード例 #8
0
def get_ex_traces_df(sessions,
                     analyspar,
                     stimpar,
                     basepar,
                     n_ex=6,
                     rolling_win=4,
                     randst=None,
                     parallel=False):
    """
    get_ex_traces_df(sessions, analyspar, stimpar, basepar)

    Returns example ROI traces dataframe.

    Required args:
        - sessions (list):
            Session objects
        - analyspar (AnalysPar): 
            named tuple containing analysis parameters
        - stimpar (StimPar): 
            named tuple containing stimulus parameters
        - basepar (BasePar): 
            named tuple containing baseline parameters
    
    Optional args:
        - n_ex (int):
            number of example traces to retain
            default: 6
        - rolling_win (int):
            window to use in rolling mean over individual trial traces
            default: 4 
        - randst (int or np.random.RandomState): 
            random state or seed value to use. (-1 treated as None)
            default: None
        - parallel (bool): 
            if True, some of the analysis is run in parallel across CPU cores 
            default: False

    Returns:
        - selected_roi_data (pd.DataFrame):
            dataframe with a row for each ROI, and the following columns, 
            in addition to the basic sess_df columns: 
            - time_values (list): values for each frame, in seconds
                (only 0 to stimpar.post, unless split is "by_exp")
            - roi_ns (list): selected ROI number
            - traces_sm (list): selected ROI sequence traces, smoothed, with 
                dims: seq x frames
            - trace_stats (list): selected ROI trace mean or median
    """

    retained_traces_df = misc_analys.get_check_sess_df(sessions, None,
                                                       analyspar)
    initial_columns = retained_traces_df.columns

    logger.info(f"Identifying example ROIs for each session...",
                extra={"spacing": TAB})

    retained_roi_data = gen_util.parallel_wrap(
        get_sess_ex_traces,
        sessions, [analyspar, stimpar, basepar, rolling_win],
        parallel=parallel)

    randst = rand_util.get_np_rand_state(randst, set_none=True)

    # add data to dataframe
    new_columns = list(retained_roi_data[0])
    retained_traces_df = gen_util.set_object_columns(retained_traces_df,
                                                     new_columns,
                                                     in_place=True)

    for i, sess in enumerate(sessions):
        row_idx = retained_traces_df.loc[retained_traces_df["sessids"] ==
                                         sess.sessid].index

        if len(row_idx) != 1:
            raise RuntimeError(
                "Expected exactly one dataframe row to match session ID.")
        row_idx = row_idx[0]

        for column, value in retained_roi_data[i].items():
            retained_traces_df.at[row_idx, column] = value

    # select a few ROIs per line/plane/session
    columns = retained_traces_df.columns.tolist()
    columns = [column.replace("roi_trace", "trace") for column in columns]
    selected_traces_df = pd.DataFrame(columns=columns)

    group_columns = ["lines", "planes", "sess_ns"]
    for _, trace_grp_df in retained_traces_df.groupby(group_columns):
        trace_grp_df = trace_grp_df.sort_values("mouse_ns")
        grp_indices = trace_grp_df.index
        n_per = np.asarray([len(roi_ns) for roi_ns in trace_grp_df["roi_ns"]])
        roi_ns = np.concatenate(trace_grp_df["roi_ns"].tolist())
        concat_idxs = np.sort(randst.choice(len(roi_ns), n_ex, replace=False))

        for concat_idx in concat_idxs:
            row_idx = len(selected_traces_df)
            sess_idx = np.where(concat_idx < np.cumsum(n_per))[0][0]
            source_row = trace_grp_df.loc[grp_indices[sess_idx]]
            for column in initial_columns:
                selected_traces_df.at[row_idx, column] = source_row[column]

            selected_traces_df.at[row_idx, "time_values"] = \
                source_row["time_values"].tolist()

            roi_idx = concat_idx - n_per[:sess_idx].sum()
            for col in ["roi_ns", "traces_sm", "trace_stats"]:
                source_col = col.replace("trace", "roi_trace")
                selected_traces_df.at[row_idx, col] = \
                    source_row[source_col][roi_idx].tolist()

    for column in [
            "mouse_ns", "mouseids", "sess_ns", "sessids", "nrois", "roi_ns"
    ]:
        selected_traces_df[column] = selected_traces_df[column].astype(int)

    return selected_traces_df
コード例 #9
0
def get_all_correlations(sessions, analyspar, n_bins=40, rolling_win=4, 
                         parallel=False):
    """
    get_all_correlations(sessions, analyspar)

    Returns ROI correlation data for each line/plane/session.

    Required args:
        - session (Session):
            Session object
        - analyspar (AnalysPar): 
            named tuple containing analysis parameters
    
    Optional args:
        - n_bins (int):
            number of bins for correlation data
            default: 40
        - parallel (bool):
            if True, some of the analysis is run in parallel across CPU cores 
            default: False
        - rolling_win (int):
            window to use in rolling mean over individual traces before 
            computing correlation between ROIs (None for no smoothing)
            default: 4 
        
    Returns:
        - binned_cc_df (pd.DataFrame):
            dataframe with one row per session/line/plane, and the 
            following columns, in addition to the basic sess_df columns:
            - bin_edges (list): first and last bin edge
            - corrs_binned (list): number of correlation values per bin
    """

    all_corrs = gen_util.parallel_wrap(
        get_correlation, sessions, [analyspar, rolling_win], parallel=parallel
        )

    cc_df = get_check_sess_df(sessions, analyspar=analyspar)
    initial_columns = cc_df.columns

    cc_df["corrs"] = [cc.tolist() for cc in all_corrs]

    # group within line/plane
    group_columns = ["lines", "planes", "sess_ns"]

    columns = initial_columns.tolist() + ["bin_edges", "corrs_binned"]
    binned_cc_df = pd.DataFrame(columns=columns)
    aggreg_cols = [col for col in initial_columns if col not in group_columns]
    for grp_vals, grp_df in cc_df.groupby(group_columns):
        grp_df = grp_df.sort_values("mouse_ns")
        row_idx = len(binned_cc_df)
        for g, group_column in enumerate(group_columns):
            binned_cc_df.loc[row_idx, group_column] = grp_vals[g]

        # add aggregated values for initial columns
        binned_cc_df = aggreg_columns(
            grp_df, binned_cc_df, aggreg_cols, row_idx=row_idx, in_place=True
            )

        cc_data = np.concatenate(grp_df["corrs"].tolist())

        cc_data_binned, bin_edges = np.histogram(
            cc_data, bins=np.linspace(-1, 1, n_bins + 1)
            )

        binned_cc_df.at[row_idx, "corrs_binned"] = cc_data_binned.tolist()
        binned_cc_df.at[row_idx, "bin_edges"] = [bin_edges[0], bin_edges[-1]]

    binned_cc_df["sess_ns"] = binned_cc_df["sess_ns"].astype(int)

    return binned_cc_df
コード例 #10
0
def get_sess_roi_trace_df(sessions,
                          analyspar,
                          stimpar,
                          basepar,
                          split="by_exp",
                          parallel=False):
    """
    get_sess_roi_trace_df(sess, analyspar, stimpar, basepar)

    Returns ROI trace statistics for specific sessions, split as requested.

    Required args:
        - sessions (list): 
            session objects
        - analyspar (AnalysPar): 
            named tuple containing analysis parameters
        - stimpar (StimPar): 
            named tuple containing stimulus parameters
        - basepar (BasePar): 
            named tuple containing baseline parameters

    Optional args:
        - split (str): 
            how to split data:
            "by_exp" (all exp, all unexp), 
            "unexp_lock" (unexp, preceeding exp), 
            "exp_lock" (exp, preceeding unexp),
            "stim_onset" (grayscr, stim on), 
            "stim_offset" (stim off, grayscr)
            default: "by_exp"
        - parallel (bool): 
            if True, some of the analysis is run in parallel across CPU cores 
            default: False

    Returns:
        - trace_df (pd.DataFrame):
            dataframe with a row for each session, and the following 
            columns, in addition to the basic sess_df columns: 
            - roi_trace_stats (list): 
                ROI trace stats (split x ROIs x frames x stat (me, err))
            - time_values (list):
                values for each frame, in seconds
                (only 0 to stimpar.post, unless split is "by_exp")
    """

    trace_df = misc_analys.get_check_sess_df(sessions, None, analyspar)

    # retrieve ROI index information
    args_dict = {
        "analyspar": analyspar,
        "stimpar": stimpar,
        "basepar": basepar,
        "split": split,
    }

    # sess x split x ROIs x frames
    roi_trace_stats, all_time_values = gen_util.parallel_wrap(
        basic_analys.get_sess_roi_trace_stats,
        sessions,
        args_dict=args_dict,
        parallel=parallel,
        zip_output=True)

    misc_analys.get_check_sess_df(sessions, trace_df)
    trace_df["roi_trace_stats"] = [stats.tolist() for stats in roi_trace_stats]
    trace_df["time_values"] = [
        time_values.tolist() for time_values in all_time_values
    ]

    return trace_df
コード例 #11
0
def prep_analyses(sess_n, args, mouse_df):
    """
    prep_analyses(sess_n, args, mouse_df)

    Prepares named tuples and sessions for which to run analyses, based on the 
    arguments passed.

    Required args:
        - sess_n (int)          : session number to run analyses on, or 
                                  combination of session numbers to compare, 
                                  e.g. "1v2"
        - args (Argument parser): parser containing all parameters
        - mouse_df (pandas df)  : path name of dataframe containing information 
                                  on each session

    Returns:
        - sessions (list)     : list of sessions, or nested list per mouse 
                                if sess_n is a combination to compare
        - analysis_dict (dict): dictionary of analysis parameters 
                                (see init_param_cont())
    """

    args = copy.deepcopy(args)

    args.sess_n = sess_n

    analysis_dict = init_param_cont(args)
    sesspar, stimpar = [analysis_dict[key] for key in ["sesspar", "stimpar"]]

    # get session IDs and load Sessions
    sessids = sess_gen_util.sess_per_mouse(mouse_df,
                                           omit_sess=args.omit_sess,
                                           omit_mice=args.omit_mice,
                                           **sesspar._asdict())

    logger.info(f"Loading {len(sessids)} session(s)...",
                extra={"spacing": "\n"})

    args_dict = {
        "datadir": args.datadir,
        "mouse_df": mouse_df,
        "runtype": sesspar.runtype,
        "full_table": False,
        "roi": False,
        "run": True,
        "temp_log": "warning",
    }

    sessions = gen_util.parallel_wrap(sess_gen_util.init_sessions,
                                      sessids,
                                      args_dict=args_dict,
                                      parallel=args.parallel,
                                      use_tqdm=True)

    # flatten list of sessions
    sessions = [sess for singles in sessions for sess in singles]

    runtype_str = ""
    if sesspar.runtype != "prod":
        runtype_str = f" ({sesspar.runtype} data)"

    stim_str = stimpar.stimtype
    if stimpar.stimtype == "gabors":
        stim_str = "gabor"
    elif stimpar.stimtype == "visflow":
        stim_str = "visual flow"

    logger.info(
        f"Analysis of {sesspar.plane} responses to {stim_str} "
        f"stimuli{runtype_str}.\nSession {sesspar.sess_n}",
        extra={"spacing": "\n"})

    return sessions, analysis_dict
コード例 #12
0
def main(args):
    """
    main(args)

    Runs analyses with parser arguments.

    Required args:
        - args (dict): parser argument dictionary
    """

    # set logger to the specified level
    logger_util.set_level(level=args.log_level)

    args.fontdir = DEFAULT_FONTDIR if DEFAULT_FONTDIR.is_dir() else None

    if args.dict_path is not None:
        source = "modif" if args.modif else "run"
        plot_dicts.plot_from_dicts(Path(args.dict_path),
                                   source=source,
                                   plt_bkend=args.plt_bkend,
                                   fontdir=args.fontdir,
                                   parallel=args.parallel,
                                   datetime=not (args.no_datetime),
                                   overwrite=args.overwrite)
    else:
        args = reformat_args(args)
        if args.datadir is None:
            args.datadir = DEFAULT_DATADIR
        else:
            args.datadir = Path(args.datadir)
        mouse_df = DEFAULT_MOUSE_DF_PATH

        # get numbers of sessions to analyse
        if args.sess_n == "all":
            all_sess_ns = sess_gen_util.get_sess_vals(mouse_df,
                                                      "sess_n",
                                                      runtype=args.runtype,
                                                      plane=args.plane,
                                                      line=args.line,
                                                      min_rois=args.min_rois,
                                                      pass_fail=args.pass_fail,
                                                      incl=args.incl,
                                                      omit_sess=args.omit_sess,
                                                      omit_mice=args.omit_mice,
                                                      sort=True)
        else:
            all_sess_ns = gen_util.list_if_not(args.sess_n)

        # get analysis parameters for each session number
        all_analys_pars = gen_util.parallel_wrap(prep_analyses,
                                                 all_sess_ns,
                                                 args_list=[args, mouse_df],
                                                 parallel=args.parallel)

        # split parallel from sequential analyses
        bool(args.parallel * (not args.debug))
        if args.parallel:
            run_seq = ""  # should be run parallel within analysis
            all_analyses = gen_util.remove_lett(args.analyses, run_seq)
            sess_parallels = [True, False]
            analyses_parallels = [False, True]
        else:
            all_analyses = [args.analyses]
            sess_parallels, analyses_parallels = [False], [False]

        for analyses, sess_parallel, analyses_parallel in zip(
                all_analyses, sess_parallels, analyses_parallels):
            if len(analyses) == 0:
                continue
            args_dict = {
                "analyses": analyses,
                "seed": args.seed,
                "parallel": analyses_parallel,
            }

            # run analyses for each parameter set
            gen_util.parallel_wrap(run_analyses,
                                   all_analys_pars,
                                   args_dict=args_dict,
                                   parallel=sess_parallel,
                                   mult_loop=True)
コード例 #13
0
def plot_from_dicts(direc,
                    source="roi",
                    plt_bkend=None,
                    fontdir=None,
                    plot_tc=True,
                    parallel=False,
                    datetime=True,
                    overwrite=False,
                    pattern="",
                    depth=0):
    """
    plot_from_dicts(direc)

    Plots data from dictionaries containing analysis parameters and results, or 
    path to results.

    Required args:
        - direc (Path): path to directory in which dictionaries to plot data 
                        from are located or path to a single json file
    
    Optional_args:
        - source (str)    : plotting source ("roi", "run", "gen", "pup", 
                            "modif", "logreg", "glm")
        - plt_bkend (str) : mpl backend to use for plotting (e.g., "agg")
                            default: None
        - fontdir (Path)  : directory in which additional fonts are stored
                            default: None
        - plot_tc (bool)  : if True, tuning curves are plotted for each ROI 
                            default: True
        - parallel (bool) : if True, some of the analysis is parallelized 
                            across CPU cores
                            default: False
        - datetime (bool) : figpar["save"] datatime parameter (whether to 
                            place figures in a datetime folder)
                            default: True
        - overwrite (bool): figpar["save"] overwrite parameter (whether to 
                            overwrite figures)
                            default: False
        - pattern (str)   : pattern based on which to include json files in 
                            direc if direc is a directory
                            default: ""
        - depth (int)     : maximum depth at which to check for json files if 
                            direc is a directory
                            default: 0
    """

    file_util.checkexists(direc)

    direc = Path(direc)

    if direc.is_dir():
        if source == "logreg":
            targ_file = "hyperparameters.json"
        else:
            targ_file = "*.json"

        dict_paths = []
        for d in range(depth + 1):
            dict_paths.extend(
                glob.glob(str(Path(direc, *(["*"] * d), targ_file))))

        dict_paths = list(filter(lambda x: pattern in str(x), dict_paths))

        if source == "logreg":
            dict_paths = [Path(dp).parent for dp in dict_paths]

        if len(dict_paths) == 0:
            raise OSError(f"No jsons found in {direc} at "
                          f"depth {depth} with pattern '{pattern}'.")

    elif ".json" not in str(direc):
        raise ValueError("If providing a file, must be a json file.")
    else:
        if (source == "logreg"
                and not str(direc).endswith("hyperparameters.json")):
            raise ValueError("For logreg source, must provide path to "
                             "a hyperparameters json file.")

        dict_paths = [direc]

    dict_paths = sorted(dict_paths)
    if len(dict_paths) > 1:
        logger.info(f"Plotting from {len(dict_paths)} dictionaries.")

    fontdir = Path(fontdir) if fontdir is not None else fontdir
    args_dict = {
        "plt_bkend": plt_bkend,
        "fontdir": fontdir,
        "plot_tc": plot_tc,
        "datetime": datetime,
        "overwrite": overwrite,
    }

    pass_parallel = True
    sources = [
        "roi", "run", "gen", "modif", "pup", "logreg", "glm", "acr_sess"
    ]
    if source == "roi":
        fct = roi_plots.plot_from_dict
    elif source in ["run", "gen"]:
        fct = gen_plots.plot_from_dict
    elif source in ["pup", "pupil"]:
        fct = pup_plots.plot_from_dict
    elif source == "modif":
        fct = mod_plots.plot_from_dict
    elif source == "logreg":
        pass_parallel = False
        fct = logreg_plots.plot_from_dict
    elif source == "glm":
        pass_parallel = False
        fct = glm_plots.plot_from_dict
    elif source == "acr_sess":
        fct = acr_sess_plots.plot_from_dict
    else:
        gen_util.accepted_values_error("source", source, sources)

    args_dict = gen_util.keep_dict_keys(args_dict,
                                        inspect.getfullargspec(fct).args)
    gen_util.parallel_wrap(fct,
                           dict_paths,
                           args_dict=args_dict,
                           parallel=parallel,
                           pass_parallel=pass_parallel)

    plot_util.cond_close_figs()
コード例 #14
0
def get_idx_corrs_df(sessions, analyspar, stimpar, basepar, idxpar, permpar, 
                     consec_only=True, permute="sess", corr_type="corr", 
                     sig_only=False, randst=None, parallel=False):
    """
    get_idx_corrs_df(sessions, analyspar, stimpar, basepar, idxpar, permpar)

    Returns ROI index correlation data for each line/plane/session comparison.

    Required args:
        - sessions (list): 
            Session objects
        - analyspar (AnalysPar): 
            named tuple containing analysis parameters
        - stimpar (StimPar): 
            named tuple containing stimulus parameters
        - basepar (BasePar): 
            named tuple containing baseline parameters
        - idxpar (IdxPar): 
            named tuple containing index parameters
        - permpar (PermPar): 
            named tuple containing permutation parameters.
    
    Optional args:
        - consec_only (bool):
            if True, only consecutive session numbers are correlated
            default: True
        - corr_type (str):
            type of correlation to run, i.e. "corr" or "R_sqr"
            default: "corr"
        - permute (bool):
            type of permutation to due ("tracking", "sess" or "all")
            default: "sess"
        - sig_only (bool):
            if True, ROIs with significant USIs are included 
            (only possible if analyspar.tracked is True)
            default: False
        - randst (int or np.random.RandomState): 
            seed value to use. (-1 treated as None)
            default: None
        - parallel (bool): 
            if True, some of the analysis is run in parallel across CPU cores 
            default: False
   
    Returns:
        - idx_corr_df (pd.DataFrame):
            dataframe with one row per line/plane, and the 
            following columns, in addition to the basic sess_df columns:

            for correlation data (normalized if corr_type is "diff_corr") for 
            session comparisons, e.g. 1v2
            - {}v{}{norm_str}_corrs (float): intersession ROI index correlations
            - {}v{}{norm_str}_corr_stds (float): bootstrapped intersession ROI 
                index correlation standard deviation
            - {}v{}_null_CIs (list): adjusted null CI for intersession ROI 
                index correlations
            - {}v{}_raw_p_vals (float): p-value for intersession correlations
            - {}v{}_p_vals (float): p-value for intersession correlations, 
                corrected for multiple comparisons and tails
    """
    
    lp_idx_df = get_lp_idx_df(
        sessions, 
        analyspar=analyspar, 
        stimpar=stimpar, 
        basepar=basepar, 
        idxpar=idxpar,
        permpar=permpar,
        sig_only=sig_only,
        randst=randst,
        parallel=parallel,
        )

    idx_corr_df = get_basic_idx_corr_df(lp_idx_df, consec_only=consec_only)

    # get correlation pairs
    corr_ns = get_corr_pairs(lp_idx_df, consec_only=consec_only)

    # get norm information
    norm = False
    if permute in ["sess", "all"]:
        corr_type = f"diff_{corr_type}"
        if corr_type == "diff_corr":
            norm = True
    norm_str = "_norm" if norm else ""

    logger.info(
        ("Calculating ROI USI correlations across sessions..."), 
        extra={"spacing": TAB}
        )
    group_columns = ["lines", "planes"]
    for grp_vals, grp_df in lp_idx_df.groupby(group_columns):
        grp_df = grp_df.sort_values("sess_ns") # mice already aggregated
        line, plane = grp_vals
        row_idx = idx_corr_df.loc[
            (idx_corr_df["lines"] == line) &
            (idx_corr_df["planes"] == plane)
        ].index

        if len(row_idx) != 1:
            raise RuntimeError("Expected exactly one row to match.")
        row_idx = row_idx[0]
    
        use_randst = copy.deepcopy(randst) # reset each time

        # obtain correlation data
        args_dict = {
            "data_df"  : grp_df,
            "analyspar": analyspar,
            "permpar"  : permpar,
            "permute"  : permute,
            "corr_type": corr_type,
            "absolute" : False,
            "norm"     : norm,
            "randst"   : use_randst,
        }

        all_corr_data = gen_util.parallel_wrap(
            get_corr_data, 
            corr_ns, 
            args_dict=args_dict, 
            parallel=parallel, 
            zip_output=False
        )

        # add to dataframe
        for sess_pair, corr_data in zip(corr_ns, all_corr_data):

            if corr_data is None:
                continue
            
            corr_name = f"{sess_pair[0]}v{sess_pair[1]}"
            roi_corr, roi_corr_std, null_CI, p_val = corr_data

            idx_corr_df.loc[row_idx, f"{corr_name}{norm_str}_corrs"] = roi_corr
            idx_corr_df.loc[row_idx, f"{corr_name}{norm_str}_corr_stds"] = \
                roi_corr_std
            idx_corr_df.at[row_idx, f"{corr_name}_null_CIs"] = null_CI.tolist()
            idx_corr_df.loc[row_idx, f"{corr_name}_p_vals"] = p_val

    # corrected p-values
    idx_corr_df = misc_analys.add_corr_p_vals(idx_corr_df, permpar)

    return idx_corr_df
コード例 #15
0
def prep_analyses(sess_n, args, mouse_df, parallel=False):
    """
    prep_analyses(sess_n, args, mouse_df)

    Prepares named tuples and sessions for which to run analyses, based on the 
    arguments passed.

    Required args:
        - sess_n (int)          : session number to run analyses on, or 
                                  combination of session numbers to compare, 
                                  e.g. "1v2"
        - args (Argument parser): parser containing all parameters
        - mouse_df (pandas df)  : path name of dataframe containing information 
                                  on each session

    Optional args:
        - parallel (bool): if True, sessions are initialized in parallel 
                           across CPU cores 
                           default: False

    Returns:
        - sessions (list)      : list of sessions, or nested list per mouse 
                                 if sess_n is a combination
        - analysis_dict (dict): dictionary of analysis parameters 
                                (see init_param_cont())
    """

    args = copy.deepcopy(args)

    args.sess_n = sess_n

    analysis_dict = init_param_cont(args)
    analyspar, sesspar, stimpar = [
        analysis_dict[key] for key in ["analyspar", "sesspar", "stimpar"]
    ]

    roi = (args.datatype == "roi")
    run = (args.datatype == "run")

    sesspar_dict = sesspar._asdict()
    _ = sesspar_dict.pop("closest")

    [all_mouse_ns,
     all_sess_ns] = sess_gen_util.get_sess_vals(mouse_df,
                                                ["mouse_n", "sess_n"],
                                                omit_sess=args.omit_sess,
                                                omit_mice=args.omit_mice,
                                                unique=False,
                                                **sesspar_dict)

    if args.sess_n in ["any", "all"]:
        all_sess_ns = [n + 1 for n in range(max(all_sess_ns))]
    else:
        all_sess_ns = args.sess_n

    # get session IDs and create Sessions
    all_mouse_ns = sorted(set(all_mouse_ns))

    logger.info(f"Loading sessions for {len(all_mouse_ns)} mice...",
                extra={"spacing": "\n"})
    args_list = [
        all_sess_ns, sesspar, mouse_df, args.datadir, args.omit_sess,
        analyspar.fluor, analyspar.dend, roi, run
    ]
    sessions = gen_util.parallel_wrap(init_mouse_sess,
                                      all_mouse_ns,
                                      args_list=args_list,
                                      parallel=parallel,
                                      use_tqdm=True)

    check_all = set([sess for m_sess in sessions for sess in m_sess])
    if len(sessions) == 0 or check_all == {None}:
        raise RuntimeError("No sessions meet the criteria.")

    runtype_str = ""
    if sesspar.runtype != "prod":
        runtype_str = f" ({sesspar.runtype} data)"

    stim_str = stimpar.stimtype
    if stimpar.stimtype == "gabors":
        stim_str = "gabor"
    elif stimpar.stimtype == "visflow":
        stim_str = "visual flow"

    logger.info(
        f"Analysis of {sesspar.plane} responses to {stim_str} "
        f"stimuli{runtype_str}.\nSession {sesspar.sess_n}",
        extra={"spacing": "\n"})

    return sessions, analysis_dict
コード例 #16
0
def get_pupil_run_trace_df(sessions,
                           analyspar,
                           stimpar,
                           basepar,
                           split="by_exp",
                           parallel=False):
    """
    get_pupil_run_trace_df(sessions, analyspar, stimpar, basepar)

    Returns pupil and running traces for specific sessions, split as 
    requested.

    Required args:
        - sessions (list): 
            session objects
        - analyspar (AnalysPar): 
            named tuple containing analysis parameters
        - stimpar (StimPar): 
            named tuple containing stimulus parameters
        - basepar (BasePar): 
            named tuple containing baseline parameters

    Optional args:
        - split (str): 
            how to split data:
            "by_exp" (all exp, all unexp), 
            "unexp_lock" (unexp, preceeding exp), 
            "exp_lock" (exp, preceeding unexp),
            "stim_onset" (grayscr, stim on), 
            "stim_offset" (stim off, grayscr)
            default: "by_exp"
        - parallel (bool): 
            if True, some of the analysis is run in parallel across CPU cores 
            default: False

    Returns:
        - trace_df (pd.DataFrame):
            dataframe with a row for each session, and the following 
            columns, in addition to the basic sess_df columns: 
            - run_traces (list): 
                running velocity traces (split x seqs x frames)
            - run_time_values (list):
                values for each frame, in seconds
                (only 0 to stimpar.post, unless split is "by_exp")
            - pupil_traces (list): 
                pupil diameter traces (split x seqs x frames)
            - pupil_time_values (list):
                values for each frame, in seconds
                (only 0 to stimpar.post, unless split is "by_exp")    
    """

    trace_df = misc_analys.get_check_sess_df(sessions,
                                             None,
                                             analyspar,
                                             roi=False)

    # retrieve ROI index information
    args_dict = {
        "analyspar": analyspar,
        "stimpar": stimpar,
        "baseline": basepar.baseline,
        "split": split,
    }

    misc_analys.get_check_sess_df(sessions, trace_df)
    for datatype in ["pupil", "run"]:
        args_dict["datatype"] = datatype
        # sess x split x seq x frames
        split_traces, all_time_values = gen_util.parallel_wrap(
            basic_analys.get_split_data_by_sess,
            sessions,
            args_dict=args_dict,
            parallel=parallel,
            zip_output=True)

        # add columns to dataframe
        trace_df[f"{datatype}_traces"] = list(split_traces)
        trace_df[f"{datatype}_time_values"] = list(all_time_values)

    return trace_df
コード例 #17
0
def init_sessions(analyspar,
                  sesspar,
                  mouse_df,
                  datadir,
                  sessions=None,
                  roi=True,
                  run=False,
                  pupil=False,
                  parallel=False):
    """
    init_sessions(sesspar, mouse_df, datadir)

    Initializes sessions.

    Required args:
        - analyspar (AnalysPar): 
            named tuple containing session parameters
        - sesspar (SessPar): 
            named tuple containing session parameters
        - mouse_df (pandas df): 
            path name of dataframe containing information on each session
        - datadir (Path): 
            path to data directory
    
    Optional args:
        - sessions (list): 
            preloaded sessions
            default: None
        - roi (bool): 
            if True, ROI data is loaded
            default: True
        - run (bool): 
            if True, running data is loaded
            default: False
        - pupil (bool): 
            if True, pupil data is loaded
            default: False

    Returns:
        - sessions (list): 
            Session objects 
    """

    sesspar_dict = sesspar._asdict()
    sesspar_dict.pop("closest")

    # identify sessions needed
    sessids = sorted(
        sess_gen_util.get_sess_vals(mouse_df, "sessid", **sesspar_dict))

    if len(sessids) == 0:
        raise ValueError("No sessions meet the criteria.")

    # check for preloaded sessions, and only load new ones
    if sessions is not None:
        loaded_sessids = [session.sessid for session in sessions]
        ext_str = " additional"
    else:
        sessions = []
        loaded_sessids = []
        ext_str = ""

    # identify new sessions to load
    load_sessids = list(
        filter(lambda sessid: sessid not in loaded_sessids, sessids))

    # remove sessions that are not needed
    if len(sessions):
        sessions = [
            session for session in sessions if session.sessid in sessids
        ]

        # check that previously loaded sessions have roi/run/pupil data loaded
        args_list = [roi, run, pupil, analyspar.fluor, analyspar.dend]
        with logger_util.TempChangeLogLevel(level="warning"):
            sessions = gen_util.parallel_wrap(sess_gen_util.check_session,
                                              sessions,
                                              args_list=args_list,
                                              parallel=parallel)

    # load new sessions
    if len(load_sessids):
        logger.info(f"Loading {len(load_sessids)}{ext_str} session(s)...",
                    extra={"spacing": "\n"})

        args_dict = {
            "datadir": datadir,
            "mouse_df": mouse_df,
            "runtype": sesspar.runtype,
            "full_table": False,
            "fluor": analyspar.fluor,
            "dend": analyspar.dend,
            "roi": roi,
            "run": run,
            "pupil": pupil,
            "temp_log": "critical"  # suppress almost all logs 
        }

        new_sessions = gen_util.parallel_wrap(sess_gen_util.init_sessions,
                                              load_sessids,
                                              args_dict=args_dict,
                                              parallel=parallel,
                                              use_tqdm=True)

        # flatten list of new sessions, and add to full sessions list
        new_sessions = [sess for singles in new_sessions for sess in singles]
        sessions = sessions + new_sessions

    # combine session lists, and sort
    sorter = [sessids.index(session.sessid) for session in sessions]
    sessions = [sessions[i] for i in sorter]

    # update ROI tracking parameters
    for sess in sessions:
        sess.set_only_tracked_rois(analyspar.tracked)

    return sessions
コード例 #18
0
def get_roi_tracking_df(sessions,
                        analyspar,
                        reg_only=False,
                        proj=False,
                        crop_info=False,
                        parallel=False):
    """
    get_roi_tracking_df(sessions, analyspar)

    Return ROI tracking information for the requested sessions.

    Required args:
        - sessions (list): 
            Session objects
        - analyspar (AnalysPar): 
            named tuple containing analysis parameters

    Optional args:
        - proj (bool):
            if True, max projections are included in the output dataframe
            default: False
        - reg_only (bool):
            if True, only registered masks, and projections if proj is True, 
            are included in the output dataframe
            default: False
        - crop_info (bool or str):
            if not False, the type of cropping information to include 
            ("small" for the small plots, "large" for the large plots)
            default: False
        - parallel (bool): 
            if True, some of the analysis is run in parallel across CPU cores 
            default: False

    Returns:
        - roi_mask_df (pd.DataFrame in dict format):
            dataframe with a row for each mouse, and the following 
            columns, in addition to the basic sess_df columns: 

            - "registered_roi_mask_idxs" (list): list of mask indices, 
                registered across sessions, for each session 
                (flattened across ROIs) ((sess, hei, wid) x val)
            - "roi_mask_shapes" (list): shape into which ROI mask indices index 
                (sess x hei x wid)

            if not reg_only:
            - "roi_mask_idxs" (list): list of mask indices for each session, 
                and each ROI (sess x ((ROI, hei, wid) x val)) (not registered)

            if proj:
            - "registered_max_projections" (list): pixel intensities of maximum 
                projection for the plane (hei x wid), after registration across 
                sessions

            if proj and not reg_only:
            - "max_projections" (list): pixel intensities of maximum projection 
                for the plane (hei x wid)
                
            if crop_info:
            - "crop_fact" (num): factor by which to crop masks (> 1) 
            - "shift_prop_hei" (float): proportion by which to shift cropped 
                mask center vertically from left edge [0, 1]
            - "shift_prop_wid" (float): proportion by which to shift cropped 
                mask center horizontally from left edge [0, 1]
    """

    if not analyspar.tracked:
        raise ValueError("analyspar.tracked must be True for this analysis.")

    misc_analys.check_sessions_complete(sessions, raise_err=True)

    sess_df = misc_analys.get_check_sess_df(sessions, analyspar=analyspar)

    # if cropping, check right away for dictionary with the preset parameters
    if crop_info:
        if crop_info == "small":
            crop_dict = SMALL_CROP_DICT
        elif crop_info == "large":
            crop_dict = LARGE_CROP_DICT
        else:
            gen_util.accepted_values_error("crop_info", crop_info,
                                           ["small", "large"])
        for mouse_n in sess_df["mouse_ns"].unique():
            if int(mouse_n) not in crop_dict.keys():
                raise NotImplementedError(
                    f"No preset cropping information found for mouse {mouse_n}."
                )

    # collect ROI mask data
    sess_dicts = gen_util.parallel_wrap(get_sess_reg_mask_info,
                                        sessions,
                                        args_list=[analyspar, True, proj],
                                        parallel=parallel)
    all_sessids = [sess.sessid for sess in sessions]

    group_columns = ["planes", "lines", "mouse_ns"]
    initial_columns = sess_df.columns.tolist()
    obj_columns = ["registered_roi_mask_idxs", "roi_mask_shapes"]
    if not reg_only:
        obj_columns.append("roi_mask_idxs")
    if proj:
        obj_columns.append("registered_max_projections")
        if not reg_only:
            obj_columns.append("max_projections")

    roi_mask_df = pd.DataFrame(columns=initial_columns + obj_columns)

    aggreg_cols = [col for col in initial_columns if col not in group_columns]
    for grp_vals, grp_df in sess_df.groupby(group_columns):
        row_idx = len(roi_mask_df)
        for g, group_column in enumerate(group_columns):
            roi_mask_df.loc[row_idx, group_column] = grp_vals[g]

        # add aggregated values for initial columns
        roi_mask_df = misc_analys.aggreg_columns(grp_df,
                                                 roi_mask_df,
                                                 aggreg_cols,
                                                 row_idx=row_idx,
                                                 in_place=True,
                                                 by_mouse=True)

        sessids = sorted(grp_df["sessids"].tolist())
        reg_roi_masks, roi_mask_idxs = [], []
        if proj:
            reg_max_projs, max_projs = [], []

        roi_mask_shape = None
        for sessid in sessids:
            sess_dict = sess_dicts[all_sessids.index(sessid)]
            reg_roi_mask = sess_dict["registered_roi_masks"]
            # flatten masks across ROIs
            reg_roi_masks.append(np.max(reg_roi_mask, axis=0))
            if roi_mask_shape is None:
                roi_mask_shape = reg_roi_mask.shape
            elif roi_mask_shape != reg_roi_mask.shape:
                raise RuntimeError(
                    "ROI mask shapes across sessions should match, for the "
                    "same mouse.")
            if not reg_only:
                roi_mask_idxs.append([
                    idxs.tolist() for idxs in np.where(sess_dict["roi_masks"])
                ])
            if proj:
                reg_max_projs.append(
                    sess_dict["registered_max_projection"].tolist())
                if not reg_only:
                    max_projs.append(sess_dict["max_projection"].tolist())

        # add to the dataframe
        roi_mask_df.at[row_idx, "registered_roi_mask_idxs"] = \
            [idxs.tolist() for idxs in np.where(reg_roi_masks)]
        roi_mask_df.at[row_idx, "roi_mask_shapes"] = roi_mask_shape

        if not reg_only:
            roi_mask_df.at[row_idx, "roi_mask_idxs"] = roi_mask_idxs
        if proj:
            roi_mask_df.at[row_idx, "registered_max_projections"] = \
                reg_max_projs
            if not reg_only:
                roi_mask_df.at[row_idx, "max_projections"] = max_projs

        # add cropping info
        if crop_info:
            mouse_n = grp_vals[group_columns.index("mouse_ns")]
            crop_fact, shift_prop_hei, shift_prop_wid = crop_dict[mouse_n]
            roi_mask_df.at[row_idx, "crop_fact"] = crop_fact
            roi_mask_df.at[row_idx, "shift_prop_hei"] = shift_prop_hei
            roi_mask_df.at[row_idx, "shift_prop_wid"] = shift_prop_wid

    roi_mask_df["mouse_ns"] = roi_mask_df["mouse_ns"].astype(int)

    return roi_mask_df
コード例 #19
0
def run_glms(sessions,
             analysis,
             seed,
             analyspar,
             sesspar,
             stimpar,
             glmpar,
             figpar,
             parallel=False):
    """
    run_glms(sessions, analysis, seed, analyspar, sesspar, stimpar, glmpar, 
             figpar)
    """

    seed = rand_util.seed_all(seed, "cpu", log_seed=False)

    sessstr_pr = sess_str_util.sess_par_str(sesspar.sess_n, stimpar.stimtype,
                                            sesspar.plane, stimpar.visflow_dir,
                                            stimpar.visflow_size, stimpar.gabk,
                                            "print")
    dendstr_pr = sess_str_util.dend_par_str(analyspar.dend, sesspar.plane,
                                            "roi", "print")

    logger.info(
        "Analysing and plotting explained variance in ROI activity "
        f"({sessstr_pr}{dendstr_pr}).",
        extra={"spacing": "\n"})

    if glmpar.each_roi:  # must do each session separately
        glm_type = "per_ROI_per_sess"
        sess_batches = sessions
        logger.info("Per ROI, each session separately.")
    else:
        glm_type = "across_sess"
        sess_batches = [sessions]
        logger.info(f"Across ROIs, {len(sessions)} sessions together.")

    # optionally runs in parallel, or propagates parallel to next level
    parallel_here = (parallel and not (glmpar.each_roi)
                     and (len(sess_batches) != 1))
    parallel_after = True if (parallel and not (parallel_here)) else False

    args_list = [analyspar, sesspar, stimpar, glmpar]
    args_dict = {
        "parallel": parallel_after,  # proactively set next parallel 
        "seed": seed,
    }
    all_expl_var = gen_util.parallel_wrap(run_glm,
                                          sess_batches,
                                          args_list,
                                          args_dict,
                                          parallel=parallel_here)

    if glmpar.each_roi:
        sessions = sess_batches
    else:
        sessions = sess_batches[0]

    sess_info = sess_gen_util.get_sess_info(sessions,
                                            analyspar.fluor,
                                            rem_bad=analyspar.rem_bad)

    extrapar = {
        "analysis": analysis,
        "seed": seed,
        "glm_type": glm_type,
    }

    info = {
        "analyspar": analyspar._asdict(),
        "sesspar": sesspar._asdict(),
        "stimpar": stimpar._asdict(),
        "glmpar": glmpar._asdict(),
        "extrapar": extrapar,
        "all_expl_var": all_expl_var,
        "sess_info": sess_info
    }

    fulldir, savename = glm_plots.plot_glm_expl_var(figpar=figpar, **info)

    file_util.saveinfo(info, savename, fulldir, "json")

    return