Exemple #1
0
def trialbytrial_drive(
        mouse,
        trace_type='zscore_day',
        method='ncp_hals',
        cs='',
        warp=False,
        word='tray',
        group_by='all2',
        nan_thresh=0.85,
        score_threshold=0.8,
        drive_type='visual'):
    """
    Create a cells x trials array that contains the log inverse p-values
    of each trial compared to the distributions of the baseline response
    for each cell. Stats us a two tailed KS test.

    This is a wrapped function so that a flow Mouse object is correctly
    passed to the memoizer and MongoDB database. 
    """

    drive_mat = trialbytrial_drive_sub(
        flow.Mouse(mouse=mouse),
        trace_type=trace_type,
        method=method,
        cs=cs,
        warp=warp,
        word=word,
        group_by=group_by,
        nan_thresh=nan_thresh,
        score_threshold=score_threshold,
        drive_type=drive_type)

    return drive_mat
Exemple #2
0
def run_psytrack(mouse, pars=pars_simp_th, include_pavlovian=False):
    """
    Simple function for running psytrack behavioral model on a single mouse
    """

    # upadate parameters to reflect pavlovian kwarg
    pars['include_pavlovian'] = include_pavlovian

    psy = flow.Mouse(mouse=mi).psytracker(verbose=True, force=force, pars=pars)

    return psy
Exemple #3
0
def groupmouse_fit_disengaged_sated_mean_per_comp(
        mice=['OA27', 'OA26', 'OA67', 'VF226', 'CC175'],
        trace_type='zscore_day',
        method='mncp_hals',
        cs='',
        warp=False,
        words=['orlando', 'already', 'already', 'already', 'already'],
        group_by='all',
        nan_thresh=0.85,
        score_threshold=None,
        random_state=None,
        init='rand',
        rank=18,
        verbose=False):
    """
    Wrapper function for fit_disengaged_sated_mean_per_comp. Gets a dataframe
    of all mice.
    """

    mouse_list = []
    for m, w in zip(mice, words):
        mouse = flow.Mouse(mouse=m)
        mouse_list.append(
            fit_disengaged_sated_mean_per_comp(mouse,
                                               word=w,
                                               rank=rank,
                                               nan_thresh=nan_thresh,
                                               score_threshold=score_threshold,
                                               method=method,
                                               cs=cs,
                                               warp=warp,
                                               group_by=group_by,
                                               trace_type=trace_type,
                                               random_state=random_state,
                                               init=init,
                                               verbose=verbose))

    dfmouse = pd.concat(mouse_list, axis=0)

    return dfmouse
Exemple #4
0
def is_center_of_mass_visual(mouse,
                             trace_type='zscore_day',
                             method='mncp_hals',
                             cs='',
                             warp=False,
                             word='restaurant',
                             group_by='all',
                             nan_thresh=0.85,
                             score_threshold=0.8,
                             rank_num=15):
    """
    Check if the center of mass of your temporal factors is beyond the
    offset of the visual stimulus.
    """

    # load your data
    cm_kwargs = {
        'mouse': flow.Mouse(mouse=mouse),
        'method': method,
        'cs': cs,
        'warp': warp,
        'word': word,
        'group_by': group_by,
        'nan_thresh': nan_thresh,
        'score_threshold': score_threshold,
        'rank_num': rank_num
    }
    cm_df = center_of_mass_tempofac(**cm_kwargs)

    # set the stimulus offset time (stimulus length + baseline length)
    if mouse in ['OA32', 'OA34', 'OA37', 'OA36', 'CB173', 'AS20', 'AS41']:
        off_time = 2 + 1  # add 1 for the second before stimulus onset
    else:
        off_time = 3 + 1

    return cm_df['center_of_mass'].values <= 15 * off_time
