コード例 #1
0
def get_signif_rois(integ_data,
                    permpar,
                    stats="mean",
                    op="diff",
                    nanpol=None,
                    log_rois=True):
    """
    get_signif_rois(integ_data, permpar)

    Identifies ROIs showing significant unexpected responses in specified quantiles,
    groups accordingly and retrieves statistics for each group.

    Required args:
        - integ_data (list): list of 2D array of ROI activity integrated 
                             across frames.
                                unexp (0, 1) x array[ROI x sequences]
        - permpar (PermPar): named tuple containing permutation parameters
                             (multcomp does not apply to identifiying 
                             significant ROIs)

    Optional args:
        - stats (str)    : statistic parameter, i.e. "mean" or "median"
                           default: "mean"
        - op (str)       : operation to identify significant ROIs
                           default: "diff"
        - nanpol (str)   : policy for NaNs, "omit" or None when taking 
                           statistics
                           default: None
        - log_rois (bool): if True, the indices of significant ROIs and
                           their actual difference values are logged

    Returns: 
        - sign_rois (list): list of ROIs showing significant differences, or 
                            list of lists if 2-tailed analysis [lo, up].
    """

    n_exp = integ_data[1].shape[1]
    # calculate real values (average across seqs)
    data = [
        math_util.mean_med(integ_data[0], stats, axis=1, nanpol=nanpol),
        math_util.mean_med(integ_data[1], stats, axis=1, nanpol=nanpol)
    ]
    # ROI x seq
    qu_data_res = math_util.calc_op(np.asarray(data), op, dim=0)
    # concatenate unexp and exp from quantile
    qu_data_all = np.concatenate(integ_data, axis=1)
    # run permutation to identify significant ROIs
    all_rand_res = rand_util.permute_diff_ratio(qu_data_all, n_exp,
                                                permpar.n_perms, stats, nanpol,
                                                op)
    sign_rois = rand_util.id_elem(all_rand_res,
                                  qu_data_res,
                                  permpar.tails,
                                  permpar.p_val,
                                  log_elems=log_rois)
    return sign_rois
コード例 #2
0
def scale_across_rois(y_df, tr_idx, sessids, stats="mean"):
    """
    scale_across_rois(y_df, tr_idx, sessids)
    """

    y_df_new = copy.deepcopy(y_df)
    sessid_vals = np.unique(sessids)

    tr_idx_set = set(tr_idx)
    for val in sessid_vals:
        all_idx_set = set(np.where(sessids == val)[0])
        curr_tr = list(all_idx_set.intersection(tr_idx_set))
        curr_test = list(all_idx_set - tr_idx_set)

        scaled_tr, facts = math_util.scale_data(y_df.loc[curr_tr].to_numpy(),
                                                axis=0,
                                                sc_type="stand_rob",
                                                nanpol="omit")
        acr_rois_tr = math_util.mean_med(scaled_tr,
                                         stats=stats,
                                         axis=1,
                                         nanpol="omit")
        y_df_new.loc[curr_tr, "roi_data"] = acr_rois_tr

        scaled_test = math_util.scale_data(y_df.loc[curr_test].to_numpy(),
                                           axis=0,
                                           sc_type="stand_rob",
                                           facts=facts,
                                           nanpol="omit")
        acr_rois_test = math_util.mean_med(scaled_test,
                                           stats=stats,
                                           axis=1,
                                           nanpol="omit")
        y_df_new.loc[curr_test, "roi_data"] = acr_rois_test

    y_df_new = y_df_new[["roi_data"]]

    return y_df_new
コード例 #3
0
def get_df_stats(scores_df, analyspar):
    """
    get_df_stats(scores_df, analyspar)

    Returns statistics (mean/median and error) for each data column.

    Required args:
        - scores_df (pd.DataFrame):
            dataframe where each column contains data for which statistics 
            should be measured
        - analyspar (AnalysPar): 
            named tuple containing analysis parameters
    
    Returns:
        - stats_df (pd.DataFrame):
            dataframe with only one data row containing data stats for each 
            original column under "{col}_stat" and "{col}_err"
    """

    # take statistics
    stats_df = pd.DataFrame()
    for col in scores_df.columns:

        # get stats
        stat = math_util.mean_med(scores_df[col].to_numpy(),
                                  stats=analyspar.stats,
                                  nanpol="omit")
        err = math_util.error_stat(scores_df[col].to_numpy(),
                                   stats=analyspar.stats,
                                   error=analyspar.error,
                                   nanpol="omit")

        if isinstance(err, np.ndarray):
            err = err.tolist()
            stats_df = gen_util.set_object_columns(stats_df, [f"{col}_err"],
                                                   in_place=True)

        stats_df.loc[0, f"{col}_stat"] = stat
        stats_df.at[0, f"{col}_err"] = err

    return stats_df
コード例 #4
0
def add_relative_resp_data(resp_data_df,
                           analyspar,
                           rel_sess=1,
                           in_place=False):
    """
    add_relative_resp_data(resp_data_df, analyspar)

    Adds relative response data to input dataframe for any column with "exp" 
    in the name, optionally in place.

    Required args:
        - resp_data_df (pd.DataFrame):
            dataframe with one row per session, and response stats 
            (2D array, ROI x stats) under keys for expected ("exp") and 
            unexpected ("unexp") data, separated by Gabor frame 
            (e.g., "exp_3", "unexp_G") if applicable.
        - analyspar (AnalysPar): 
            named tuple containing analysis parameters

    Optional args:
        - rel_sess (int):
            number of session relative to which data should be scaled, for each 
            mouse
            default: 1
        - in_place (bool):
            if True, dataframe is modified in place

    Returns:
        - resp_data_df (pd.DataFrame):
            input dataframe, with "rel_{}" columns added for each input column 
            with "exp" in its name
    """

    if not in_place:
        resp_data_df = resp_data_df.copy(deep=True)

    nanpol = None if analyspar.rem_bad else "omit"

    source_columns = [col for col in resp_data_df.columns if "exp" in col]
    rel_columns = [f"rel_{col}" for col in source_columns]
    resp_data_df = gen_util.set_object_columns(resp_data_df,
                                               rel_columns,
                                               in_place=True)

    # calculate relative value for each
    for mouse_n, resp_mouse_df in resp_data_df.groupby("mouse_ns"):
        resp_data_df = resp_data_df.sort_values("sess_ns")
        # find sess 1 and check that there is only 1
        rel_sess_idx = resp_mouse_df.loc[resp_mouse_df["sess_ns"] ==
                                         rel_sess].index
        mouse_n_idxs = resp_mouse_df.index
        if len(rel_sess_idx) != 1:
            raise RuntimeError(
                f"Expected to find session {rel_sess} data for each mouse, "
                f"but if is missing for mouse {mouse_n}.")

        mouse_row = resp_mouse_df.loc[rel_sess_idx[0]]
        for source_col in source_columns:
            rel_col = source_col.replace("unexp", "exp")
            rel_data = math_util.mean_med(mouse_row[rel_col],
                                          analyspar.stats,
                                          nanpol=nanpol)
            for mouse_n_idx in mouse_n_idxs:
                resp_data_df.at[mouse_n_idx, f"rel_{source_col}"] = \
                    resp_data_df.loc[mouse_n_idx, source_col] / rel_data

    return resp_data_df
コード例 #5
0
def get_sess_ex_traces(sess, analyspar, stimpar, basepar, rolling_win=4):
    """
    get_sess_ex_traces(sess, analyspar, stimpar, basepar)

    Returns example traces selected for the session, based on SNR and Gabor 
    response pattern criteria. 

    Criteria:
    - Above median SNR
    - Sequence response correlation above 75th percentile.
    - Mean sequence standard deviation above 75th percentile.
    - Mean sequence skew above 75th percentile.

    Required args:
        - sess (Session):
            Session object
        - analyspar (AnalysPar): 
            named tuple containing analysis parameters
        - stimpar (StimPar): 
            named tuple containing stimulus parameters
        - basepar (BasePar): 
            named tuple containing baseline parameters

    Optional args:
        - rolling_win (int):
            window to use in rolling mean over individual trial traces before 
            computing correlation between trials (None for no smoothing)
            default: 4 

    Returns:
        - selected_roi_data (dict):
            ["time_values"] (1D array): values for each frame, in seconds
                (only 0 to stimpar.post, unless split is "by_exp")
            ["roi_ns"] (1D array): selected ROI numbers
            ["roi_traces_sm"] (3D array): selected ROI sequence traces, 
                smoothed, with dims: ROIs x seq x frames
            ["roi_trace_stats"] (2D array): selected ROI trace mean or median, 
                dims: ROIs x frames
    """

    nanpol = None if analyspar.rem_bad else "omit"

    if stimpar.stimtype != "gabors":
        raise NotImplementedError(
            "ROI selection criteria designed for Gabors, and based on their "
            "cyclical responses.")

    snr_analyspar = sess_ntuple_util.get_modif_ntuple(analyspar, "scale",
                                                      False)

    snrs = misc_analys.get_snr(sess, snr_analyspar, "snrs")
    snr_median = np.median(snrs)

    # identify ROIs that meet the SNR threshold
    snr_thr_rois = np.where(snrs > snr_median)[0]

    # collect all data, and compute statistics
    traces, time_values = basic_analys.get_split_data_by_sess(
        sess,
        analyspar=analyspar,
        stimpar=stimpar,
        split="by_exp",
        baseline=basepar.baseline,
    )

    traces_exp = np.asarray(traces[0])  # get expected split
    traces_exp_stat = math_util.mean_med(traces_exp,
                                         stats=analyspar.stats,
                                         axis=1,
                                         nanpol=nanpol)

    # smooth individual traces, then compute correlations
    if rolling_win is not None:
        traces_exp = math_util.rolling_mean(traces_exp, win=rolling_win)

    triu_idx = np.triu_indices(traces_exp[snr_thr_rois].shape[1], k=1)
    corr_medians = [
        np.median(np.corrcoef(roi_trace)[triu_idx])
        for roi_trace in traces_exp[snr_thr_rois]
    ]

    # calculate std and skew over trace statistics
    trace_stat_stds = np.std(traces_exp_stat[snr_thr_rois], axis=1)
    trace_stat_skews = scist.skew(traces_exp_stat[snr_thr_rois], axis=1)

    # identify ROIs that meet thresholds (from those with high enough SNR)
    std_thr = np.percentile(trace_stat_stds, 75)
    skew_thr = np.percentile(trace_stat_skews, 75)
    corr_thr = np.percentile(corr_medians, 75)

    selected_idx = np.where(
        ((trace_stat_stds > std_thr) * (corr_medians > corr_thr) *
         (trace_stat_skews > skew_thr)))[0]

    # re-index into all ROIs
    roi_ns = snr_thr_rois[selected_idx]

    selected_roi_data = {
        "time_values": time_values,
        "roi_ns": roi_ns,
        "roi_traces_sm": traces_exp[roi_ns],
        "roi_trace_stats": traces_exp_stat[roi_ns]
    }

    return selected_roi_data
