예제 #1
0
def get_signif_rois(integ_data,
                    permpar,
                    stats="mean",
                    op="diff",
                    nanpol=None,
                    log_rois=True):
    """
    get_signif_rois(integ_data, permpar)

    Identifies ROIs showing significant unexpected responses in specified quantiles,
    groups accordingly and retrieves statistics for each group.

    Required args:
        - integ_data (list): list of 2D array of ROI activity integrated 
                             across frames.
                                unexp (0, 1) x array[ROI x sequences]
        - permpar (PermPar): named tuple containing permutation parameters
                             (multcomp does not apply to identifiying 
                             significant ROIs)

    Optional args:
        - stats (str)    : statistic parameter, i.e. "mean" or "median"
                           default: "mean"
        - op (str)       : operation to identify significant ROIs
                           default: "diff"
        - nanpol (str)   : policy for NaNs, "omit" or None when taking 
                           statistics
                           default: None
        - log_rois (bool): if True, the indices of significant ROIs and
                           their actual difference values are logged

    Returns: 
        - sign_rois (list): list of ROIs showing significant differences, or 
                            list of lists if 2-tailed analysis [lo, up].
    """

    n_exp = integ_data[1].shape[1]
    # calculate real values (average across seqs)
    data = [
        math_util.mean_med(integ_data[0], stats, axis=1, nanpol=nanpol),
        math_util.mean_med(integ_data[1], stats, axis=1, nanpol=nanpol)
    ]
    # ROI x seq
    qu_data_res = math_util.calc_op(np.asarray(data), op, dim=0)
    # concatenate unexp and exp from quantile
    qu_data_all = np.concatenate(integ_data, axis=1)
    # run permutation to identify significant ROIs
    all_rand_res = rand_util.permute_diff_ratio(qu_data_all, n_exp,
                                                permpar.n_perms, stats, nanpol,
                                                op)
    sign_rois = rand_util.id_elem(all_rand_res,
                                  qu_data_res,
                                  permpar.tails,
                                  permpar.p_val,
                                  log_elems=log_rois)
    return sign_rois
예제 #2
0
def grp_traces_by_qu_unexp_sess(trace_data, analyspar, roigrppar,
                                all_roi_grps):
    """
    grp_traces_by_qu_unexp_sess(trace_data, analyspar, roigrppar, all_roi_grps)
                               
    Required args:
        - trace_data (list)    : list of 4D array of mean/medians traces 
                                 for each session, structured as:
                                    unexp x quantiles x ROIs x frames
        - analyspar (AnalysPar): named tuple containing analysis parameters
        - roigrppar (RoiGrpPar): named tuple containing roi grouping parameters
        - all_roi_grps (list)  : list of sublists per session, each containing
                                 sublists per roi grp with ROI numbers included 
                                 in the group: session x roi_grp

    Returns:
        - grp_stats (list): nested list of statistics for ROI groups 
                            structured as:
                                sess x qu x ROI grp x stats x frame
    """

    # calculate diff/ratio or retrieve exp/unexp
    op = roigrppar.op
    if roigrppar.plot_vals in ["exp", "unexp"]:
        op = ["exp", "unexp"].index(roigrppar.plot_vals)
    data_me = [math_util.calc_op(sess_me, op, dim=0) for sess_me in trace_data]

    n_sesses = len(data_me)
    n_quants = data_me[0].shape[0]
    n_stats = 2 + (analyspar.stats == "median" and analyspar.error == "std")

    n_frames = [me.shape[2] for me in data_me]

    # sess x quantile (first/last) x ROI grp
    empties = [np.empty([n_stats, n_fr]) * np.nan for n_fr in n_frames]
    grp_stats = [[[] for _ in range(n_quants)] for _ in range(n_sesses)]
    for i, sess in enumerate(data_me):
        for q, quant in enumerate(sess):
            for g, grp_rois in enumerate(all_roi_grps[i]):
                # leave NaNs if no ROIs in group
                if len(grp_rois) != 0:
                    grp_st = math_util.get_stats(quant[grp_rois],
                                                 analyspar.stats,
                                                 analyspar.error,
                                                 axes=0)
                else:
                    grp_st = empties[i]
                grp_stats[i][q].append(grp_st.tolist())

    return grp_stats