Exemple #5
0
def varex_norm_bycomp_byday(mouse,
                            trace_type='zscore_day',
                            method='mncp_hals',
                            cs='',
                            warp=False,
                            word='orlando',
                            group_by='all',
                            nan_thresh=0.85,
                            rank_num=18,
                            rectified=True,
                            verbose=False):

    # necessary parameters for determining type of analysis
    pars = {'trace_type': trace_type, 'cs': cs, 'warp': warp}
    group_pars = {'group_by': group_by}

    # save dir
    # if cells were removed with too many nan trials
    if nan_thresh:
        nt_tag = '_nantrial' + str(nan_thresh)
        nt_save_tag = ' nantrial ' + str(nan_thresh)
    else:
        nt_tag = ''
        nt_save_tag = ''
    # save tag for rectification
    if rectified:
        r_tag = ' rectified'
        r_save_tag = '_rectified'
    else:
        r_tag = ''
        r_save_tag = ''
    save_dir = paths.tca_plots(mouse,
                               'group',
                               pars=pars,
                               word=word,
                               group_pars=group_pars)
    save_dir = os.path.join(save_dir, 'varex' + nt_save_tag + r_tag)
    if not os.path.isdir(save_dir): os.mkdir(save_dir)
    save_dir = os.path.join(save_dir, 'byday_bycomp')
    if not os.path.isdir(save_dir): os.mkdir(save_dir)
    var_path = os.path.join(
        save_dir,
        str(mouse) + '_rank' + str(rank_num) + '_norm_varex_by_day' +
        r_save_tag + nt_save_tag + '.pdf')

    # load dir
    load_dir = paths.tca_path(mouse,
                              'group',
                              pars=pars,
                              word=word,
                              group_pars=group_pars)
    tensor_path = os.path.join(
        load_dir,
        str(mouse) + '_' + str(group_by) + nt_tag + '_group_decomp_' +
        str(trace_type) + '.npy')
    input_tensor_path = os.path.join(
        load_dir,
        str(mouse) + '_' + str(group_by) + nt_tag + '_group_tensor_' +
        str(trace_type) + '.npy')
    ids_tensor_path = os.path.join(
        load_dir,
        str(mouse) + '_' + str(group_by) + nt_tag + '_group_ids_' +
        str(trace_type) + '.npy')
    meta_path = os.path.join(
        load_dir,
        str(mouse) + '_' + str(group_by) + nt_tag + '_df_group_meta.pkl')

    # load your data
    # ensemble = np.load(tensor_path)
    # ensemble = ensemble.item()
    # # re-balance your factors ()
    # print('Re-balancing factors.')
    # for r in ensemble[method].results:
    #     for i in range(len(ensemble[method].results[r])):
    #         ensemble[method].results[r][i].factors.rebalance()
    # V = ensemble[method]
    # X = np.load(input_tensor_path)
    # meta = pd.read_pickle(meta_path)
    # meta = utils.update_naive_cs(meta)
    # orientation = meta.reset_index()['orientation']
    # condition = meta.reset_index()['condition']
    # speed = meta.reset_index()['speed']
    # dates = meta.reset_index()['date']
    # # time_in_trial = meta.reset_index()['trial_idx']
    # # total_time = pd.DataFrame(data={'total_time': np.arange(len(time_in_trial))}, index=time_in_trial.index)
    # learning_state = meta['learning_state']
    # trialerror = meta['trialerror']
    # ids = np.load(ids_tensor_path)

    # create dataframe of dprime values
    # dprime_vec = []
    # for date in dates:
    #     date_obj = flow.Date(mouse, date=date)
    #     dprime_vec.append(pool.calc.performance.dprime(date_obj))
    # data = {'dprime': dprime_vec}
    # dprime = pd.DataFrame(data=data, index=speed.index)
    # dprime = dprime['dprime']  # make indices match to meta

    test = calc.var.groupday_varex_byday_bycomp(flow.Mouse(
        mouse=mouse, exclude_tags=['bad']),
                                                word=word)
    test3 = calc.var.groupday_varex_byday(flow.Mouse(mouse=mouse,
                                                     exclude_tags=['bad']),
                                          word=word)
    # test = cas.calc.var.groupday_varex_byday_bycomp(flow.Mouse(mouse='VF226'), word='already')
    # test3 = cas.calc.var.groupday_varex_byday(flow.Mouse(mouse='VF226'), word='already')

    R = rank_num
    comp_var_df = test.loc[test['rank'] == R, :]
    col_var = deepcopy(comp_var_df['variance_explained_tcamodel'].values)
    new_col_var = deepcopy(comp_var_df['variance_explained_tcamodel'].values)
    new_col_dates = deepcopy(comp_var_df['date'].values)

    daily_var_df = test3.loc[test3['rank'] == R, :]
    daily_var_lookup = deepcopy(
        daily_var_df['variance_explained_tcamodel'].values)

    for c, day in enumerate(np.unique(new_col_dates)):
        new_col_dates[new_col_dates == day] = c
        new_col_var[comp_var_df['date'].values == day] = (
            col_var[comp_var_df['date'].values == day] /
            daily_var_lookup[daily_var_df['date'].values == day])
    comp_var_df['norm_varex'] = new_col_var
    comp_var_df['day_num'] = new_col_dates

    g = sns.relplot(x="day_num",
                    y="norm_varex",
                    hue="component",
                    data=comp_var_df.loc[(comp_var_df['day_num'] >= 0)
                                         & (comp_var_df['day_num'] <= 100), :],
                    legend='full',
                    kind='line',
                    alpha=0.8,
                    palette=sns.color_palette('muted', R),
                    marker='o')
    plt.title(mouse +
              ': Fraction of total daily variance explained per component')
    plt.savefig(var_path, bbox_inches='tight')