コード例 #6
0
def get_sess_grped_diffs_df(sessions,
                            analyspar,
                            stimpar,
                            basepar,
                            permpar,
                            split="by_exp",
                            randst=None,
                            parallel=False):
    """
    get_sess_grped_diffs_df(sessions, analyspar, stimpar, basepar)

    Returns split difference statistics for specific sessions, grouped across 
    mice.

    Required args:
        - sessions (list): 
            session objects
        - analyspar (AnalysPar): 
            named tuple containing analysis parameters
        - stimpar (StimPar): 
            named tuple containing stimulus parameters
        - basepar (BasePar): 
            named tuple containing baseline parameters
        - permpar (PermPar): 
            named tuple containing permutation parameters

    Optional args:
        - split (str): 
            how to split data:
            "by_exp" (all exp, all unexp), 
            "unexp_lock" (unexp, preceeding exp), 
            "exp_lock" (exp, preceeding unexp),
            "stim_onset" (grayscr, stim on), 
            "stim_offset" (stim off, grayscr)
            default: "by_exp"
        - randst (int or np.random.RandomState): 
            random state or seed value to use. (-1 treated as None)
            default: None
        - parallel (bool): 
            if True, some of the analysis is run in parallel across CPU cores 
            default: False

    Returns:
        - diffs_df (pd.DataFrame):
            dataframe with one row per session/line/plane, and the following 
            columns, in addition to the basic sess_df columns: 
            - diff_stats (list): split difference stats (me, err)
            - null_CIs (list): adjusted null CI for split differences 
            - raw_p_vals (float): uncorrected p-value for differences within 
                sessions
            - p_vals (float): p-value for differences within sessions, 
                corrected for multiple comparisons and tails
            for session comparisons, e.g. 1v2:
            - raw_p_vals_{}v{} (float): uncorrected p-value for differences
                between sessions 
            - p_vals_{}v{} (float): p-value for differences between sessions, 
                corrected for multiple comparisons and tails
    """

    nanpol = None if analyspar.rem_bad else "omit"

    if analyspar.tracked:
        misc_analys.check_sessions_complete(sessions, raise_err=True)

    sess_diffs_df = misc_analys.get_check_sess_df(sessions, None, analyspar)
    initial_columns = sess_diffs_df.columns.tolist()

    # retrieve ROI index information
    args_dict = {
        "analyspar": analyspar,
        "stimpar": stimpar,
        "basepar": basepar,
        "split": split,
        "return_data": True,
    }

    # sess x split x ROI
    split_stats, split_data = gen_util.parallel_wrap(get_sess_roi_split_stats,
                                                     sessions,
                                                     args_dict=args_dict,
                                                     parallel=parallel,
                                                     zip_output=True)

    misc_analys.get_check_sess_df(sessions, sess_diffs_df)
    sess_diffs_df["roi_split_stats"] = list(split_stats)
    sess_diffs_df["roi_split_data"] = list(split_data)

    columns = initial_columns + ["diff_stats", "null_CIs"]
    diffs_df = pd.DataFrame(columns=columns)

    group_columns = ["lines", "planes", "sess_ns"]
    aggreg_cols = [col for col in initial_columns if col not in group_columns]
    for lp_grp_vals, lp_grp_df in sess_diffs_df.groupby(["lines", "planes"]):
        lp_grp_df = lp_grp_df.sort_values(["sess_ns", "mouse_ns"])
        line, plane = lp_grp_vals
        lp_name = plot_helper_fcts.get_line_plane_name(line, plane)
        logger.info(f"Running permutation tests for {lp_name} sessions...",
                    extra={"spacing": TAB})

        # obtain ROI random split differences per session
        # done here to avoid OOM errors
        lp_rand_diffs = gen_util.parallel_wrap(
            get_rand_split_data,
            lp_grp_df["roi_split_data"].tolist(),
            args_list=[analyspar, permpar, randst],
            parallel=parallel,
            zip_output=False)

        sess_diffs = []
        row_indices = []
        sess_ns = sorted(lp_grp_df["sess_ns"].unique())
        for sess_n in sess_ns:
            row_idx = len(diffs_df)
            row_indices.append(row_idx)
            sess_grp_df = lp_grp_df.loc[lp_grp_df["sess_ns"] == sess_n]

            grp_vals = list(lp_grp_vals) + [sess_n]
            for g, group_column in enumerate(group_columns):
                diffs_df.loc[row_idx, group_column] = grp_vals[g]

            # add aggregated values for initial columns
            diffs_df = misc_analys.aggreg_columns(sess_grp_df,
                                                  diffs_df,
                                                  aggreg_cols,
                                                  row_idx=row_idx,
                                                  in_place=True)

            # group ROI split stats across mice: split x ROIs
            split_stats = np.concatenate(
                sess_grp_df["roi_split_stats"].to_numpy(), axis=-1)

            # take diff and stats across ROIs
            diffs = split_stats[1] - split_stats[0]
            diff_stats = math_util.get_stats(diffs,
                                             stats=analyspar.stats,
                                             error=analyspar.error,
                                             nanpol=nanpol)
            diffs_df.at[row_idx, "diff_stats"] = diff_stats.tolist()
            sess_diffs.append(diffs)

            # group random ROI split diffs across mice, and take stat
            rand_idxs = [
                lp_grp_df.index.tolist().index(idx)
                for idx in sess_grp_df.index
            ]
            rand_diffs = math_util.mean_med(np.concatenate(
                [lp_rand_diffs[r] for r in rand_idxs], axis=0),
                                            axis=0,
                                            stats=analyspar.stats,
                                            nanpol=nanpol)

            # get CIs and p-values
            p_val, null_CI = rand_util.get_p_val_from_rand(
                diff_stats[0],
                rand_diffs,
                return_CIs=True,
                p_thresh=permpar.p_val,
                tails=permpar.tails,
                multcomp=permpar.multcomp,
                nanpol=nanpol)
            diffs_df.loc[row_idx, "p_vals"] = p_val
            diffs_df.at[row_idx, "null_CIs"] = null_CI

        del lp_rand_diffs  # free up memory

        # calculate p-values between sessions (0-1, 0-2, 1-2...)
        p_vals = rand_util.comp_vals_acr_groups(sess_diffs,
                                                n_perms=permpar.n_perms,
                                                stats=analyspar.stats,
                                                paired=analyspar.tracked,
                                                nanpol=nanpol,
                                                randst=randst)
        p = 0
        for i, sess_n in enumerate(sess_ns):
            for j, sess_n2 in enumerate(sess_ns[i + 1:]):
                key = f"p_vals_{int(sess_n)}v{int(sess_n2)}"
                diffs_df.loc[row_indices[i], key] = p_vals[p]
                diffs_df.loc[row_indices[j + 1], key] = p_vals[p]
                p += 1

    # add corrected p-values
    diffs_df = misc_analys.add_corr_p_vals(diffs_df, permpar)

    diffs_df["sess_ns"] = diffs_df["sess_ns"].astype(int)

    return diffs_df