예제 #3
0
def get_corr_data(sess_pair, data_df, analyspar, permpar, 
                  corr_type="corr", permute="sess", absolute=False, norm=True, 
                  return_data=False, return_rand=False, n_rand_ex=1, 
                  randst=None, raise_no_pair=True):
    """
    get_corr_data(sess_pair, data_df, analyspar, permpar)

    Returns correlation data for a session pair.

    Required args:
        - sess_pair (list):
            sessions to correlate, e.g. [1, 2]
        - data_df (pd.DataFrame):
            dataframe with one row per line/plane/session, and the following 
            columns, in addition to the basic sess_df columns:
            - roi_idxs (list): index for each ROI
        - analyspar (AnalysPar): 
            named tuple containing analysis parameters
        - permpar (PermPar): 
            named tuple containing permutation parameters.

    Optional args:
        - corr_type (str):
            type of correlation to run, i.e. "corr" or "R_sqr"
            default: "corr"
        - permute (str):
            type of permutation to due ("tracking", "sess" or "all")
            default: "sess"
        - absolute (bool):
            if True, absolute USIs are used for correlation calculation instead 
            of signed USIs
            default: False
        - norm (bool):
            if True, normalized correlation data is returned, if corr_type if 
            "diff_corr"
            default: True
        - return_data (bool):
            if True, data to correlate is returned
            default: False
        - return_rand (bool):
            if True, random normalized correlation values are returned, along 
            with random data to correlate for one example permutation
            default: False
        - n_rand_ex (int):
            number of examples to return, if return_rand is True
            default: 1
        - randst (int or np.random.RandomState): 
            random state or seed value to use. (-1 treated as None)
            default: None
        - raise_no_pair (bool):
            if True, if sess_pair session numbers are not found, an error is 
            raised. Otherwise, None is returned.
            default: True

    Returns:
        - roi_corr (float):
             (normalized) correlation between sessions
        - roi_corr_std (float):
            bootstrapped standard deviation for the (normalized) correlation 
            between sessions
        - null_CI (1D array):
            adjusted, null CI for the (normalized) correlation between sessions
        - p_val (float):
            uncorrected p-value for the correlation between sessions
        
        if return_data:
        - corr_data (2D array):
            data to correlate (grps (2) x datapoints)            
        
        if return_rand:
        - rand_corrs (1D array):
            (normalized) random correlation between sessions
        - rand_ex (3D array):
            example randomized data pairs to correlate 
            (grps (2) x datapoints x n_rand_ex)
        - rand_ex_corr (1D array):
            correlation for example randomized data pairs
    """

    nanpol = None if analyspar.rem_bad else "omit"

    if analyspar.stats != "mean" or analyspar.error != "std":
        raise NotImplementedError(
            "analyspar.stats must be set to 'mean', and "
            "analyspar.error must be set to 'std'."
            )

    roi_idxs = []
    for sess_n in sess_pair:
        row = data_df.loc[data_df["sess_ns"] == sess_n]
        if len(row) < 1:
            continue
        elif len(row) > 1:
            raise RuntimeError("Expected at most one row.")

        data = np.asarray(row.loc[row.index[0], "roi_idxs"])
        roi_idxs.append(data)

    if len(roi_idxs) != 2:
        if raise_no_pair:
            raise RuntimeError("Session pairs not found.")
        else:
            return None

    if roi_idxs[0].shape != roi_idxs[1].shape:
        raise RuntimeError(
            "Sessions should have the same number of ROI indices."
            )

    # get updated correlation parameters
    corr_type, paired, norm = get_corr_info(
        permpar, corr_type=corr_type, permute=permute, norm=norm
        )

    # check correlation type and related parameters
    corr_data = np.vstack([roi_idxs[0], roi_idxs[1]]) # 2 x datapoints

    if absolute:
        corr_data = np.absolute(corr_data)

    # get actual correlation
    roi_corr = math_util.calc_op(corr_data, nanpol=nanpol, op=corr_type)

    # get first set of random values
    if return_rand:
        use_randst = copy.deepcopy(randst)
        if paired:
            perm_data = corr_data.T # groups x datapoints (2)
        else:
            perm_data = corr_data.reshape(1, -1) # 2 groups concatenate
        rand_exs = rand_util.run_permute(
            perm_data, n_perms=n_rand_ex, paired=paired, randst=use_randst
            )
        rand_exs = np.transpose(rand_exs, [1, 0, 2])
        if not paired:
            rand_exs = rand_exs.reshape(2, -1, n_rand_ex)
        rand_ex_corrs = math_util.calc_op(
            rand_exs, nanpol=nanpol, op=corr_type, axis=1
            )

    # get random correlation info
    returns = rand_util.get_op_p_val(
        corr_data, n_perms=permpar.n_perms, 
        stats=analyspar.stats, op=corr_type, return_CIs=True, 
        p_thresh=permpar.p_val, tails=permpar.tails, 
        multcomp=permpar.multcomp, paired=paired, nanpol=nanpol, 
        return_rand=return_rand, randst=randst
        )
    
    if return_rand:
        p_val, null_CI, rand_corrs = returns
    else:
        p_val, null_CI = returns

    med = null_CI[1]
    null_CI = np.asarray(null_CI)
    if norm:
        # normalize all data
        roi_corr = float(get_norm_corrs(roi_corr, med=med, corr_type=corr_type))
        null_CI = get_norm_corrs(null_CI, med=med, corr_type=corr_type)
    
    # get bootstrapped std over corr
    roi_corr_std = corr_bootstrapped_std(
        corr_data, n_samples=misc_analys.N_BOOTSTRP, randst=randst, 
        return_rand=False, nanpol=nanpol, norm=norm, med=med, 
        corr_type=corr_type
        )

    returns = [roi_corr, roi_corr_std, null_CI, p_val]
    
    if return_data:
        corr_data = np.vstack(corr_data)
        if "diff" in corr_type: # take diff
            corr_data[1] = corr_data[1] - corr_data[0]
        returns = returns + [corr_data]

    if return_rand:
        if norm:
            rand_corrs = get_norm_corrs(
                rand_corrs, med=med, corr_type=corr_type
                )
        if "diff" in corr_type: # take diff
            rand_exs[1] = rand_exs[1] - rand_exs[0]
        returns = returns + [rand_corrs, rand_exs, rand_ex_corrs]
    
    return returns