Exemple #6
0
def _splice_data_inputs(psydata,
                        mouse,
                        trace_type='zscore_day',
                        method='mncp_hals',
                        cs='',
                        warp=False,
                        word=None,
                        group_by='all',
                        nan_thresh=0.85,
                        score_threshold=0.8,
                        rank_num=18,
                        verbose=True):
    """
    Create dict used for fitting Pillow model. Main purpose of function
    is to align indices from Pillow and TCA since they often sub-select
    different trials. This forces Pillow 'y' and 'answer' to have same
    trials as TCA and uses TCA trial factors as 'inputs'.
    """

    # default TCA params to use
    if not word:
        if mouse == 'OA27':
            word = 'tray'
        else:
            word = 'obligations'  # should be updated to 'obligations'
        if verbose:
            print('Creating dataframe for ' + mouse + '-' + word)

    ms = flow.Mouse(mouse)
    psy = ms.psytracker(verbose=True)
    dateRuns = psy.data['dateRuns']
    trialRuns = psy.data['runLength']

    # create your trial indices per day and run
    trial_idx = []
    for i in trialRuns:
        trial_idx.extend(range(i))

    # get date and run vectors
    date_vec = []
    run_vec = []
    for c, i in enumerate(dateRuns):
        date_vec.extend([i[0]] * trialRuns[c])
        run_vec.extend([i[1]] * trialRuns[c])

    # create your data dict, transform from log odds to odds ratio
    data = {}
    for c, i in enumerate(psy.weight_labels):
        # adding multiplication step here with binary vector !!!!!!
        data[i] = np.exp(psy.fits[c, :]) * psy.inputs[:, c].T
    ori_0_in = [i[0] for i in psy.data['inputs']['ori_0']]
    ori_135_in = [i[0] for i in psy.data['inputs']['ori_135']]
    ori_270_in = [i[0] for i in psy.data['inputs']['ori_270']]
    blank_in = [
        0 if i == 1 else 1
        for i in np.sum((ori_0_in, ori_135_in, ori_270_in), axis=0)
    ]

    # loop through psy data create a binary vectors for trial history
    binary_cat = ['ori_0', 'ori_135', 'ori_270', 'prev_reward', 'prev_punish']
    for cat in binary_cat:
        data[cat + '_th'] = [i[0] for i in psy.data['inputs'][cat]]
        data[cat + '_th_prev'] = [i[1] for i in psy.data['inputs'][cat]]

    # create a single list of orientations to match format of meta
    ori_order = [0, 135, 270, -1]
    data['orientation'] = [
        ori_order[np.where(np.isin(i, 1))[0][0]]
        for i in zip(ori_0_in, ori_135_in, ori_270_in, blank_in)
    ]

    # create your index out of relevant variables
    index = pd.MultiIndex.from_arrays(
        [[mouse] * len(trial_idx), date_vec, run_vec, trial_idx],
        names=['mouse', 'date', 'run', 'trial_idx'])

    # make master dataframe
    dfr = pd.DataFrame(data, index=index)

    # load TCA data
    load_kwargs = {
        'mouse': mouse,
        'method': method,
        'cs': cs,
        'warp': warp,
        'word': word,
        'group_by': group_by,
        'nan_thresh': nan_thresh,
        'score_threshold': score_threshold,
        'rank': rank_num
    }
    tensor, _, _ = load.groupday_tca_model(**load_kwargs)
    meta = load.groupday_tca_meta(**load_kwargs)

    # add in continuous dprime so psytracker data frame
    dp = pool.calc.psytrack.dprime(flow.Mouse(mouse))
    dfr['dprime'] = dp

    # add in non continuous dprime to meta dataframe
    meta = utils.add_dprime_to_meta(meta)

    # filter out blank trials
    blank_trials_bool = (dfr['orientation'] >= 0)
    psy_df = dfr.loc[blank_trials_bool, :]

    # check that all runs have matched trial orientations
    new_psy_df_list = []
    new_meta_df_list = []
    drop_trials_bin = np.zeros((len(psy_df)))
    dates = meta.reset_index()['date'].unique()
    for d in dates:
        psy_day_bool = psy_df.reset_index()['date'].isin([d]).values
        meta_day_bool = meta.reset_index()['date'].isin([d]).values
        psy_day_df = psy_df.iloc[psy_day_bool, :]
        meta_day_df = meta.iloc[meta_day_bool, :]
        runs = meta_day_df.reset_index()['run'].unique()
        drop_pos_day = np.where(psy_day_bool)[0]
        for r in runs:
            psy_run_bool = psy_day_df.reset_index()['run'].isin([r]).values
            meta_run_bool = meta_day_df.reset_index()['run'].isin([r]).values
            psy_run_df = psy_day_df.iloc[psy_run_bool, :]
            meta_run_df = meta_day_df.iloc[meta_run_bool, :]
            psy_run_idx = psy_run_df.reset_index()['trial_idx'].values
            meta_run_idx = meta_run_df.reset_index()['trial_idx'].values

            # drop extra trials from trace2P that don't have associated imaging
            max_trials = np.min([len(psy_run_idx), len(meta_run_idx)])

            # get just your orientations for checking that trials are matched
            meta_ori = meta_run_df['orientation'].iloc[:max_trials]
            psy_ori = psy_run_df['orientation'].iloc[:max_trials]

            # make sure all oris match between vectors of the same length each day
            assert np.all(psy_ori.values == meta_ori.values)

            # check which trials are dropped
            drop_pos_run = drop_pos_day[psy_run_bool][:max_trials]
            drop_trials_bin[drop_pos_run] = 1

            # if everything looks good, copy meta index into psy
            meta_new = meta_run_df.iloc[:max_trials]
            psy_new = psy_run_df.iloc[:max_trials]
            data = {}
            for i in psy_new.columns:
                data[i] = psy_new[i].values
            new_psy_df_list.append(
                pd.DataFrame(data=data, index=meta_new.index))
            new_meta_df_list.append(meta_new)

    meta1 = pd.concat(new_meta_df_list, axis=0)
    psy1 = pd.concat(new_psy_df_list, axis=0)

    tca_data = {}
    for comp_num in range(1, rank_num + 1):
        fac = tensor.results[rank_num][0].factors[2][:, comp_num - 1]
        tca_data['factor_' + str(comp_num)] = fac[:, None]

    # which values were dropped from the psydata. Use this to update psydata
    blank_trials_bool[blank_trials_bool] = (drop_trials_bin == 1)
    keep_bool = blank_trials_bool

    # you don't have any blank trials so drop them.
    psydata['y'] = psydata['y'][keep_bool]  # 1-2 binary not 0-1
    psydata['answer'] = psydata['answer'][keep_bool]
    psydata['correct'] = psydata['correct'][keep_bool]
    psydata['dateRunTrials'] = psydata['dateRunTrials'][keep_bool]

    # recalculate dayLength and runLength and dateRuns
    new_runLength = []
    new_dayLength = []
    new_dateRuns = []
    for di in np.unique(psydata['dateRunTrials'][:, 0]):
        day_bool = psydata['dateRunTrials'][:, 0] == di
        new_dayLength.append(np.sum(day_bool))
        day_runs = psydata['dateRunTrials'][day_bool, 1]
        for ri in np.unique(day_runs):
            run_bool = day_runs == ri
            new_runLength.append(np.sum(run_bool))
            new_dateRuns.append([di, ri])
    psydata['dayLength'] = new_dayLength
    psydata['runLength'] = new_runLength
    psydata['dateRuns'] = np.array(new_dateRuns)

    # update days
    clean_days = np.unique(psydata['dateRunTrials'][:, 0])
    clean_day_bool = np.isin(psydata['days'], clean_days)
    psydata['days'] = psydata['days'][clean_day_bool]

    # ensure that you still have the same number of runs
    assert len(psydata['runLength']) == len(psydata['dateRuns'])

    # ensure that you still have the same number of days
    assert len(psydata['dayLength']) == len(psydata['days'])

    # reset inputs
    psydata['inputs'] = tca_data

    if verbose:
        print('Successful sync of psytracker and TCA data :)')
        print(' Fitting {} days'.format(len(psydata['days'])))
        print(' Fitting {} runs'.format(len(psydata['dateRuns'])))
        print(' Fitting {} trials'.format(len(psydata['dateRunTrials'])))
        print(' Fitting {} total hyper-parameters'.format(rank_num))

    return psydata