コード例 #7
0
def get_ex_idx_corr_norm_df(sessions, analyspar, stimpar, basepar, idxpar, 
                            permpar, permute="sess", sig_only=False, n_bins=40, 
                            randst=None, parallel=False):
    """
    get_ex_idx_corr_norm_df(sessions, analyspar, stimpar, basepar, idxpar, 
                            permpar)

    Returns example correlation normalization data.

    Required args:
        - sessions (list): 
            Session objects
        - analyspar (AnalysPar): 
            named tuple containing analysis parameters
        - stimpar (StimPar): 
            named tuple containing stimulus parameters
        - basepar (BasePar): 
            named tuple containing baseline parameters
        - idxpar (IdxPar): 
            named tuple containing index parameters
        - permpar (PermPar): 
            named tuple containing permutation parameters.
    
    Optional args:
        - permute (bool):
            type of permutation to due ("tracking", "sess" or "all")
            default: "sess"
        - sig_only (bool):
            if True, ROIs with significant USIs are included 
            (only possible if analyspar.tracked is True)
            default: False
        - n_bins (int):
            number of bins
            default: 40
        - randst (int or np.random.RandomState): 
            seed value to use. (-1 treated as None)
            default: None
        - parallel (bool): 
            if True, some of the analysis is run in parallel across CPU cores 
            default: False

    Returns:
        - idx_corr_norm_df (pd.DataFrame):
            dataframe with one row for a line/plane, and the 
            following columns, in addition to the basic sess_df columns:

            for a specific session comparison, e.g. 1v2
            - {}v{}_corrs (float): unnormalized intersession ROI index 
                correlations
            - {}v{}_norm_corrs (float): normalized intersession ROI index 
                correlations
            - {}v{}_rand_ex_corrs (float): unnormalized intersession 
                ROI index correlations for an example of randomized data
            - {}v{}_rand_corr_meds (float): median of randomized correlations

            - {}v{}_corr_data (list): intersession values to correlate
            - {}v{}_rand_ex (list): intersession values for an example of 
                randomized data
            - {}v{}_rand_corrs_binned (list): binned random unnormalized 
                intersession ROI index correlations
            - {}v{}_rand_corrs_bin_edges (list): bins edges
    """

    nanpol = None if analyspar.rem_bad else "omit"

    initial_columns = misc_analys.get_sess_df_columns(sessions[0], analyspar)
    
    lp_idx_df = get_lp_idx_df(
        sessions, 
        analyspar=analyspar, 
        stimpar=stimpar, 
        basepar=basepar, 
        idxpar=idxpar,
        permpar=permpar,
        sig_only=sig_only,
        randst=randst,
        parallel=parallel,
        )

    idx_corr_norm_df = get_basic_idx_corr_df(lp_idx_df, consec_only=False)
    if len(idx_corr_norm_df) != 1:
        raise ValueError("sessions should be from the same line/plane.")

    # get correlation pairs
    corr_ns = get_corr_pairs(lp_idx_df)

    if len(corr_ns) != 1:
        raise ValueError("Sessions should allow only one pair.")
    sess_pair = corr_ns[0]
    corr_name = f"{sess_pair[0]}v{sess_pair[1]}"

    drop_columns = [
        col for col in idx_corr_norm_df.columns if col not in initial_columns
        ]
    idx_corr_norm_df = idx_corr_norm_df.drop(columns=drop_columns)

    logger.info(
        ("Calculating ROI USI correlations for a single session pair..."), 
        extra={"spacing": TAB}
        )

    corr_type = "diff_corr"
    returns = get_corr_data(
        sess_pair, 
        data_df=lp_idx_df, 
        analyspar=analyspar, 
        permpar=permpar, 
        permute=permute, 
        corr_type=corr_type,
        absolute=False,
        norm=False,
        return_data=True,
        return_rand=True,
        n_rand_ex=1, 
        randst=randst
        )

    roi_corr, _, _, _, corr_data, rand_corrs, rand_exs, rand_ex_corrs = returns
    rand_ex = rand_exs[..., 0]
    rand_ex_corr = rand_ex_corrs[0]

    rand_corr_med = math_util.mean_med(
        rand_corrs, stats="median", nanpol=nanpol
        )
    norm_roi_corr = float(
        get_norm_corrs(roi_corr, med=rand_corr_med, corr_type=corr_type)
        )

    row_idx = idx_corr_norm_df.index[0]

    idx_corr_norm_df.loc[row_idx, f"{corr_name}_corrs"] = roi_corr
    idx_corr_norm_df.loc[row_idx, f"{corr_name}_rand_ex_corrs"] = rand_ex_corr
    idx_corr_norm_df.loc[row_idx, f"{corr_name}_rand_corr_meds"] = rand_corr_med
    idx_corr_norm_df.loc[row_idx, f"{corr_name}_norm_corrs"] = norm_roi_corr

    cols = [
        f"{corr_name}_{col_name}" 
        for col_name in 
        ["corr_data", "rand_ex", "rand_corrs_binned", "rand_corrs_bin_edges"]
        ]
    idx_corr_norm_df = gen_util.set_object_columns(
        idx_corr_norm_df, cols, in_place=True
        )

    idx_corr_norm_df.at[row_idx, f"{corr_name}_corr_data"] = corr_data.tolist()
    idx_corr_norm_df.at[row_idx, f"{corr_name}_rand_ex"] = rand_ex.tolist()

    fcts = [np.min, np.max] if nanpol is None else [np.nanmin, np.nanmax]
    bounds = [fct(rand_corrs) for fct in fcts]
    bins = np.linspace(*bounds, n_bins + 1)
    rand_corrs_binned = np.histogram(rand_corrs, bins=bins)[0]

    idx_corr_norm_df.at[row_idx, f"{corr_name}_rand_corrs_bin_edges"] = \
        [bounds[0], bounds[-1]]
    idx_corr_norm_df.at[row_idx, f"{corr_name}_rand_corrs_binned"] = \
        rand_corrs_binned.tolist()

    return idx_corr_norm_df
コード例 #8
0
def qu_mags(all_data, permpar, mouse_ns, lines, stats="mean", error="sem", 
            nanpol=None, op_qu="diff", op_unexp="diff", log_vals=True):
    """
    qu_mags(all_data, permpar, mouse_ns, lines)

    Returns a dictionary containing the results of the magnitudes and L2 
    analysis, as well as the results of the permutation test.

    Specifically, magnitude and L2 norm are calculated as follows: 
        - Magnitude: for unexp and expected segments: 
                         mean/median across ROIs of
                             diff/ratio in average activity between 2 quantiles
        - L2 norm:   for unexp and expected segments: 
                         L2 norm across ROIs of
                             diff/ratio in average activity between 2 quantiles
    
    Significance is assessed based on the diff/ratio between unexpected and 
    expected magnitude/L2 norm results.

    Optionally, the magnitudes and L2 norms are logged for each session, with
    significance indicated.

    Required args:
        - all_data (list)  : nested list of data, structured as:
                                 session x unexp x qu x array[(ROI x) seqs]
        - permpar (PermPar): named tuple containing permutation parameters
        - mouse_ns (list)  : list of mouse numbers (1 per session)
        - lines (list)     : list of mouse lines (1 per session)

    Optional args:
        - stats (str)      : statistic to take across segments, (then ROIs) 
                             ("mean" or "median")
                             default: "mean"
        - error (str)      : statistic to take across segments, (then ROIs) 
                             ("std" or "sem")
                             default: "sem"
        - nanpol (str)     : policy for NaNs, "omit" or None when taking 
                             statistics
                             default: None
        - op_qu (str)      : Operation to use in comparing the last vs first 
                             quantile ("diff" or "ratio")
                             default: "diff"       
        - op_unexp (str)    : Operation to use in comparing the unexpected vs 
                             expected, data ("diff" or "ratio")
                             default: "diff" 
        - log_vals (bool)  : If True, the magnitudes and L2 norms are logged
                             for each session, with significance indicated.

    Returns:
        - mags (dict): dictionary containing magnitude and L2 data to plot.
            ["L2"] (3D array)        : L2 norms, structured as: 
                                           sess x scaled x unexp
            ["mag_st"] (4D array)    : magnitude stats, structured as: 
                                           sess x scaled x unexp x stats
            ["L2_rel_th"] (2D array) : L2 thresholds calculated from 
                                       permutation analysis, structured as:
                                           sess x tail(s)
            ["mag_rel_th"] (2D array): magnitude thresholds calculated from
                                       permutation analysis, structured as:
                                           sess x tail(s)
            ["L2_sig"] (list)        : L2 significance results for each session 
                                       ("hi", "lo" or "no")
            ["mag_sig"] (list)       : magnitude significance results for each 
                                       session 
                                           ("hi", "lo" or "no")
    """


    n_sess = len(all_data)
    n_qu   = len(all_data[0][0])
    scales = [False, True]
    unexps    = ["exp", "unexp"]
    stat_len = 2 + (stats == "median" and error == "std")
    tail_len = 1 + (str(permpar.tails) == "2")

    if n_qu != 2:
        raise ValueError(f"Expected 2 quantiles, but found {n_qu}.")
    if len(unexps) != 2:
        raise ValueError("Expected a length 2 unexpected dim, "
            f"but found length {len(unexps)}.")
    
    mags = {"mag_st": np.empty([n_sess, len(scales), len(unexps), stat_len]),
            "L2"    : np.empty([n_sess, len(scales), len(unexps)])
           }
    
    for lab in ["mag_sig", "L2_sig"]:
        mags[lab] = []
    for lab in ["mag_rel_th", "L2_rel_th"]:
        mags[lab] = np.empty([n_sess, tail_len])

    all_data = copy.deepcopy(all_data)
    for i in range(n_sess):
        logger.info(f"Mouse {mouse_ns[i]}, {lines[i]}:", 
            extra={"spacing": "\n"})
        sess_data_me = []
        # number of expected sequences
        n_exps = [all_data[i][0][q].shape[-1] for q in range(n_qu)]
        for s in range(len(unexps)):
            # take the mean for each quantile
            data_me = np.asarray(
                [math_util.mean_med(all_data[i][s][q], stats, axis=-1, 
                nanpol=nanpol) for q in range(n_qu)])

            if len(data_me.shape) == 1:
                # add dummy ROI-like axis, e.g. for run data
                data_me = data_me[:, np.newaxis]
                all_data[i][s] = \
                    [qu_data[np.newaxis, :] for qu_data in all_data[i][s]]

            msgs = ["Degrees of freedom", "invalid value"]
            categs = [RuntimeWarning, RuntimeWarning]
            with gen_util.TempWarningFilter(msgs, categs):
                mags["mag_st"][i, 0, s] = math_util.calc_mag_change(
                    data_me, 0, 1, order="stats", op=op_qu, stats=stats, 
                    error=error)
                mags["L2"][i, 0, s] = math_util.calc_mag_change(
                    data_me, 0, 1, order=2, op=op_qu)
            sess_data_me.append(data_me)
        # scale
        sess_data_me = np.asarray(sess_data_me)
        mags["mag_st"][i, 1] = math_util.calc_mag_change(
            sess_data_me, 1, 2, order="stats", op=op_qu, stats=stats, 
            error=error, scale=True, axis=0, pos=0, sc_type="unit").T
        mags["L2"][i, 1] = math_util.calc_mag_change(
            sess_data_me, 1, 2, order=2, op=op_qu, stats=stats, scale=True, 
            axis=0, pos=0, sc_type="unit").T
        
        # diff/ratio for permutation test
        act_mag_rel = math_util.calc_op(mags["mag_st"][i, 0, :, 0], op=op_unexp)
        act_L2_rel  = math_util.calc_op(mags["L2"][i, 0, :], op=op_unexp)

        # concatenate expected and unexpected sequences for each quantile
        all_data_perm = [np.concatenate(
            [all_data[i][0][q], all_data[i][1][q]], axis=1) 
                for q in range(n_qu)]
        
        signif, ths = run_mag_permute(
            all_data_perm, act_mag_rel, act_L2_rel, n_exps, permpar, op_qu, 
            op_unexp, stats, nanpol)
        
        mags["mag_sig"].append(signif[0])
        mags["L2_sig"].append(signif[1])
        mags["mag_rel_th"][i] = np.asarray(ths[0])
        mags["L2_rel_th"][i] = np.asarray(ths[1])

        # logs results 
        if log_vals:
            sig_symb = ["", ""]
            for si, sig in enumerate(signif):
                if sig != "no":
                    sig_symb[si] = "*"

            vals = [mags["mag_st"][i, 0, :, 0], mags["L2"][i, 0, :]]
            names = [f"{stats} mag".capitalize(), "L2"]
            for v, (val, name) in enumerate(zip(vals, names)):
                for s, unexp in zip([0, 1], ["(exp) ", "(unexp)"]):
                    logger.info(f"{name} {unexp}: {val[s]:.4f}{sig_symb[v]}", 
                        extra={"spacing": TAB})
        
    return mags