예제 #4
0
def corr_bootstrapped_std(data, n_samples=1000, randst=None, corr_type="corr", 
                          return_rand=False, nanpol=None, med=0, norm=True):
    """
    corr_bootstrapped_std(data)
    
    Returns bootstrapped standard deviation for Pearson correlations.

    Required args:
        - data (2D array): 
            values to correlate for each of 2 groups (2, n)
    
    Optional args:
        - n (int): 
            number of datapoints in dataset. Required if proportion is True.
            default: None
        - n_samples (int): 
            number of samplings to take for bootstrapping
            default: 1000
        - randst (int or np.random.RandomState): 
            seed or random state to use when generating random values.
            default: None
        - return_rand (bool): if True, random correlations are returned
            default: False
        - nanpol (str): 
            policy for NaNs, "omit" or None
            default: None
        - med (float): 
            null distribution median for normalization, if norm is True
            default: 0
        - norm (bool):
            if True, normalized correlation data is returned
            default: True

    Returns:
        - bootstrapped_std (float): 
            bootstrapped standard deviation of correlations, 
            normalized if norm is True
        if return_rand:
        - rand_corrs (1D array): 
            randomly generated correlations, normalized if norm is True
    """

    randst = rand_util.get_np_rand_state(randst, set_none=True)

    n_samples = int(n_samples)

    data = np.asarray(data)

    if len(data.shape) != 2 or data.shape[0] != 2:
        raise ValueError(
            "data must have 2 dimensions, with the first having length 2."
            )

    n = data.shape[1]

    # random choices
    choices = np.arange(n)

    # random corrs
    rand_corrs = math_util.calc_op(
        list(data[:, randst.choice(choices, (n, n_samples), replace=True)]), 
        op=corr_type, nanpol=nanpol, axis=0,
        )
    
    if norm:
        rand_corrs = get_norm_corrs(rand_corrs, med=med, corr_type=corr_type)

    bootstrapped_std = math_util.error_stat(
        rand_corrs, stats="mean", error="std", nanpol=nanpol
        )
    
    if return_rand:
        return bootstrapped_std, rand_corrs
    else:
        return bootstrapped_std