Exemple #7
0
def _splice_data_y(psydata,
                   mouse,
                   trace_type='zscore_day',
                   method='mncp_hals',
                   cs='',
                   warp=False,
                   word=None,
                   group_by='all',
                   nan_thresh=0.85,
                   score_threshold=0.8,
                   rank_num=18,
                   comp_num=1,
                   verbose=True):
    """
    Create a pandas dataframe of trial history modulation indices for one
    mouse. Only looks at initial learning stage.
    """

    # default TCA params to use
    if not word:
        if mouse == 'OA27':
            word = 'tray'
        else:
            word = 'obligations'  # should be updated to 'obligations'
        if verbose:
            print('Creating dataframe for ' + mouse + '-' + word)

    ms = flow.Mouse(mouse)
    psy = ms.psytracker(verbose=True)
    dateRuns = psy.data['dateRuns']
    trialRuns = psy.data['runLength']

    # create your trial indices per day and run
    trial_idx = []
    for i in trialRuns:
        trial_idx.extend(range(i))

    # get date and run vectors
    date_vec = []
    run_vec = []
    for c, i in enumerate(dateRuns):
        date_vec.extend([i[0]] * trialRuns[c])
        run_vec.extend([i[1]] * trialRuns[c])

    # create your data dict, transform from log odds to odds ratio
    data = {}
    for c, i in enumerate(psy.weight_labels):
        # adding multiplication step here with binary vector !!!!!!
        data[i] = np.exp(psy.fits[c, :]) * psy.inputs[:, c].T
    ori_0_in = [i[0] for i in psy.data['inputs']['ori_0']]
    ori_135_in = [i[0] for i in psy.data['inputs']['ori_135']]
    ori_270_in = [i[0] for i in psy.data['inputs']['ori_270']]
    blank_in = [
        0 if i == 1 else 1
        for i in np.sum((ori_0_in, ori_135_in, ori_270_in), axis=0)
    ]

    # loop through psy data create a binary vectors for trial history
    binary_cat = ['ori_0', 'ori_135', 'ori_270', 'prev_reward', 'prev_punish']
    for cat in binary_cat:
        data[cat + '_th'] = [i[0] for i in psy.data['inputs'][cat]]
        data[cat + '_th_prev'] = [i[1] for i in psy.data['inputs'][cat]]

    # create a single list of orientations to match format of meta
    ori_order = [0, 135, 270, -1]
    data['orientation'] = [
        ori_order[np.where(np.isin(i, 1))[0][0]]
        for i in zip(ori_0_in, ori_135_in, ori_270_in, blank_in)
    ]

    # create your index out of relevant variables
    index = pd.MultiIndex.from_arrays(
        [[mouse] * len(trial_idx), date_vec, run_vec, trial_idx],
        names=['mouse', 'date', 'run', 'trial_idx'])

    # make master dataframe
    dfr = pd.DataFrame(data, index=index)

    # load TCA data
    load_kwargs = {
        'mouse': mouse,
        'method': method,
        'cs': cs,
        'warp': warp,
        'word': word,
        'group_by': group_by,
        'nan_thresh': nan_thresh,
        'score_threshold': score_threshold,
        'rank': rank_num
    }
    tensor, _, _ = load.groupday_tca_model(**load_kwargs)
    meta = load.groupday_tca_meta(**load_kwargs)

    # add in continuous dprime so psytracker data frame
    dp = pool.calc.psytrack.dprime(flow.Mouse(mouse))
    dfr['dprime'] = dp

    # add in non continuous dprime to meta dataframe
    meta = utils.add_dprime_to_meta(meta)

    # filter out blank trials
    blank_trials_bool = (dfr['orientation'] >= 0)
    psy_df = dfr.loc[blank_trials_bool, :]

    # check that all runs have matched trial orientations
    new_psy_df_list = []
    new_meta_df_list = []
    drop_trials_bin = np.zeros((len(psy_df)))
    dates = meta.reset_index()['date'].unique()
    for d in dates:
        psy_day_bool = psy_df.reset_index()['date'].isin([d]).values
        meta_day_bool = meta.reset_index()['date'].isin([d]).values
        psy_day_df = psy_df.iloc[psy_day_bool, :]
        meta_day_df = meta.iloc[meta_day_bool, :]
        runs = meta_day_df.reset_index()['run'].unique()
        drop_pos_day = np.where(psy_day_bool)[0]
        for r in runs:
            psy_run_bool = psy_day_df.reset_index()['run'].isin([r]).values
            meta_run_bool = meta_day_df.reset_index()['run'].isin([r]).values
            psy_run_df = psy_day_df.iloc[psy_run_bool, :]
            meta_run_df = meta_day_df.iloc[meta_run_bool, :]
            psy_run_idx = psy_run_df.reset_index()['trial_idx'].values
            meta_run_idx = meta_run_df.reset_index()['trial_idx'].values

            # drop extra trials from trace2P that don't have associated imaging
            max_trials = np.min([len(psy_run_idx), len(meta_run_idx)])

            # get just your orientations for checking that trials are matched
            meta_ori = meta_run_df['orientation'].iloc[:max_trials]
            psy_ori = psy_run_df['orientation'].iloc[:max_trials]

            # make sure all oris match between vectors of the same length each day
            assert np.all(psy_ori.values == meta_ori.values)

            # check which trials are dropped
            drop_pos_run = drop_pos_day[psy_run_bool][:max_trials]
            drop_trials_bin[drop_pos_run] = 1

            # if everything looks good, copy meta index into psy
            meta_new = meta_run_df.iloc[:max_trials]
            psy_new = psy_run_df.iloc[:max_trials]
            data = {}
            for i in psy_new.columns:
                data[i] = psy_new[i].values
            new_psy_df_list.append(
                pd.DataFrame(data=data, index=meta_new.index))
            new_meta_df_list.append(meta_new)

    meta1 = pd.concat(new_meta_df_list, axis=0)
    psy1 = pd.concat(new_psy_df_list, axis=0)

    tca_data = {}
    fac = tensor.results[rank_num][0].factors[2][:, comp_num - 1]
    tca_data['factor_' + str(comp_num)] = fac
    fac_df = pd.DataFrame(data=tca_data, index=meta1.index)

    # threshold your data in a clever way so that you are not only
    # looking at orientation trials
    clever_binary = np.ones((len(fac)))
    thresh = np.nanstd(fac) * 1
    clever_binary[fac > thresh] = 2

    # which values were dropped from the psydata. Use this to update psydata
    blank_trials_bool[blank_trials_bool] = (drop_trials_bin == 1)
    keep_bool = blank_trials_bool
    drop_bool = blank_trials_bool == False
    # keep_bool = np.logical_and(
    #     drop_trials_bin == 0, blank_trials_bool.values == True)
    # drop_bool = np.logical_or(
    #     drop_trials_bin == 1, blank_trials_bool.values == False)

    # you don't have any blank trials to avoid using them.
    psydata['y'][drop_bool] = 1
    psydata['answer'][drop_bool] = 1  # 1-2 binary not 0-1
    psydata['y'][keep_bool] = clever_binary
    psydata['answer'][keep_bool] = clever_binary

    print('cleared :)')

    return psydata