コード例 #9
0
def collect_base_data(sessions,
                      analyspar,
                      stimpar,
                      datatype="rel_unexp_resp",
                      rel_sess=1,
                      basepar=None,
                      idxpar=None,
                      abs_usi=True,
                      parallel=False):
    """
    collect_base_data(sessions, analyspar, stimpar)

    Collects base data for which stimulus and sessions comparisons are to be 
    calculated.

    Required args:
        - sessions (list): 
            session objects
        - analyspar (AnalysPar): 
            named tuple containing analysis parameters
        - stimpar (StimPar): 
            named tuple containing stimulus parameters

    Optional args:
        - datatype (str):
            type of data to retrieve ("rel_unexp_resp" or "usis")
            default: "rel_unexp_resp"
        - rel_sess (int):
            number of session relative to which data should be scaled, for each 
            mouse
            default: 1
        - basepar (BasePar): 
            named tuple containing baseline parameters 
            (needed if datatype is "usis")
            default: None
        - idxpar (IdxPar): 
            named tuple containing index parameters 
            (needed if datatype is "usis")
            default: None
        - abs_usi (bool): 
            if True, absolute USIs are returned (applies if datatype is "usis")
            default: True
        - parallel (bool): 
            if True, some of the analysis is run in parallel across CPU cores 
            default: False

    Returns:
        - data_df (pd.DataFrame):
            dataframe with one row per session, and the basic sess_df columns, 
            in addition to datatype column:
            - {datatype} (1D array): data per ROI
    """

    nanpol = None if analyspar.rem_bad else "omit"

    initial_columns = misc_analys.get_sess_df_columns(sessions[0], analyspar)

    if datatype == "rel_unexp_resp":
        data_df = seq_analys.get_resp_df(sessions,
                                         analyspar,
                                         stimpar,
                                         rel_sess=rel_sess,
                                         parallel=parallel)
        if stimpar.stimtype == "gabors":
            unexp_gabfrs = stimpar.gabfr[1]
            unexp_data = [data_df[f"rel_unexp_{fr}"] for fr in unexp_gabfrs]
            data_df[datatype] = [
                math_util.mean_med(data,
                                   stats=analyspar.stats,
                                   axis=0,
                                   nanpol=nanpol) for data in zip(*unexp_data)
            ]
        else:
            data_df = data_df.rename(columns={"rel_unexp": datatype})

    elif datatype == "usis":
        if basepar is None or idxpar is None:
            raise ValueError(
                "If datatype is 'usis', must provide basepar and idxpar.")
        data_df = usi_analys.get_idx_only_df(sessions,
                                             analyspar,
                                             stimpar,
                                             basepar,
                                             idxpar,
                                             parallel=parallel)
        data_df = data_df.rename(columns={"roi_idxs": datatype})
        if abs_usi:
            # absolute USIs
            data_df[datatype] = data_df[datatype].map(np.absolute)
    else:
        gen_util.accepted_values_error("datatype", datatype,
                                       ["rel_unexp_resp", "usis"])

    data_df = data_df[initial_columns + [datatype]]

    return data_df
コード例 #10
0
def plot_perc_sig_usis(perc_sig_df, analyspar, permpar, figpar, by_mouse=False, 
                       title=None):
    """
    plot_perc_sig_usis(perc_sig_df, analyspar, figpar)

    Plots percentage of significant USIs.

    Required args:
        - perc_sig_df (pd.DataFrame):
            dataframe with one row per session/line/plane, and the following 
            columns, in addition to the basic sess_df columns:
            for sig in ["lo", "hi"]: for low vs high ROI indices
            - perc_sig_{sig}_idxs (num): percent significant ROIs (0-100)
            - perc_sig_{sig}_idxs_stds (num): bootstrapped standard deviation 
                over percent significant ROIs
            - perc_sig_{sig}_idxs_CIs (list): adjusted CI for percent sig. ROIs 
            - perc_sig_{sig}_idxs_null_CIs (list): adjusted null CI for percent 
                sig. ROIs
            - perc_sig_{sig}_idxs_raw_p_vals (num): uncorrected p-value for 
                percent sig. ROIs
            - perc_sig_{sig}_idxs_p_vals (num): p-value for percent sig. 
                ROIs, corrected for multiple comparisons and tails

        - analyspar (dict): 
            dictionary with keys of AnalysPar namedtuple
        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - by_mouse (bool):
            if True, plotting is done per mouse
            default: False
        - title (str):
            plot title
            default: None
        
    Returns:
        - ax (2D array): 
            array of subplots
    """  

    perc_sig_df = perc_sig_df.copy(deep=True)

    nanpol = None if analyspar["rem_bad"] else "omit"

    sess_ns = perc_sig_df["sess_ns"].unique()
    if len(sess_ns) != 1:
        raise NotImplementedError(
            "Plotting function implemented for 1 session only."
            )

    figpar = sess_plot_util.fig_init_linpla(figpar, kind="idx", n_sub=1, 
        sharex=True, sharey=True)

    figpar["init"]["sharey"] = True
    figpar["init"]["subplot_wid"] = 3.4
    figpar["init"]["gs"] = {"wspace": 0.18}
    if by_mouse:
        figpar["init"]["subplot_hei"] = 8.4
    else:
        figpar["init"]["subplot_hei"] = 3.5

    fig, ax = plot_util.init_fig(2, **figpar["init"])
    if title is not None:
        y = 0.98 if by_mouse else 1.07
        fig.suptitle(title, y=y, weight="bold")

    tail_order = ["Low tail", "High tail"]
    tail_keys = ["lo", "hi"]
    chance = permpar["p_val"] / 2 * 100

    ylims = get_perc_sig_ylims(perc_sig_df, high_pt_min=40)
    n_linpla = plot_helper_fcts.N_LINPLA

    comp_info = misc_analys.get_comp_info(permpar)
    
    logger.info(f"{comp_info}:", extra={"spacing": "\n"})
    for t, (tail, key) in enumerate(zip(tail_order, tail_keys)):
        sub_ax = ax[0, t]
        sub_ax.set_title(tail, fontweight="bold")
        sub_ax.set_ylim(ylims)

        # replace bottom spine with line at 0
        sub_ax.spines['bottom'].set_visible(False)
        sub_ax.axhline(y=0, c="k", lw=4.0)

        data_key = f"perc_sig_{key}_idxs"

        CIs = np.full((plot_helper_fcts.N_LINPLA, 2), np.nan)
        CI_meds = np.full(plot_helper_fcts.N_LINPLA, np.nan)

        tail_sig_str = f"{tail:9}:"
        linpla_names = []
        for (line, plane), lp_df in perc_sig_df.groupby(["lines", "planes"]):
            li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
            x_index = 2 * li + pl
            linpla_name = plot_helper_fcts.get_line_plane_name(line, plane)
            linpla_names.append(linpla_name)
            
            if len(lp_df) == 0:
                continue
            elif len(lp_df) > 1 and not by_mouse:
                raise RuntimeError("Expected a single row per line/plane.")

            lp_df = lp_df.sort_values("mouse_ns") # sort by mouse
            df_indices = lp_df.index.tolist()

            if by_mouse:
                # plot means or medians per mouse
                mouse_data = lp_df[data_key].to_numpy()
                mouse_cols = plot_util.get_hex_color_range(
                    len(lp_df), col=col, 
                    interval=plot_helper_fcts.MOUSE_COL_INTERVAL
                    )
                mouse_data_mean = math_util.mean_med(
                    mouse_data, stats=analyspar["stats"], nanpol=nanpol
                    )
                CI_dummy = np.repeat(mouse_data_mean, 2)
                plot_util.plot_CI(sub_ax, CI_dummy, med=mouse_data_mean, 
                    x=x_index, width=0.6, med_col=col, med_rat=0.01)
            else:
                # collect confidence interval data
                row = lp_df.loc[df_indices[0]]
                mouse_cols = [col]
                CIs[x_index] = np.asarray(row[f"{data_key}_null_CIs"])[
                    np.asarray([0, 2])
                    ]
                CI_meds[x_index] = row[f"{data_key}_null_CIs"][1]

            if by_mouse:
                perc_p_vals = []
                rel_y = 0.05
            else:
                tail_sig_str = f"{tail_sig_str}{TAB}{linpla_name}: "
                rel_y = 0.1

            for df_i, mouse_col in zip(df_indices, mouse_cols):
                # plot UFOs
                err = None
                no_line = True
                if not by_mouse:
                    err = perc_sig_df.loc[df_i, f"{data_key}_stds"]
                    no_line = False
                # indicate bootstrapped error with wider capsize
                plot_util.plot_ufo(
                    sub_ax, x_index, perc_sig_df.loc[df_i, data_key], err,
                    color=mouse_col, capsize=8, no_line=no_line
                    )

                # add significance markers
                p_val = perc_sig_df.loc[df_i, f"{data_key}_p_vals"]
                perc = perc_sig_df.loc[df_i, data_key]
                nrois = np.sum(perc_sig_df.loc[df_i, "nrois"])
                side = np.sign(perc - chance)
                sensitivity = misc_analys.get_binom_sensitivity(
                    nrois, null_perc=chance, side=side
                    )                
                sig_str = misc_analys.get_sig_symbol(
                    p_val, sensitivity=sensitivity, side=side, 
                    tails=permpar["tails"], p_thresh=permpar["p_val"]
                    )

                if len(sig_str):
                    perc_high = perc + err if err is not None else perc
                    plot_util.add_signif_mark(sub_ax, x_index, perc_high, 
                        rel_y=rel_y, color=mouse_col, fontsize=24, 
                        mark=sig_str) 

                if by_mouse:
                    perc_p_vals.append(
                        (int(np.around(perc)), p_val, sig_str)
                    )
                else:
                    tail_sig_str = (
                        f"{tail_sig_str}{p_val:.5f}{sig_str:3}"
                        )

            if by_mouse: # sort p-value logging by percentage value
                tail_sig_str = f"{tail_sig_str}\n\t{linpla_name:6}: "
                order = np.argsort([vals[0] for vals in perc_p_vals])
                for i in order:
                    perc, p_val, sig_str = perc_p_vals[i]
                    perc_str = f"(~{perc}%)"
                    tail_sig_str = (
                        f"{tail_sig_str}{TAB}{perc_str:6} "
                        f"{p_val:.5f}{sig_str:3}"
                        )
                
        # add chance information
        if by_mouse:
            sub_ax.axhline(
                y=chance, ls=plot_helper_fcts.VDASH, c="k", lw=3.0, alpha=0.5, 
                zorder=-12
                )
        else:
            plot_util.plot_CI(sub_ax, CIs.T, med=CI_meds, 
                x=np.arange(n_linpla), width=0.45, med_rat=0.025, zorder=-12)

        logger.info(tail_sig_str, extra={"spacing": TAB})
    
    for sub_ax in fig.axes:
        sub_ax.tick_params(axis="x", which="both", bottom=False) 
        plot_util.set_ticks(
            sub_ax, min_tick=0, max_tick=n_linpla - 1, n=n_linpla, pad_p=0.2)
        sub_ax.set_xticklabels(linpla_names, rotation=90, weight="bold")

    ax[0, 0].set_ylabel("%", fontweight="bold")
    plot_util.set_interm_ticks(ax, 3, axis="y", weight="bold", share=True)

    # adjustment if tick interval is repeated in the negative
    if ax[0, 0].get_ylim()[0] < 0:
        ax[0, 0].set_ylim([ylims[0], ax[0, 0].get_ylim()[1]])

    return ax