def qu_mags(all_data, permpar, mouse_ns, lines, stats="mean", error="sem", 
            nanpol=None, op_qu="diff", op_unexp="diff", log_vals=True):
    """
    qu_mags(all_data, permpar, mouse_ns, lines)

    Returns a dictionary containing the results of the magnitudes and L2 
    analysis, as well as the results of the permutation test.

    Specifically, magnitude and L2 norm are calculated as follows: 
        - Magnitude: for unexp and expected segments: 
                         mean/median across ROIs of
                             diff/ratio in average activity between 2 quantiles
        - L2 norm:   for unexp and expected segments: 
                         L2 norm across ROIs of
                             diff/ratio in average activity between 2 quantiles
    
    Significance is assessed based on the diff/ratio between unexpected and 
    expected magnitude/L2 norm results.

    Optionally, the magnitudes and L2 norms are logged for each session, with
    significance indicated.

    Required args:
        - all_data (list)  : nested list of data, structured as:
                                 session x unexp x qu x array[(ROI x) seqs]
        - permpar (PermPar): named tuple containing permutation parameters
        - mouse_ns (list)  : list of mouse numbers (1 per session)
        - lines (list)     : list of mouse lines (1 per session)

    Optional args:
        - stats (str)      : statistic to take across segments, (then ROIs) 
                             ("mean" or "median")
                             default: "mean"
        - error (str)      : statistic to take across segments, (then ROIs) 
                             ("std" or "sem")
                             default: "sem"
        - nanpol (str)     : policy for NaNs, "omit" or None when taking 
                             statistics
                             default: None
        - op_qu (str)      : Operation to use in comparing the last vs first 
                             quantile ("diff" or "ratio")
                             default: "diff"       
        - op_unexp (str)    : Operation to use in comparing the unexpected vs 
                             expected, data ("diff" or "ratio")
                             default: "diff" 
        - log_vals (bool)  : If True, the magnitudes and L2 norms are logged
                             for each session, with significance indicated.

    Returns:
        - mags (dict): dictionary containing magnitude and L2 data to plot.
            ["L2"] (3D array)        : L2 norms, structured as: 
                                           sess x scaled x unexp
            ["mag_st"] (4D array)    : magnitude stats, structured as: 
                                           sess x scaled x unexp x stats
            ["L2_rel_th"] (2D array) : L2 thresholds calculated from 
                                       permutation analysis, structured as:
                                           sess x tail(s)
            ["mag_rel_th"] (2D array): magnitude thresholds calculated from
                                       permutation analysis, structured as:
                                           sess x tail(s)
            ["L2_sig"] (list)        : L2 significance results for each session 
                                       ("hi", "lo" or "no")
            ["mag_sig"] (list)       : magnitude significance results for each 
                                       session 
                                           ("hi", "lo" or "no")
    """


    n_sess = len(all_data)
    n_qu   = len(all_data[0][0])
    scales = [False, True]
    unexps    = ["exp", "unexp"]
    stat_len = 2 + (stats == "median" and error == "std")
    tail_len = 1 + (str(permpar.tails) == "2")

    if n_qu != 2:
        raise ValueError(f"Expected 2 quantiles, but found {n_qu}.")
    if len(unexps) != 2:
        raise ValueError("Expected a length 2 unexpected dim, "
            f"but found length {len(unexps)}.")
    
    mags = {"mag_st": np.empty([n_sess, len(scales), len(unexps), stat_len]),
            "L2"    : np.empty([n_sess, len(scales), len(unexps)])
           }
    
    for lab in ["mag_sig", "L2_sig"]:
        mags[lab] = []
    for lab in ["mag_rel_th", "L2_rel_th"]:
        mags[lab] = np.empty([n_sess, tail_len])

    all_data = copy.deepcopy(all_data)
    for i in range(n_sess):
        logger.info(f"Mouse {mouse_ns[i]}, {lines[i]}:", 
            extra={"spacing": "\n"})
        sess_data_me = []
        # number of expected sequences
        n_exps = [all_data[i][0][q].shape[-1] for q in range(n_qu)]
        for s in range(len(unexps)):
            # take the mean for each quantile
            data_me = np.asarray(
                [math_util.mean_med(all_data[i][s][q], stats, axis=-1, 
                nanpol=nanpol) for q in range(n_qu)])

            if len(data_me.shape) == 1:
                # add dummy ROI-like axis, e.g. for run data
                data_me = data_me[:, np.newaxis]
                all_data[i][s] = \
                    [qu_data[np.newaxis, :] for qu_data in all_data[i][s]]

            msgs = ["Degrees of freedom", "invalid value"]
            categs = [RuntimeWarning, RuntimeWarning]
            with gen_util.TempWarningFilter(msgs, categs):
                mags["mag_st"][i, 0, s] = math_util.calc_mag_change(
                    data_me, 0, 1, order="stats", op=op_qu, stats=stats, 
                    error=error)
                mags["L2"][i, 0, s] = math_util.calc_mag_change(
                    data_me, 0, 1, order=2, op=op_qu)
            sess_data_me.append(data_me)
        # scale
        sess_data_me = np.asarray(sess_data_me)
        mags["mag_st"][i, 1] = math_util.calc_mag_change(
            sess_data_me, 1, 2, order="stats", op=op_qu, stats=stats, 
            error=error, scale=True, axis=0, pos=0, sc_type="unit").T
        mags["L2"][i, 1] = math_util.calc_mag_change(
            sess_data_me, 1, 2, order=2, op=op_qu, stats=stats, scale=True, 
            axis=0, pos=0, sc_type="unit").T
        
        # diff/ratio for permutation test
        act_mag_rel = math_util.calc_op(mags["mag_st"][i, 0, :, 0], op=op_unexp)
        act_L2_rel  = math_util.calc_op(mags["L2"][i, 0, :], op=op_unexp)

        # concatenate expected and unexpected sequences for each quantile
        all_data_perm = [np.concatenate(
            [all_data[i][0][q], all_data[i][1][q]], axis=1) 
                for q in range(n_qu)]
        
        signif, ths = run_mag_permute(
            all_data_perm, act_mag_rel, act_L2_rel, n_exps, permpar, op_qu, 
            op_unexp, stats, nanpol)
        
        mags["mag_sig"].append(signif[0])
        mags["L2_sig"].append(signif[1])
        mags["mag_rel_th"][i] = np.asarray(ths[0])
        mags["L2_rel_th"][i] = np.asarray(ths[1])

        # logs results 
        if log_vals:
            sig_symb = ["", ""]
            for si, sig in enumerate(signif):
                if sig != "no":
                    sig_symb[si] = "*"

            vals = [mags["mag_st"][i, 0, :, 0], mags["L2"][i, 0, :]]
            names = [f"{stats} mag".capitalize(), "L2"]
            for v, (val, name) in enumerate(zip(vals, names)):
                for s, unexp in zip([0, 1], ["(exp) ", "(unexp)"]):
                    logger.info(f"{name} {unexp}: {val[s]:.4f}{sig_symb[v]}", 
                        extra={"spacing": TAB})
        
    return mags