Exemple #8
0
def groupmouse_correlate_pillow_tca(mice=[
    'OA27', 'OA32', 'OA34', 'CC175', 'OA36', 'OA26', 'OA67', 'VF226'
],
                                    words=[
                                        'orlando', 'already', 'already',
                                        'already', 'already', 'already',
                                        'already', 'already'
                                    ],
                                    trace_type='zscore_day',
                                    method='mncp_hals',
                                    cs='',
                                    warp=False,
                                    group_by='all',
                                    nan_thresh=0.85,
                                    score_threshold=None):

    # preallocate
    corr_list = []
    pmat_list = []
    x_labels = []

    # LOADING
    for mouse, word in zip(mice, words):
        # Load or run your psytracker behavioral model
        ms = flow.Mouse(mouse)
        psy = ms.psytracker(verbose=True)
        dateRuns = psy.data['dateRuns']
        trialRuns = psy.data['runLength']

        # create your trial indices per day and run
        trial_idx = []
        for i in trialRuns:
            trial_idx.extend(range(i))

        # get date and run vectors
        date_vec = []
        run_vec = []
        for c, i in enumerate(dateRuns):
            date_vec.extend([i[0]] * trialRuns[c])
            run_vec.extend([i[1]] * trialRuns[c])

        # create your data dict, transform from log odds to odds ratio
        data = {}
        for c, i in enumerate(psy.weight_labels):
            # adding multiplication step here with binary vector
            data[i] = np.exp(psy.fits[c, :]) * psy.inputs[:, c].T
        ori_0_in = [i[0] for i in psy.data['inputs']['ori_0']]
        ori_135_in = [i[0] for i in psy.data['inputs']['ori_135']]
        ori_270_in = [i[0] for i in psy.data['inputs']['ori_270']]
        blank_in = [
            0 if i == 1 else 1
            for i in np.sum((ori_0_in, ori_135_in, ori_270_in), axis=0)
        ]

        # create a single list of orientations to match format of meta
        ori_order = [0, 135, 270, -1]
        data['orientation'] = [
            ori_order[np.where(np.isin(i, 1))[0][0]]
            for i in zip(ori_0_in, ori_135_in, ori_270_in, blank_in)
        ]

        # create your index out of relevant variables
        index = pd.MultiIndex.from_arrays(
            [[mouse] * len(trial_idx), date_vec, run_vec, trial_idx],
            names=['mouse', 'date', 'run', 'trial_idx'])

        dfr = pd.DataFrame(data, index=index)

        # Load TCA results
        load_kwargs = {
            'mouse': mouse,
            'method': method,
            'cs': cs,
            'warp': warp,
            'word': word,
            'group_by': group_by,
            'nan_thresh': nan_thresh,
            'score_threshold': score_threshold
        }
        tensor, _, _ = load.groupday_tca_model(**load_kwargs)
        meta = load.groupday_tca_meta(**load_kwargs)

        savepath = paths.tca_plots(mouse,
                                   'group',
                                   word=word,
                                   group_pars={'group_by': group_by})
        savepath = os.path.join(savepath, 'psytrack-vs-tca')
        if not os.path.isdir(savepath): os.mkdir(savepath)

        # Load smooth dprime
        dfr['dprime'] = pool.calc.psytrack.dprime(flow.Mouse(mouse))

        # CHECK THAT TRIAL INDICES ARE MATCHED AND HAVE MATCHED ORIS

        # filter out blank trials
        psy_df = dfr.loc[(dfr['orientation'] >= 0), :]

        # check that all runs have matched trial orienations
        new_psy_df_list = []
        new_meta_df_list = []
        dates = meta.reset_index()['date'].unique()
        for d in dates:
            psy_day_bool = psy_df.reset_index()['date'].isin([d]).values
            meta_day_bool = meta.reset_index()['date'].isin([d]).values
            psy_day_df = psy_df.iloc[psy_day_bool, :]
            meta_day_df = meta.iloc[meta_day_bool, :]
            runs = meta_day_df.reset_index()['run'].unique()
            for r in runs:
                psy_run_bool = psy_day_df.reset_index()['run'].isin([r]).values
                meta_run_bool = meta_day_df.reset_index()['run'].isin(
                    [r]).values
                psy_run_df = psy_day_df.iloc[psy_run_bool, :]
                meta_run_df = meta_day_df.iloc[meta_run_bool, :]
                psy_run_idx = psy_run_df.reset_index()['trial_idx'].values
                meta_run_idx = meta_run_df.reset_index()['trial_idx'].values

                # drop extra trials from trace2P that don't have associated imaging
                max_trials = np.min([len(psy_run_idx), len(meta_run_idx)])

                # get just your orientations for checking that trials are matched
                meta_ori = meta_run_df['orientation'].iloc[:max_trials]
                psy_ori = psy_run_df['orientation'].iloc[:max_trials]

                # make sure all oris match between vectors of the same length each day
                assert np.all(psy_ori.values == meta_ori.values)

                # if everything looks good, copy meta index into psy
                meta_new = meta_run_df.iloc[:max_trials]
                psy_new = psy_run_df.iloc[:max_trials]
                data = {}
                for i in psy_new.columns:
                    data[i] = psy_new[i].values
                new_psy_df_list.append(
                    pd.DataFrame(data=data, index=meta_new.index))
                new_meta_df_list.append(meta_new)

        meta1 = pd.concat(new_meta_df_list, axis=0)
        psy1 = pd.concat(new_psy_df_list, axis=0)

        # NOW TAKE TCA TRIAL FACTORS AND TRY CORRELATING FOR WITH PILLOW
        # put factors for a given rank into a dataframe

        ori = 'all'
        save_pls = False
        iteration = 0
        for rank in [18]:  # tensor.results
            data = {}
            for i in range(rank):
                fac = tensor.results[rank][iteration].factors[2][:, i]
                data['factor_' + str(i + 1)] = fac
            fac_df = pd.DataFrame(data=data, index=meta1.index)

            # loop over single oris
            single_ori = pd.concat([psy1, fac_df],
                                   axis=1).drop(columns='orientation')
            corr = np.corrcoef(single_ori.values.T)

            single_data = {}
            for c, i in enumerate(single_ori.columns):
                single_data[i] = corr[:, c]
            corr_plt = pd.DataFrame(data=single_data, index=single_ori.columns)

            num_corr = np.shape(single_ori)[1]
            corrmat = np.zeros((num_corr, num_corr))
            pmat = np.zeros((num_corr, num_corr))
            for i in range(num_corr):
                for k in range(num_corr):
                    corA, corP = pearsonr(single_ori.values[:, i],
                                          single_ori.values[:, k])
                    corrmat[i, k] = corA
                    pmat[i, k] = corP

            if mouse == mice[0]:
                y_label = single_ori.columns

        # stick chunks of corr matrix together
        x_labels.extend([mouse + ' ' + s for s in single_ori.columns[0:8]])
        corr_list.append(corrmat[:, 0:8])
        pmat_list.append(pmat[:, 0:8])

    # concatenate final matrix together
    corrmat = np.concatenate(corr_list, axis=1)
    pmat = np.concatenate(pmat_list, axis=1)
    annot = True
    figsize = (80, 20)

    # create your path for saving
    rankpath = os.path.join(savepath, 'rank ' + str(rank))
    if not os.path.isdir(rankpath): os.mkdir(rankpath)
    var_path_prefix = os.path.join(
        rankpath,
        mouse + '_psytrack-vs-tca_ori-' + str(ori) + '_rank-' + str(rank))

    plt.figure(figsize=figsize)
    # plt.figure()
    sns.heatmap(corrmat,
                annot=annot,
                xticklabels=x_labels,
                yticklabels=y_label,
                square=False,
                cbar_kws={'label': 'correlation (R)'})
    plt.xticks(rotation=45, ha='right')
    plt.title('Pearson-R corrcoef: rank ' + str(rank))
    if save_pls:
        plt.savefig(var_path_prefix + '_corr.pdf', bbox_inches='tight')

    plt.figure(figsize=figsize)
    # plt.figure()
    sns.heatmap(pmat,
                annot=annot,
                xticklabels=x_labels,
                yticklabels=y_label,
                square=False,
                cbar_kws={'label': 'p-value'})
    plt.xticks(rotation=45, ha='right')
    plt.title('Pearson-R p-values: rank ' + str(rank))
    if save_pls:
        plt.savefig(var_path_prefix + '_pvals.pdf', bbox_inches='tight')

    plt.figure(figsize=figsize)
    # plt.figure()
    logger = np.log10(pmat).flatten()
    vmin = np.nanmin(logger[np.isfinite(logger)])
    vmax = 0
    sns.heatmap(np.log10(pmat),
                annot=annot,
                xticklabels=x_labels,
                yticklabels=y_label,
                vmin=vmin,
                vmax=vmax,
                square=False,
                cbar_kws={'label': 'log$_{10}$(p-value)'})
    plt.xticks(rotation=45, ha='right')
    plt.title('Pearson-R log$_{10}$(p-values): rank ' + str(rank))
    if save_pls:
        plt.savefig(var_path_prefix + '_log10pvals.pdf', bbox_inches='tight')

    # close plots after saving to save memory
    if save_pls:
        plt.close('all')