コード例 #11
0
def summ_subplot(ax, arr, sh_arr, data_title, mouse_ns, sess_ns, line, plane, 
                 stat="mean", error="sem", CI=0.95, q1v4=False, evu=False, 
                 split_oris=False, modif=False):
    """
    summ_subplot(ax, arr, data_title, mouse_ns, sess_ns, line, plane, title)

    Plots summary data in the specific subplot for a line and plane.

    Required args:
        - ax (plt Axis subplot): subplot
        - arr (3D array)       : array of session information, structured as 
                                 mice x sessions x vals, where vals
                                 are: mean/med, sem/low_perc, sem/hi_perc, 
                                      (x2 if q1v4 and test accuracy)
                                      n_rois, n_runs
        - sh_arr (3D array)    : array of session information, structured as 
                                 mice (1) x sessions x vals, where vals
                                 are: mean/med, sem/low_perc, sem/hi_perc, 
                                      (x2 if q1v4 and test accuracy), n_runs
        - data_title (str)     : name of type of data plotted (must contain 
                                 "data"), i.e. for epochs or test accuracy
        - mouse_ns (int)       : mouse numbers (-1 for shuffled data)
        - sess_ns (int)        : session numbers
        - line (str)           : transgenic line name
        - plane (str)          : plane name
    
    Optional args:
        - stat (str)       : stats to take for non shuffled data, 
                             i.e., "mean" or "median" 
                             default: "mean"
        - error (str)      : error stats to take for non shuffled data, 
                             i.e., "std", "sem"
                             default: "sem"
        - CI (num)         : CI for shuffled data (e.g., 0.95)
                             default: 0.95
        - q1v4 (bool)      : if True, analysis is separated across first and 
                             last quartiles
                             default: False
        - evu (bool)       : if True, the first dataset will include expected 
                             sequences and the second will include unexpected 
                             sequences
                             default: False
        - split_oris (list): if not False, the dataset will include 
                             orientations from the first set of Gabor frame, 
                             and the second, from the second set.
                             default: False 
        - modif (bool)     : if True, plots are made in a modified (simplified 
                             way)
                             default: False

    """

    if modif:
        # only plot first few sessions
        limit = 3
        arr = arr[:, :limit]
        sh_arr = sh_arr[:, :limit]
        x_label = [x + 1 for x in range(arr.shape[1])]
    else:
        x_label = rois_x_label(sess_ns, arr[:, :, -2])

    if "acc" in data_title.lower():
        if (not modif or ax.is_first_row()) and ax.is_first_col():
            if q1v4:
                ax.set_ylabel("Accuracy in Q4 (%)")
            elif evu:
                ax.set_ylabel("Accuracy for unexp (%)")
            else:
                ax.set_ylabel("Accuracy (%)")
            plot_util.set_ticks(ax, "y", 0, 100, 6, pad_p=0)

    elif "epoch" in data_title.lower():
        q1v4 = False # treated as if no Q4
        evu = False # treated as if no exp v unexp
        split_oris = False # treated as if no 2 sets of Gabor frames
        if (not modif or ax.is_first_row()) and ax.is_first_col():
            ax.set_ylabel("Nbr epochs")
        plot_util.set_ticks(ax, "y", 0, 1000, 6, pad_p=0)

    if q1v4 or evu or split_oris:
        mean_ids = [0, 3]
        alphas = [0.3, 0.8]
        if modif:
            alphas = [0.5, 0.8]
        if q1v4:
            add_leg = [" (Q1)", " (Q4)"]
        elif evu:
            add_leg = [" (exp)", " (unexp)"]
        else:
            add_leg = [f" ({split_oris[0]})", f" ({split_oris[1]})"]
    else:
        mean_ids = [0]
        alphas = [0.5]
        if modif:
            alphas = [0.8]
        add_leg = [""]
    
    plot_CI(ax, x_label, sh_arr, CI, q1v4 + evu, modif)
    
    # plot non shuffle data
    main_col = "blue"
    if plane == "dend":
        main_col = "green"

    if line == "L23":
        line = "L2/3"

    if not modif:
        cols = plot_util.get_color_range(len(mouse_ns), main_col)
        for m, mouse_n in enumerate(mouse_ns):
            leg = mouse_runs_leg(arr[m, :, -1], mouse_n, False)
            for i, m_i in enumerate(mean_ids):
                leg_i = leg.index("\n")
                leg_m = f"{leg[: leg_i]}{add_leg[i]}{leg[leg_i :]}"
                ax.errorbar(x_label, arr[m, :, m_i], yerr=arr[m, :, m_i + 1], 
                    fmt="-o", markersize=12, capthick=4, label=leg_m, 
                    alpha=alphas[i], lw=3, color=cols[m])
        for i in range(len(x_label)):
            for m, m_i in enumerate(mean_ids):
                if not np.isnan(arr[:, i, m_i]).all():
                    med = math_util.mean_med(
                        arr[:, i, m_i], axis=0, stats=stat, nanpol="omit")
                    y_lim = ax.get_ylim()
                    med_th = 0.0075 * (y_lim[1]-y_lim[0])
                    ax.bar(x_label[i], height=med_th, bottom=med - med_th/2., 
                        color="black", width=0.5, alpha=alphas[m])
        title_pre = data_title[: data_title.index("data")]
        title_post = data_title[data_title.index("data") :] 
        title = f"{title_pre}{line} {plane} {title_post}"

        ax.set_title(title, y=1.03)

    # add a mean line
    else:
        col = plot_util.get_color(main_col, ret="single")
        for m, m_i in enumerate(mean_ids):
            if not np.isnan(arr[:, :, m_i]).all():
                meds = math_util.mean_med(arr[:, :, m_i], axis=0, stats=stat, 
                    nanpol="omit")
                errs = math_util.error_stat(arr[:, :, m_i], axis=0, 
                    stats=stat, error="sem", nanpol="omit")
                plot_util.plot_errorbars(ax, meds, err=errs, x=x_label, 
                    color=col, alpha=alphas[m], xticks="auto")
    
    if not modif:
        ax.legend()