def run_mag_permute(all_data_perm, act_mag_me_rel, act_L2_rel, n_exps, permpar, 
                    op_qu="diff", op_grp="diff", stats="mean", nanpol=None):
    """
    run_mag_permute(all_data_perm, act_mag_rel, act_L2_rel, n_exp, permpar)

    Returns the results of a permutation analysis of difference or ratio 
    between 2 quantiles of the magnitude change or L2 norm between expected and 
    unexpected activity.

    Required args:
        - all_data_perm (2D array): Data from both groups for permutation, 
                                    structured as:
                                        ROI x seqs
        - act_mag_rel (num)       : Real mean/median magnitude difference
                                    between quantiles
        - act_L2_rel (num)        : Real L2 difference between quantiles
        - n_exps (list)           : List of number of expected sequences in
                                    each quantile
        - permpar (PermPar)       : named tuple containing permutation 
                                    parameters
    
    Optional args:
        - op_qu (str) : Operation to use in comparing the last vs first 
                        quantile ("diff" or "ratio")
                        default: "diff"       
        - op_grp (str): Operation to use in comparing groups 
                        (e.g., unexpected vs expected data) ("diff" or "ratio")
                        default: "diff" 
        - stats (str) : Statistic to take across group sequences, and then 
                        across magnitude differences ("mean" or "median")
                        default: "mean"
        - nanpol (str): Policy for NaNs, "omit" or None when taking statistics
                        default: None
    
    Returns:
        - signif (list) : list of significance results ("hi", "lo" or "no") for 
                          magnitude, L2
        - threshs (list): list of thresholds (1 if 1-tailed analysis, 
                          2 if 2-tailed) for magnitude, L2
    """

    if permpar.multcomp:
        permpar = sess_ntuple_util.get_modif_ntuple(
            permpar, ["multcomp", "p_val"], 
            [False, permpar.p_val / permpar.multcomp]
            )

    if len(all_data_perm) != 2 or len(n_exps) !=2:
        raise ValueError("all_data_perm and n_exps must have length of 2.")

    all_rand_vals = [] # qu x grp x ROI x perms
    # for each quantile
    for q, perm_data in enumerate(all_data_perm):
        qu_vals = rand_util.permute_diff_ratio(
            perm_data, n_exps[q], permpar.n_perms, stats, nanpol=nanpol, 
            op="none")
        all_rand_vals.append(qu_vals)

    all_rand_vals = np.asarray(all_rand_vals)
    # get absolute change stats and retain mean/median only
    rand_mag_me = math_util.calc_mag_change(
        all_rand_vals, 0, 2, order="stats", op=op_qu, stats=stats)[0]
    rand_L2 = math_util.calc_mag_change(all_rand_vals, 0, 2, order=2, op=op_qu)

    # take diff/ratio between grps
    rand_mag_rel = math_util.calc_op(rand_mag_me, op_grp, dim=0)
    rand_L2_rel  = math_util.calc_op(rand_L2, op_grp, dim=0)

    # check significance (returns list although only one result tested)
    mag_sign, mag_th = rand_util.id_elem(
        rand_mag_rel, act_mag_me_rel, permpar.tails, permpar.p_val, ret_th=True)
    L2_sign, L2_th   = rand_util.id_elem(
        rand_L2_rel, act_L2_rel, permpar.tails, permpar.p_val, ret_th=True)

    mag_signif, L2_signif = ["no", "no"]
    if str(permpar.tails) == "2":
        if len(mag_sign[0]) == 1:
            mag_signif = "lo"
        elif len(mag_sign[1]) == 1:
            mag_signif = "hi"
        if len(L2_sign[0]) == 1:
            L2_signif = "lo"
        elif len(L2_sign[1]) == 1:
            L2_signif = "hi"
    elif permpar.tails in ["lo", "hi"]:
        if len(mag_sign) == 1:
            mag_signif = permpar.tails
        if len(L2_sign) == 1:
            L2_signif = permpar.tails

    signif  = [mag_signif, L2_signif]
    threshs = [mag_th[0], L2_th[0]]

    return signif, threshs