コード例 #12
0
def peristim_data(sess,
                  stimpar,
                  ran_s=None,
                  datatype="both",
                  returns="diff",
                  fluor="dff",
                  stats="mean",
                  rem_bad=True,
                  scale=False,
                  first_unexp=True,
                  trans_all=False):
    """
    peristim_data(sess, stimpar)

    Returns pupil, ROI and run data around unexpected onset, or the difference 
    between post and pre unexpected onset, or both.

    Required args:
        - sess (Session)   : session object
        - stimpar (StimPar): named tuple containing stimulus parameters

    Optional args:
        - ran_s (dict, list or num): number of frames to take before and after 
                                     unexpected for each datatype (ROI, run, 
                                     pupil) (in sec).  
                                         If dictionary, expected keys are:
                                            "pup_pre", "pup_post", 
                                            ("roi_pre", "roi_post"), 
                                            ("run_pre", "run_post"), 
                                        If list, should be structured as 
                                        [pre, post] and the same values will be 
                                        used for all datatypes. 
                                        If num, the same value will be used 
                                            for all keys. 
                                        If None, the values are taken from the
                                            stimpar pre and post attributes.
                                     default: None
        - datatype (str)           : type of data to include with pupil data, 
                                     "roi", "run" or "both"
                                     default: "roi" 
        - returns (str)            : type of data to return (data around 
                                     unexpected, difference between post and pre 
                                     unexpected)
                                     default: "diff"
        - fluor (str)              : if "dff", dF/F is used, if "raw", ROI 
                                     traces
                                     default: "dff"
        - stats (str)              : measure on which to take the pre and post
                                     unexpected difference: either mean ("mean") 
                                     or median ("median")
                                     default: "mean"
        - rem_bad (bool)           : if True, removes ROIs with NaN/Inf values 
                                     anywhere in session and running array with
                                     NaNs linearly interpolated is used. If 
                                     False, NaNs are ignored in calculating 
                                     statistics for the ROI and running data 
                                     (always ignored for pupil data)
                                     default: True
        - scale (bool)             : if True, data is scaled
                                     default: False
        - first_unexp (bool)        : if True, only the first of consecutive 
                                     unexpecteds are retained
                                     default: True
        - trans_all (bool)         : if True, only ROIs with transients are 
                                     retained
                                     default: False

    Returns:
        if datatype == "data" or "both":
        - datasets (list): list of 2-3D data arrays, structured as
                               datatype (pupil, (ROI), (running)) x 
                               [trial x frames (x ROI)]
        elif datatype == "diff" or "both":
        - diffs (list)   : list of 1-2D data difference arrays, structured as
                               datatype (pupil, (ROI), (running)) x 
                               [trial (x ROI)]    
    """

    stim = sess.get_stim(stimpar.stimtype)

    # initialize ran_s dictionary if needed
    if ran_s is None:
        ran_s = [stimpar.pre, stimpar.post]
    ran_s = get_ran_s(ran_s, datatype)

    if first_unexp:
        unexp_segs = stim.get_segs_by_criteria(
            visflow_dir=stimpar.visflow_dir,
            visflow_size=stimpar.visflow_size,
            gabk=stimpar.gabk,
            unexp=1,
            remconsec=True,
            by="seg")
        if stimpar.stimtype == "gabors":
            unexp_segs = [seg + stimpar.gabfr for seg in unexp_segs]
    else:
        unexp_segs = stim.get_segs_by_criteria(
            visflow_dir=stimpar.visflow_dir,
            visflow_size=stimpar.visflow_size,
            gabk=stimpar.gabk,
            gabfr=stimpar.gabfr,
            unexp=1,
            remconsec=False,
            by="seg")

    unexp_twopfr = stim.get_fr_by_seg(unexp_segs, start=True,
                                      fr_type="twop")["start_frame_twop"]
    unexp_stimfr = stim.get_fr_by_seg(unexp_segs, start=True,
                                      fr_type="stim")["start_frame_stim"]
    # get data dataframes
    pup_data = gen_util.reshape_df_data(stim.get_pup_diam_data(
        unexp_twopfr,
        ran_s["pup_pre"],
        ran_s["pup_post"],
        rem_bad=rem_bad,
        scale=scale)["pup_diam"],
                                        squeeze_cols=True)

    datasets = [pup_data]
    datanames = ["pup"]
    if datatype in ["roi", "both"]:
        # ROI x trial x fr
        roi_data = gen_util.reshape_df_data(stim.get_roi_data(
            unexp_twopfr,
            ran_s["roi_pre"],
            ran_s["roi_post"],
            fluor=fluor,
            integ=False,
            rem_bad=rem_bad,
            scale=scale,
            transients=trans_all)["roi_traces"],
                                            squeeze_cols=True)
        datasets.append(roi_data.transpose([1, 2, 0]))  # ROIs last
        datanames.append("roi")
    if datatype in ["run", "both"]:
        run_data = gen_util.reshape_df_data(stim.get_run_data(
            unexp_stimfr,
            ran_s["run_pre"],
            ran_s["run_post"],
            rem_bad=rem_bad,
            scale=scale),
                                            squeeze_cols=True)
        datasets.append(run_data)
        datanames.append("run")

    if rem_bad:
        nanpolgen = None
    else:
        nanpolgen = "omit"

    if returns in ["diff", "both"]:
        for key in ran_s.keys():
            if "pre" in key and ran_s[key] == 0:
                raise ValueError(
                    "Cannot set pre to 0 if returns is 'diff' or 'both'.")
        # get avg for first and second halves
        diffs = []
        for dataset, name in zip(datasets, datanames):
            if name == "pup":
                nanpol = "omit"
            else:
                nanpol = nanpolgen
            n_fr = dataset.shape[1]
            pre_s = ran_s[f"{name}_pre"]
            post_s = ran_s[f"{name}_post"]
            split = int(np.round(pre_s / (pre_s + post_s) *
                                 n_fr))  # find 0 mark
            pre = math_util.mean_med(dataset[:, :split], stats, 1, nanpol)
            post = math_util.mean_med(dataset[:, split:], stats, 1, nanpol)
            diffs.append(post - pre)

    if returns == "data":
        return datasets
    elif returns == "diff":
        return diffs
    elif returns == "both":
        return datasets, diffs
    else:
        gen_util.accepted_values_error("returns", returns,
                                       ["data", "diff", "both"])
コード例 #13
0
def run_pupil_diff_corr(sessions,
                        analysis,
                        analyspar,
                        sesspar,
                        stimpar,
                        figpar,
                        datatype="roi"):
    """
    run_pupil_diff_corr(sessions, analysis, analyspar, sesspar, 
                        stimpar, figpar)
    
    Calculates and plots between pupil and ROI/running changes
    locked to each unexpected, as well as the correlation.

    Saves results and parameters relevant to analysis in a dictionary.

    Required args:
        - sessions (list)      : list of Session objects
        - analysis (str)       : analysis type (e.g., "c")
        - analyspar (AnalysPar): named tuple containing analysis parameters
        - sesspar (SessPar)    : named tuple containing session parameters
        - stimpar (StimPar)    : named tuple containing stimulus parameters
        - figpar (dict)        : dictionary containing figure parameters
    
    Optional args:
        - datatype (str): type of data (e.g., "roi", "run")
    """

    sessstr_pr = sess_str_util.sess_par_str(sesspar.sess_n, stimpar.stimtype,
                                            sesspar.plane, stimpar.visflow_dir,
                                            stimpar.visflow_size, stimpar.gabk,
                                            "print")
    dendstr_pr = sess_str_util.dend_par_str(analyspar.dend, sesspar.plane,
                                            datatype, "print")

    datastr = sess_str_util.datatype_par_str(datatype)

    logger.info(
        "Analysing and plotting correlations between unexpected vs "
        f"expected {datastr} traces between sessions ({sessstr_pr}"
        f"{dendstr_pr}).",
        extra={"spacing": "\n"})

    sess_diffs = []
    sess_corr = []

    for sess in sessions:
        if datatype == "roi" and (sess.only_tracked_rois != analyspar.tracked):
            raise RuntimeError(
                "sess.only_tracked_rois should match analyspar.tracked.")
        diffs = peristim_data(sess,
                              stimpar,
                              datatype=datatype,
                              returns="diff",
                              scale=analyspar.scale,
                              first_unexp=True)
        [pup_diff, data_diff] = diffs
        # trials (x ROIs)
        if datatype == "roi":
            if analyspar.rem_bad:
                nanpol = None
            else:
                nanpol = "omit"
            data_diff = math_util.mean_med(data_diff,
                                           analyspar.stats,
                                           axis=-1,
                                           nanpol=nanpol)
        elif datatype != "run":
            gen_util.accepted_values_error("datatype", datatype,
                                           ["roi", "run"])
        sess_corr.append(np.corrcoef(pup_diff, data_diff)[0, 1])
        sess_diffs.append([diff.tolist() for diff in [pup_diff, data_diff]])

    extrapar = {
        "analysis": analysis,
        "datatype": datatype,
    }

    corr_data = {"corrs": sess_corr, "diffs": sess_diffs}

    sess_info = sess_gen_util.get_sess_info(sessions,
                                            analyspar.fluor,
                                            incl_roi=(datatype == "roi"),
                                            rem_bad=analyspar.rem_bad)

    info = {
        "analyspar": analyspar._asdict(),
        "sesspar": sesspar._asdict(),
        "stimpar": stimpar._asdict(),
        "extrapar": extrapar,
        "sess_info": sess_info,
        "corr_data": corr_data
    }

    fulldir, savename = pup_plots.plot_pup_diff_corr(figpar=figpar, **info)

    file_util.saveinfo(info, savename, fulldir, "json")
コード例 #14
0
def add_stim_pop_stats(stim_stats_df,
                       sessions,
                       analyspar,
                       stimpar,
                       permpar,
                       comp_sess=[1, 3],
                       in_place=False,
                       randst=None):
    """
    add_stim_pop_stats(stim_stats_df, sessions, analyspar, stimpar, permpar)

    Adds to dataframe comparison of absolute fractional data changes 
    between sessions for different stimuli, calculated for population 
    statistics.

    Required args:
        - stim_stats_df (pd.DataFrame):
            dataframe with one row per line/plane, and the basic sess_df 
            columns, as well as stimulus columns for each comp_sess:
            - {stimpar.stimtype}_s{comp_sess[0]}: 
                first comp_sess data for each ROI
            - {stimpar.stimtype}_s{comp_sess[1]}: 
                second comp_sess data for each ROI
        - sessions (list): 
            session objects
        - analyspar (AnalysPar): 
            named tuple containing analysis parameters
        - stimpar (StimPar): 
            named tuple containing stimulus parameters
        - permpar (PermPar): 
            named tuple containing permutation parameters

    Optional args:
        - comp_sess (int):
            sessions for which to obtain absolute fractional change 
            [x, y] => |(y - x) / x|
            default: [1, 3]
        - in_place (bool):
            if True, targ_df is modified in place. Otherwise, a deep copy is 
            modified. targ_df is returned in either case.
            default: False
        - randst (int or np.random.RandomState): 
            random state or seed value to use. (-1 treated as None)
            default: None

    Returns:
        - stim_stats_df (pd.DataFrame):
            dataframe with one row per line/plane and one for all line/planes 
            together, and the basic sess_df columns, in addition to the input 
            columns, and for each stimtype:
            - {stimtype} (list): absolute fractional change statistics (me, err)
            - p_vals (float): p-value for data differences between stimulus 
                types, corrected for multiple comparisons and tails
    """

    nanpol = None if analyspar.rem_bad else "omit"

    if analyspar.tracked:
        misc_analys.check_sessions_complete(sessions, raise_err=False)

    if not in_place:
        stim_stats_df = stim_stats_df.copy(deep=True)

    stimtypes = gen_util.list_if_not(stimpar.stimtype)
    stim_stats_df = gen_util.set_object_columns(stim_stats_df,
                                                stimtypes,
                                                in_place=True)

    if analyspar.stats != "mean" or analyspar.error != "std":
        raise NotImplementedError("For population statistics analysis, "
                                  "analyspar.stats must be set to 'mean', and "
                                  "analyspar.error must be set to 'std'.")

    # initialize arrays for all data
    n_linpla = len(stim_stats_df)
    n_stims = len(stimpar.stimtype)
    n_bootstrp = misc_analys.N_BOOTSTRP

    all_stats = np.full((n_linpla, n_stims), np.nan)
    all_btstrap_stats = np.full((n_linpla, n_stims, n_bootstrp), np.nan)
    all_rand_stat_diffs = np.full((n_linpla, permpar.n_perms), np.nan)

    for i, row_idx in enumerate(stim_stats_df.index):
        full_comp_data = [[], []]
        for s, stimtype in enumerate(stimpar.stimtype):
            comp_data, btstrap_comp_data = [], []
            choices = None
            for n in comp_sess:
                data_col = f"{stimtype}_s{n}"

                # get data
                data = stim_stats_df.loc[row_idx, data_col]

                # get session stats
                comp_data.append(
                    math_util.mean_med(data, analyspar.stats, nanpol=nanpol))

                # get bootstrapped data
                returns = rand_util.bootstrapped_std(
                    data,
                    randst=randst,
                    n_samples=n_bootstrp,
                    return_rand=True,
                    return_choices=analyspar.tracked,
                    choices=choices,
                    nanpol=nanpol)

                btstrap_data = returns[1]
                if analyspar.tracked:
                    choices = returns[-1]  # use same choices across sessions

                btstrap_comp_data.append(btstrap_data)
                full_comp_data[s].append(data)  # retain full data

            # compute absolute fractional change stats (bootstrapped std)
            all_stats[i, s] = abs_fractional_diff(comp_data)
            all_btstrap_stats[i, s] = abs_fractional_diff(btstrap_comp_data)
            error = np.std(all_btstrap_stats[i, s])

            # add to dataframe
            stim_stats_df.at[row_idx, stimtype] = [all_stats[i, s], error]

        # obtain p-values for real data wrt random data
        stim_stat_diff = all_stats[i, 1] - all_stats[i, 0]

        # permute data for each session across stimtypes
        sess_rand_stats = []  # sess x stim
        for j in range(len(comp_sess)):
            rand_concat = [stim_data[j] for stim_data in full_comp_data]
            rand_concat = np.stack(rand_concat).T
            rand_stats = rand_util.permute_diff_ratio(
                rand_concat,
                div=None,
                n_perms=permpar.n_perms,
                stats=analyspar.stats,
                op="none",
                paired=True,  # pair stimuli
                nanpol=nanpol,
                randst=randst)
            sess_rand_stats.append(rand_stats)

        # obtain stats per stimtypes, then differences between stimtypes
        stim_rand_stats = list(zip(*sess_rand_stats))  # stim x sess
        all_rand_stats = []
        for rand_stats in stim_rand_stats:
            all_rand_stats.append(abs_fractional_diff(rand_stats))
        all_rand_stat_diffs[i] = all_rand_stats[1] - all_rand_stats[0]

        # calculate p-value
        p_val = rand_util.get_p_val_from_rand(stim_stat_diff,
                                              all_rand_stat_diffs[i],
                                              tails=permpar.tails,
                                              nanpol=nanpol)
        stim_stats_df.loc[row_idx, "p_vals"] = p_val

    # collect stats for all line/planes
    row_idx = len(stim_stats_df)
    for col in stim_stats_df.columns:
        stim_stats_df.loc[row_idx, col] = "all"

    # average across line/planes
    all_data = []
    for data in [all_stats, all_btstrap_stats, all_rand_stat_diffs]:
        all_data.append(
            math_util.mean_med(data, analyspar.stats, nanpol=nanpol, axis=0))
    stat, btstrap_stats, rand_stat_diffs = all_data

    for s, stimtype in enumerate(stimpar.stimtype):
        error = np.std(btstrap_stats[s])
        stim_stats_df.at[row_idx, stimtype] = [stat[s], error]

    p_val = rand_util.get_p_val_from_rand(stat[1] - stat[0],
                                          rand_stat_diffs,
                                          tails=permpar.tails,
                                          nanpol=nanpol)
    stim_stats_df.loc[row_idx, "p_vals"] = p_val

    return stim_stats_df
コード例 #15
0
def get_rel_resp_stats_df(sessions,
                          analyspar,
                          stimpar,
                          permpar,
                          rel_sess=1,
                          randst=None,
                          parallel=False):
    """
    get_rel_resp_stats_df(sessions, analyspar, stimpar, permpar)

    Returns relative response stats dataframe for requested sessions.

    Required args:
        - sessions (list): 
            session objects
        - analyspar (AnalysPar): 
            named tuple containing analysis parameters
        - stimpar (StimPar): 
            named tuple containing stimulus parameters
        - permpar (PermPar): 
            named tuple containing permutation parameters

    Optional args:
        - rel_sess (int):
            number of session relative to which data should be scaled, for each 
            mouse
            default: 1
        - randst (int or np.random.RandomState): 
            random state or seed value to use. (-1 treated as None)
            default: None
        - parallel (bool): 
            if True, some of the analysis is run in parallel across CPU cores 
            default: False

    Returns:
            dataframe with one row per session/line/plane, and the following 
            columns, in addition to the basic sess_df columns: 
            - rel_reg or rel_exp (list): data stats for regular data (me, err)
            - rel_unexp (list): data stats for unexpected data (me, err)
            for reg/exp/unexp data types, session comparisons, e.g. 1v2:
            - {data_type}_raw_p_vals_{}v{} (float): uncorrected p-value for 
                data differences between sessions 
            - {data_type}_p_vals_{}v{} (float): p-value for data between 
                sessions, corrected for multiple comparisons and tails
    """

    nanpol = None if analyspar.rem_bad else "omit"

    initial_columns = misc_analys.get_sess_df_columns(sessions[0], analyspar)

    resp_data_df = get_resp_df(sessions,
                               analyspar,
                               stimpar,
                               rel_sess=rel_sess,
                               parallel=parallel)

    # prepare target dataframe
    source_cols = ["rel_exp", "rel_unexp"]
    if stimpar.stimtype == "gabors":
        # regular means only A, B, C are included
        targ_cols = ["rel_reg", "rel_unexp"]
    else:
        targ_cols = ["rel_exp", "rel_unexp"]
    rel_resp_data_df = pd.DataFrame(columns=initial_columns + targ_cols)

    group_columns = ["lines", "planes"]
    aggreg_cols = [
        col for col in initial_columns
        if col not in group_columns + ["sess_ns"]
    ]
    for grp_vals, resp_grp_df in resp_data_df.groupby(group_columns):
        sess_ns = sorted(resp_grp_df["sess_ns"].unique())

        # take stats across frame types
        for e, (data_col, source_col) in enumerate(zip(targ_cols,
                                                       source_cols)):
            sess_data = []
            if e == 0:
                row_indices = []
            for s, sess_n in enumerate(sess_ns):
                sess_grp_df = resp_grp_df.loc[resp_grp_df["sess_ns"] == sess_n]
                sess_grp_df = sess_grp_df.sort_values("mouse_ns")
                if e == 0:
                    row_idx = len(rel_resp_data_df)
                    row_indices.append(row_idx)
                    rel_resp_data_df.loc[row_idx, "sess_ns"] = sess_n
                    for g, group_column in enumerate(group_columns):
                        rel_resp_data_df.loc[row_idx,
                                             group_column] = grp_vals[g]

                    # add aggregated values for initial columns
                    rel_resp_data_df = misc_analys.aggreg_columns(
                        sess_grp_df,
                        rel_resp_data_df,
                        aggreg_cols,
                        row_idx=row_idx,
                        in_place=True)
                else:
                    row_idx = row_indices[s]

                if stimpar.stimtype == "gabors":
                    # average across Gabor frames included in reg or unexp data
                    cols = [f"{source_col}_{fr}" for fr in stimpar.gabfr[e]]
                    data = sess_grp_df[cols].values.tolist()
                    # sess x frs x ROIs -> sess x ROIs
                    data = [
                        math_util.mean_med(sub,
                                           stats=analyspar.stats,
                                           axis=0,
                                           nanpol=nanpol) for sub in data
                    ]
                else:
                    # sess x ROIs
                    data = sess_grp_df[source_col].tolist()

                data = np.concatenate(data, axis=0)

                # take stats across ROIs, grouped
                rel_resp_data_df.at[row_idx, data_col] = \
                    math_util.get_stats(
                        data,
                        stats=analyspar.stats,
                        error=analyspar.error,
                        nanpol=nanpol
                        ).tolist()

                sess_data.append(data)  # for p-value calculation

            # calculate p-values between sessions (0-1, 0-2, 1-2...)
            p_vals = rand_util.comp_vals_acr_groups(sess_data,
                                                    n_perms=permpar.n_perms,
                                                    stats=analyspar.stats,
                                                    paired=analyspar.tracked,
                                                    nanpol=nanpol,
                                                    randst=randst)
            p = 0
            for i, sess_n in enumerate(sess_ns):
                for j, sess_n2 in enumerate(sess_ns[i + 1:]):
                    key = f"{data_col}_p_vals_{int(sess_n)}v{int(sess_n2)}"
                    rel_resp_data_df.loc[row_indices[i], key] = p_vals[p]
                    rel_resp_data_df.loc[row_indices[j + 1], key] = p_vals[p]
                    p += 1

    rel_resp_data_df["sess_ns"] = rel_resp_data_df["sess_ns"].astype(int)

    # corrected p-values
    rel_resp_data_df = misc_analys.add_corr_p_vals(rel_resp_data_df, permpar)

    return rel_resp_data_df