예제 #7
0
def grp_stats(integ_stats,
              grps,
              plot_vals="both",
              op="diff",
              stats="mean",
              error="std",
              scale=False):
    """
    grp_stats(integ_stats, grps)

    Calculate statistics (e.g. mean + sem) across quantiles for each group 
    and session.

    Required args:
        - integ_stats (list): list of 3D arrays of mean/medians of integrated
                              sequences, for each session structured as:
                                 unexp if by_exp x
                                 quantiles x
                                 ROIs if byroi
        - grps (list)       : list of sublists per session, each containing
                              sublists per roi grp with ROI numbers included in 
                              the group: session x roi_grp
    Optional args:
        - plot_vals (str): which values to return ("unexp", "exp" or "both")
                           default: "both"
        - op (str)       : operation to use to compare groups, if plot_vals
                           is "both"
                           i.e. "diff": grp1-grp2, or "ratio": grp1/grp2
                           default: "diff"
        - stats (str)    : statistic parameter, i.e. "mean" or "median"
                           default: "mean"
        - error (str)    : error statistic parameter, i.e. "std" or "sem"
                           default: "std"
        - scale (bool)   : if True, data is scaled using first quantile
    Returns:
        - all_grp_st (4D array): array of group stats (mean/median, error) 
                                 structured as:
                                  session x quantile x grp x stat 
        - all_ns (2D array)    : array of group ns, structured as:
                                  session x grp
    """

    n_sesses = len(integ_stats)
    n_quants = integ_stats[0].shape[1]
    n_stats = 2 + (stats == "median" and error == "std")
    n_grps = len(grps[0])

    all_grp_st = np.empty([n_sesses, n_quants, n_grps, n_stats])
    all_ns = np.empty([n_sesses, n_grps], dtype=int)

    for i, [sess_data, sess_grps] in enumerate(zip(integ_stats, grps)):
        # calculate diff/ratio or retrieve exp/unexp
        if plot_vals in ["exp", "unexp"]:
            op = ["exp", "unexp"].index(plot_vals)
        sess_data = math_util.calc_op(sess_data, op, dim=0)
        for g, grp in enumerate(sess_grps):
            all_ns[i, g] = len(grp)
            all_grp_st[i, :, g, :] = np.nan
            if len(grp) != 0:
                grp_data = sess_data[:, grp]
                if scale:
                    grp_data, _ = math_util.scale_data(grp_data,
                                                       axis=0,
                                                       pos=0,
                                                       sc_type="unit")
                all_grp_st[i, :, g] = math_util.get_stats(grp_data,
                                                          stats,
                                                          error,
                                                          axes=1).T

    return all_grp_st, all_ns