コード例 #16
0
def get_block_data(sess, analyspar, stimpar, datatype="roi", integ=False):
    """
    get_block_data(sess, analyspar, stimpar)

    Returns data statistics split by expected/unexpected sequences, and by 
    blocks, where one block is defined as consecutive expected sequences, and 
    the subsequent consecutive unexpected sequences.

    Required args:
        - sess (Session):
            Session object
        - analyspar (AnalysPar): 
            named tuple containing analysis parameters
        - stimpar (StimPar): 
            named tuple containing stimulus parameters

    Optional args:
        - datatype (str):
            type of data to return ("roi", "run" or "pupil")
            default: "roi"
        - integ (bool):
            if True, data is integrated across frames, instead of a statistic 
            being taken
            default: False

    Returns:
        - block_data (3 or 4D array):
            data statistics across sequences per block
            dims: split x block (x ROIs) x stats (me, err)
    """

    stim = sess.get_stim(stimpar.stimtype)

    nanpol = None if analyspar.rem_bad else "omit"

    ch_fl = [stimpar.pre, stimpar.post]

    by_exp_fr_ns = []
    by_exp_data = []
    for unexp in [0, 1]:
        segs = stim.get_segs_by_criteria(gabfr=stimpar.gabfr,
                                         gabk=stimpar.gabk,
                                         gab_ori=stimpar.gab_ori,
                                         visflow_dir=stimpar.visflow_dir,
                                         visflow_size=stimpar.visflow_size,
                                         unexp=unexp,
                                         remconsec=False,
                                         by="seg")

        fr_ns, fr_type = get_frame_numbers(stim,
                                           segs,
                                           ch_fl=ch_fl,
                                           ref_type="segs",
                                           datatype=datatype)

        by_exp_fr_ns.append(np.asarray(fr_ns))

        data, _ = get_data(stim,
                           fr_ns,
                           analyspar,
                           pre=stimpar.pre,
                           post=stimpar.post,
                           integ=integ,
                           datatype=datatype,
                           ref_type=fr_type)

        if not integ:  # take statistic across frames
            with gen_util.TempWarningFilter("Mean of empty", RuntimeWarning):
                data = math_util.mean_med(data,
                                          stats=analyspar.stats,
                                          axis=-1,
                                          nanpol=nanpol)

        by_exp_data.append(data)

    # take means per block
    block_idxs = split_seqs_by_block(by_exp_fr_ns)

    n_splits = len(by_exp_data)
    n_blocks = len(block_idxs[0])
    n_stats = 2
    if analyspar.stats == "median" and analyspar.error == "std":
        n_stats = 3

    targ_shape = (n_splits, n_blocks, n_stats)
    if datatype == "roi":
        n_rois = sess.get_nrois(analyspar.rem_bad, analyspar.fluor)
        targ_shape = (n_splits, n_blocks, n_rois, n_stats)

    block_data = np.full(targ_shape, np.nan)
    for b, seq_idxs in enumerate(zip(*block_idxs)):
        for d, data_seq_idxs in enumerate(seq_idxs):
            # take stats across sequences within each split/block
            block_data[d, b] = math_util.get_stats(
                by_exp_data[d][..., data_seq_idxs],
                stats=analyspar.stats,
                error=analyspar.error,
                nanpol=nanpol,
                axes=-1  # sequences within 
            ).T

    return block_data
コード例 #17
0
def plot_pupil_run_block_diffs(block_df, analyspar, permpar, figpar, 
                               title=None, seed=None):
    """
    plot_pupil_run_trace_stats(trace_df, analyspar, permpar, figpar)

    Plots pupil and running block differences.

    Required args:
        - block_df (pd.DataFrame):
            dataframe with one row per session/line/plane, and the following 
            columns, in addition to the basic sess_df columns: 
            - run_block_diffs (list): 
                running velocity differences per block
            - run_raw_p_vals (float):
                uncorrected p-value for differences within sessions
            - run_p_vals (float):
                p-value for differences within sessions, 
                corrected for multiple comparisons and tails
            - pupil_block_diffs (list): 
                for pupil diameter differences per block
            - pupil_raw_p_vals (list):
                uncorrected p-value for differences within sessions
            - pupil_p_vals (list):
                p-value for differences within sessions, 
                corrected for multiple comparisons and tails

        - analyspar (dict): 
            dictionary with keys of AnalysPar namedtuple
        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - title (str):
            plot title
            default: None
        - seed (int): 
            seed value to use. (-1 treated as None)
            default: None

    Returns:
        - ax (2D array): 
            array of subplots
    """

    if analyspar["scale"]:
        raise NotImplementedError(
            "Expected running and pupil data to not be scaled."
            )

    if len(block_df["sess_ns"].unique()) != 1:
        raise NotImplementedError(
            "'block_df' should only contain one session number."
        )

    nanpol = None if analyspar["rem_bad"] else "omit"
    
    sensitivity = misc_analys.get_sensitivity(permpar)
    comp_info = misc_analys.get_comp_info(permpar)

    datatypes = ["run", "pupil"]
    datatype_strs = ["Running velocity", "Pupil diameter"]
    n_datatypes = len(datatypes)

    fig, ax = plt.subplots(
        1, n_datatypes, figsize=(12.7, 4), squeeze=False, 
        gridspec_kw={"wspace": 0.22}
        )

    if title is not None:
        fig.suptitle(title, y=1.2, weight="bold")

    logger.info(f"{comp_info}:", extra={"spacing": "\n"})
    corr_str = "corr." if permpar["multcomp"] else "raw"

    for d, datatype in enumerate(datatypes):
        datatype_sig_str = f"{datatype_strs[d]:16}:"
        sub_ax = ax[0, d]

        lp_names = [None for _ in range(plot_helper_fcts.N_LINPLA)]
        xs, all_data, cols, dashes, p_val_texts = [], [], [], [], []
        for (line, plane), lp_df in block_df.groupby(["lines", "planes"]):
            x, col, dash = plot_helper_fcts.get_line_plane_idxs(
                line, plane, flat=True
                )
            line_plane_name = plot_helper_fcts.get_line_plane_name(
                line, plane
                )
            lp_names[int(x)] = line_plane_name

            if len(lp_df) == 1:
                row_idx = lp_df.index[0]
            elif len(lp_df) > 1:
                raise RuntimeError("Expected 1 row per line/plane/session.")
        
            lp_data = lp_df.loc[row_idx, f"{datatype}_block_diffs"]
            
            # get p-value information
            p_val_corr = lp_df.loc[row_idx, f"{datatype}_p_vals"]

            side = np.sign(
                math_util.mean_med(
                    lp_data, stats=analyspar["stats"], nanpol=nanpol
                    )
                )
            sig_str = misc_analys.get_sig_symbol(
                p_val_corr, sensitivity=sensitivity, side=side, 
                tails=permpar["tails"], p_thresh=permpar["p_val"]
                )

            p_val_text = f"{p_val_corr:.2f}{sig_str}"

            datatype_sig_str = (
                f"{datatype_sig_str}{TAB}{line_plane_name}: "
                f"{p_val_corr:.5f}{sig_str:3}"
                )

            # collect information
            xs.append(x)
            all_data.append(lp_data)
            cols.append(col)
            dashes.append(dash)
            p_val_texts.append(p_val_text)

        plot_violin_data(
            sub_ax, xs, all_data, palette=cols, dashes=dashes, seed=seed
            )
        
        # edit ticks
        sub_ax.set_xticks(range(plot_helper_fcts.N_LINPLA))
        sub_ax.set_xticklabels(lp_names, fontweight="bold")
        sub_ax.tick_params(axis="x", which="both", bottom=False) 

        plot_util.expand_lims(sub_ax, axis="y", prop=0.1)

        plot_util.set_interm_ticks(
            np.asarray(sub_ax), n_ticks=3, axis="y", share=False, 
            fontweight="bold", update_ticks=True
            )

        for i, (x, p_val_text) in enumerate(zip(xs, p_val_texts)):
            ylim_range = np.diff(sub_ax.get_ylim())
            y = sub_ax.get_ylim()[1] + ylim_range * 0.08
            ha = "center"
            if d == 0 and i == 0:
                x += 0.2
                p_val_text = f"{corr_str} p-val. {p_val_text}"
                ha = "right"
            sub_ax.text(
                x, y, p_val_text, fontsize=20, weight="bold", ha=ha
                )

        logger.info(datatype_sig_str, extra={"spacing": TAB})

    # add labels/titles
    for d, datatype in enumerate(datatypes):
        sub_ax = ax[0, d]
        sub_ax.axhline(
            y=0, ls=plot_helper_fcts.HDASH, c="k", lw=3.0, alpha=0.5
            )
     
        if d == 0:
            ylabel = "Trial differences\nU-G - D-G"
            sub_ax.set_ylabel(ylabel, weight="bold")    
        
        if datatype == "run":
            title = "Running velocity (cm/s)"
        elif datatype == "pupil":
            title = "Pupil diameter (mm)"

        sub_ax.set_title(title, weight="bold", y=1.2)
        
    return ax