def pair_field_preprocess_exp(df, vthresh=5, sthresh=3, NShuffles=200, save_pth=None):
    pairdata_dict = dict(pair_id=[], ca=[], overlap=[],
                         border1=[], border2=[], area1=[], area2=[], xyval1=[], xyval2=[],
                         aver_rate1=[], aver_rate2=[], aver_rate_pair=[],
                         peak_rate1=[], peak_rate2=[],
                         minocc1=[], minocc2=[],
                         fieldcoor1=[], fieldcoor2=[], com1=[], com2=[],

                         rate_angle1=[], rate_angle2=[], rate_anglep=[],
                         rate_R1=[], rate_R2=[], rate_Rp=[],

                         num_spikes1=[], num_spikes2=[], num_spikes_pair=[],
                         phaselag_AB=[], phaselag_BA=[], corr_info_AB=[], corr_info_BA=[],
                         thetaT_AB=[], thetaT_BA=[],
                         rate_AB=[], rate_BA=[], corate=[], pair_rate=[],
                         kld=[], rate_R_pvalp=[],
                         precess_df1=[], precess_angle1=[], precess_R1=[],
                         precess_df2=[], precess_angle2=[], precess_R2=[],
                         numpass_at_precess1=[], numpass_at_precess2=[],
                         precess_dfp=[])

    num_trials = df.shape[0]
    aedges = np.linspace(-np.pi, np.pi, 36)
    abind = aedges[1] - aedges[0]
    sp_binwidth = 5
    aedges_precess = np.linspace(-np.pi, np.pi, 6)
    kappa_precess = 1
    precess_filter = PrecessionFilter()

    pair_id = 0
    for ntrial in range(num_trials):
        wave = df.loc[ntrial, 'wave']
        precesser = PrecessionProcesser(wave=wave)


        for ca in ['CA%d' % (i + 1) for i in range(3)]:

            # Get data
            pair_df = df.loc[ntrial, ca + 'pairs']
            field_df = df.loc[ntrial, ca + 'fields']
            indata = df.loc[ntrial, ca + 'indata']
            if (indata.shape[0] < 1) & (pair_df.shape[0] < 1) & (field_df.shape[0] < 1):
                continue

            tunner = IndataProcessor(indata, vthresh=vthresh, sthresh=sthresh, minpasstime=0.4)
            interpolater_angle = interp1d(tunner.t, tunner.angle)
            interpolater_x = interp1d(tunner.t, tunner.x)
            interpolater_y = interp1d(tunner.t, tunner.y)
            all_maxt, all_mint = tunner.t.max(), tunner.t.min()
            trange = (all_maxt, all_mint)
            dt = tunner.t[1] - tunner.t[0]
            precesser.set_trange(trange)

            ##  Loop for pairs
            num_pairs = pair_df.shape[0]
            for npair in range(num_pairs):

                print('%d/%d trial, %s, %d/%d pair id=%d' % (ntrial, num_trials, ca, npair, num_pairs, pair_id))

                # find within-mask indexes
                field_ids = pair_df.loc[npair, 'fi'] - 1  # minus 1 to convert to python index
                mask1, pf1, xyval1 = field_df.loc[field_ids[0], ['mask', 'pf', 'xyval']]
                mask2, pf2, xyval2 = field_df.loc[field_ids[1], ['mask', 'pf', 'xyval']]
                tsp1 = field_df.loc[field_ids[0], 'xytsp']['tsp']
                tsp2 = field_df.loc[field_ids[1], 'xytsp']['tsp']
                xaxis1, yaxis1 = pf1['X'][:, 0], pf1['Y'][0, :]
                xaxis2, yaxis2 = pf2['X'][:, 0], pf2['Y'][0, :]
                assert (np.all(xaxis1==xaxis2) and np.all(yaxis1==yaxis2))
                area1 = mask1.sum()
                area2 = mask2.sum()
                field_d1 = np.sqrt(area1/np.pi)*2
                field_d2 = np.sqrt(area2/np.pi)*2


                # Find overlap
                _, ks_dist, _ = dist_overlap(pf1['map'], pf2['map'], mask1, mask2)

                # Field's center coordinates
                maskedmap1, maskedmap2 = pf1['map'] * mask1, pf2['map'] * mask2
                cooridx1 = np.unravel_index(maskedmap1.argmax(), maskedmap1.shape)
                cooridx2 = np.unravel_index(maskedmap2.argmax(), maskedmap2.shape)
                fcoor1 = np.array([pf1['X'][cooridx1[0], cooridx1[1]], pf1['Y'][cooridx1[0], cooridx1[1]]])
                fcoor2 = np.array([pf2['X'][cooridx2[0], cooridx2[1]], pf2['Y'][cooridx2[0], cooridx2[1]]])
                XY1 = np.stack([pf1['X'].ravel(), pf1['Y'].ravel()])
                XY2 = np.stack([pf2['X'].ravel(), pf2['Y'].ravel()])
                com1 = np.sum(XY1 * (maskedmap1/np.sum(maskedmap1)).ravel().reshape(1, -1), axis=1)
                com2 = np.sum(XY2 * (maskedmap2/np.sum(maskedmap2)).ravel().reshape(1, -1), axis=1)

                # Border
                border1 = check_border(mask1, margin=2)
                border2 = check_border(mask2, margin=2)

                # Construct passes (segment & chunk) for field 1 and 2
                tok1, idin1 = tunner.get_idin(mask1, xaxis1, yaxis1)
                tok2, idin2 = tunner.get_idin(mask2, xaxis2, yaxis2)
                passdf1 = tunner.construct_singlefield_passdf(tok1, tsp1, interpolater_x, interpolater_y, interpolater_angle)
                passdf2 = tunner.construct_singlefield_passdf(tok2, tsp2, interpolater_x, interpolater_y, interpolater_angle)
                allchunk_df1 = passdf1[(~passdf1['rejected']) & (passdf1['chunked']<2)].reset_index(drop=True)
                allchunk_df2 = passdf2[(~passdf2['rejected']) & (passdf2['chunked']<2)].reset_index(drop=True)
                # allchunk_df1 = passdf1[(~passdf1['rejected']) ].reset_index(drop=True)
                # allchunk_df2 = passdf2[(~passdf2['rejected']) ].reset_index(drop=True)
                if (allchunk_df1.shape[0] < 1) or (allchunk_df2.shape[0] < 1):
                    continue

                x1list, y1list, angle1list = allchunk_df1['x'].to_list(), allchunk_df1['y'].to_list(), allchunk_df1['angle'].to_list()
                t1list, tsp1list = allchunk_df1['t'].to_list(), allchunk_df1['tsp'].to_list()
                x2list, y2list, angle2list = allchunk_df2['x'].to_list(), allchunk_df2['y'].to_list(), allchunk_df2['angle'].to_list()
                t2list, tsp2list = allchunk_df2['t'].to_list(), allchunk_df2['tsp'].to_list()
                if (len(t1list) < 1) or (len(t2list) < 1) or (len(tsp1list) < 1) or (len(tsp2list) < 1):
                    continue
                x1, x2 = np.concatenate(x1list), np.concatenate(x2list)
                y1, y2 = np.concatenate(y1list), np.concatenate(y2list)
                hd1, hd2 = np.concatenate(angle1list), np.concatenate(angle2list)
                pos1, pos2 = np.stack([x1, y1]).T, np.stack([x2, y2]).T

                tsp1, tsp2 = np.concatenate(tsp1list), np.concatenate(tsp2list)
                xsp1, xsp2 = np.concatenate(allchunk_df1['spikex'].to_list()), np.concatenate(allchunk_df2['spikex'].to_list())
                ysp1, ysp2 = np.concatenate(allchunk_df1['spikey'].to_list()), np.concatenate(allchunk_df2['spikey'].to_list())
                possp1, possp2 = np.stack([xsp1, ysp1]).T, np.stack([xsp2, ysp2]).T
                hdsp1 = np.concatenate(allchunk_df1['spikeangle'].to_list())
                hdsp2 = np.concatenate(allchunk_df2['spikeangle'].to_list())
                nspks1, nspks2 = tsp1.shape[0], tsp2.shape[0]

                # Rates
                aver_rate1 = nspks1 / (x1.shape[0]*dt)
                peak_rate1 = np.max(pf1['map'] * mask1)
                aver_rate2 = nspks2 / (x2.shape[0] * dt)
                peak_rate2 = np.max(pf2['map'] * mask2)

                # Directionality
                occbins1, _ = np.histogram(hd1, bins=aedges)
                spbins1, _ = np.histogram(hdsp1, bins=aedges)
                mlmer1 = DirectionerMLM(pos1, hd1, dt=dt, sp_binwidth=sp_binwidth, a_binwidth=abind)
                rate_angle1, rate_R1, normprob1_mlm = mlmer1.get_directionality(possp1, hdsp1)
                normprob1_mlm[np.isnan(normprob1_mlm)] = 0

                occbins2, _ = np.histogram(hd2, bins=aedges)
                spbins2, _ = np.histogram(hdsp2, bins=aedges)
                mlmer2 = DirectionerMLM(pos2, hd2, dt=dt, sp_binwidth=sp_binwidth, a_binwidth=abind)
                rate_angle2, rate_R2, normprob2_mlm = mlmer2.get_directionality(possp2, hdsp2)
                normprob2_mlm[np.isnan(normprob2_mlm)] = 0
                minocc1, minocc2 = occbins1.min(), occbins2.min()
                neuro_keys_dict = dict(tsp='tsp', spikev='spikev', spikex='spikex', spikey='spikey',
                                        spikeangle='spikeangle')

                # Precession1 & Post-hoc exclusion for 1st field
                accept_mask1 = (~passdf1['rejected']) & (passdf1['chunked'] < 2)
                passdf1['excluded_for_precess'] = ~accept_mask1
                precessdf1, precessangle1, precessR1, _ = get_single_precessdf(passdf1, precesser, precess_filter, neuro_keys_dict,
                                                                               field_d=field_d1, kappa=kappa_precess, bins=None)
                fitted_precessdf1 = precessdf1[precessdf1['fitted']].reset_index(drop=True)
                if (precessangle1 is not None) and (fitted_precessdf1.shape[0] > 0):

                    # Post-hoc precession exclusion
                    _, _, postdoc_dens1 = compute_precessangle(pass_angles=fitted_precessdf1['mean_anglesp'].to_numpy(),
                                                              pass_nspikes=fitted_precessdf1['pass_nspikes'].to_numpy(),
                                                              precess_mask=fitted_precessdf1['precess_exist'].to_numpy(),
                                                              kappa=None, bins=aedges_precess)
                    (_, passbins_p1, passbins_np1, _) = postdoc_dens1
                    all_passbins1 = passbins_p1 + passbins_np1
                    numpass_at_precess1 = get_numpass_at_angle(target_angle=precessangle1, aedge=aedges_precess,
                                                              all_passbins=all_passbins1)
                else:
                    numpass_at_precess1 = None

                # Precession2 & Post-hoc exclusion for 2nd field
                accept_mask2 = (~passdf2['rejected']) & (passdf2['chunked'] < 2)
                passdf2['excluded_for_precess'] = ~accept_mask2
                precessdf2, precessangle2, precessR2, _ = get_single_precessdf(passdf2, precesser, precess_filter, neuro_keys_dict,
                                                                               field_d=field_d2, kappa=kappa_precess, bins=None)
                fitted_precessdf2 = precessdf2[precessdf2['fitted']].reset_index(drop=True)
                if (precessangle2 is not None) and (fitted_precessdf2.shape[0] > 0):

                    # Post-hoc precession exclusion
                    _, _, postdoc_dens2 = compute_precessangle(pass_angles=fitted_precessdf2['mean_anglesp'].to_numpy(),
                                                               pass_nspikes=fitted_precessdf2['pass_nspikes'].to_numpy(),
                                                               precess_mask=fitted_precessdf2['precess_exist'].to_numpy(),
                                                               kappa=None, bins=aedges_precess)
                    (_, passbins_p2, passbins_np2, _) = postdoc_dens2
                    all_passbins2 = passbins_p2 + passbins_np2
                    numpass_at_precess2 = get_numpass_at_angle(target_angle=precessangle2, aedge=aedges_precess,
                                                               all_passbins=all_passbins2)
                else:
                    numpass_at_precess2 = None

                # # Paired field processing
                # Construct pairedpasses
                mask_union = mask1 | mask2
                field_d_union = np.sqrt(mask_union.sum()/np.pi)*2
                tok_pair, _ = tunner.get_idin(mask_union, xaxis1, yaxis1)
                pairedpasses = tunner.construct_pairfield_passdf(tok_pair, tok1, tok2, tsp1, tsp2, interpolater_x,
                                                                 interpolater_y, interpolater_angle)

                # Phase lags
                phase_finder = ThetaEstimator(0.005, 0.3, [5, 12])
                AB_tsp1_list, BA_tsp1_list = [], []
                AB_tsp2_list, BA_tsp2_list = [], []
                nspikes_AB_list, nspikes_BA_list = [], []
                duration_AB_list, duration_BA_list = [], []
                t_all = []
                passangles_all, x_all, y_all = [], [], []
                paired_tsp_list = []


                accepted_df = pairedpasses[(~pairedpasses['rejected']) & (pairedpasses['chunked']<2)].reset_index(drop=True)
                # accepted_df = pairedpasses[(~pairedpasses['rejected'])].reset_index(drop=True)
                for npass in range(accepted_df.shape[0]):

                    # Get data
                    t, tsp1, tsp2 = accepted_df.loc[npass, ['t', 'tsp1', 'tsp2']]
                    x, y, pass_angles, v, direction = accepted_df.loc[npass, ['x', 'y', 'angle', 'v', 'direction']]
                    duration = t.max() - t.min()


                    # Find paired spikes
                    pairidx1, pairidx2 = find_pair_times(tsp1, tsp2)
                    paired_tsp1, paired_tsp2 = tsp1[pairidx1], tsp2[pairidx2]
                    if (paired_tsp1.shape[0] < 1) and (paired_tsp2.shape[0] < 1):
                        continue
                    paired_tsp_eachpass = np.concatenate([paired_tsp1, paired_tsp2])
                    paired_tsp_list.append(paired_tsp_eachpass)
                    passangles_all.append(pass_angles)
                    x_all.append(x)
                    y_all.append(y)
                    t_all.append(t)
                    if direction == 'A->B':
                        AB_tsp1_list.append(tsp1)
                        AB_tsp2_list.append(tsp2)
                        nspikes_AB_list.append(tsp1.shape[0] + tsp2.shape[0])
                        duration_AB_list.append(duration)

                    elif direction == 'B->A':
                        BA_tsp1_list.append(tsp1)
                        BA_tsp2_list.append(tsp2)
                        nspikes_BA_list.append(tsp1.shape[0] + tsp2.shape[0])
                        duration_BA_list.append(duration)

                # Phase lags
                thetaT_AB, phaselag_AB, corr_info_AB = phase_finder.find_theta_isi_hilbert(AB_tsp1_list, AB_tsp2_list)
                thetaT_BA, phaselag_BA, corr_info_BA = phase_finder.find_theta_isi_hilbert(BA_tsp1_list, BA_tsp2_list)

                # Pair precession
                neuro_keys_dict1 = dict(tsp='tsp1', spikev='spike1v', spikex='spike1x', spikey='spike1y',
                                        spikeangle='spike1angle')
                neuro_keys_dict2 = dict(tsp='tsp2', spikev='spike2v', spikex='spike2x', spikey='spike2y',
                                        spikeangle='spike2angle')



                accept_mask = (~pairedpasses['rejected']) & (pairedpasses['chunked']<2) & ((pairedpasses['direction']=='A->B')| (pairedpasses['direction']=='B->A'))

                pairedpasses['excluded_for_precess'] = ~accept_mask
                precess_dfp = precesser.get_single_precession(pairedpasses, neuro_keys_dict1, field_d_union, tag='1')
                precess_dfp = precesser.get_single_precession(precess_dfp, neuro_keys_dict2, field_d_union, tag='2')
                precess_dfp = precess_filter.filter_pair(precess_dfp)
                fitted_precess_dfp = precess_dfp[precess_dfp['fitted1'] & precess_dfp['fitted2']].reset_index(drop=True)

                # Paired spikes
                if (len(paired_tsp_list) == 0) or (len(passangles_all) == 0):
                    continue
                hd_pair = np.concatenate(passangles_all)
                x_pair, y_pair = np.concatenate(x_all), np.concatenate(y_all)
                pos_pair = np.stack([x_pair, y_pair]).T
                paired_tsp = np.concatenate(paired_tsp_list)
                paired_tsp = paired_tsp[(paired_tsp <= all_maxt) & (paired_tsp >= all_mint)]
                if paired_tsp.shape[0] < 1:
                    continue
                num_spikes_pair = paired_tsp.shape[0]
                hdsp_pair = interpolater_angle(paired_tsp)
                xsp_pair = interpolater_x(paired_tsp)
                ysp_pair = interpolater_y(paired_tsp)
                possp_pair = np.stack([xsp_pair, ysp_pair]).T
                aver_rate_pair = num_spikes_pair / (x_pair.shape[0] * dt)

                # Pair Directionality
                occbinsp, _ = np.histogram(hd_pair, bins=aedges)
                spbinsp, _ = np.histogram(hdsp_pair, bins=aedges)
                mlmer_pair = DirectionerMLM(pos_pair, hd_pair, dt, sp_binwidth, abind)
                rate_anglep, rate_Rp, normprobp_mlm = mlmer_pair.get_directionality(possp_pair, hdsp_pair)
                normprobp_mlm[np.isnan(normprobp_mlm)] = 0

                # Time shift shuffling
                if np.isnan(rate_Rp):
                    rate_R_pvalp = np.nan
                else:
                    rate_R_pvalp = timeshift_shuffle_exp_wrapper(paired_tsp_list, t_all, rate_Rp,
                                                                 NShuffles, mlmer_pair,
                                                                 interpolater_x, interpolater_y,
                                                                 interpolater_angle, trange)

                # Rates
                with np.errstate(divide='ignore', invalid='ignore'):  # None means no sample
                    rate_AB = np.sum(nspikes_AB_list) / np.sum(duration_AB_list)
                    rate_BA = np.sum(nspikes_BA_list) / np.sum(duration_BA_list)
                    corate = np.sum(nspikes_AB_list + nspikes_BA_list) / np.sum(duration_AB_list + duration_BA_list)
                    pair_rate = num_spikes_pair / np.sum(duration_AB_list + duration_BA_list)

                # KLD
                kld = calc_kld(normprob1_mlm, normprob2_mlm, normprobp_mlm)



                pairdata_dict['pair_id'].append(pair_id)
                pairdata_dict['ca'].append(ca)
                pairdata_dict['overlap'].append(ks_dist)
                pairdata_dict['border1'].append(border1)
                pairdata_dict['border2'].append(border2)
                pairdata_dict['area1'].append(area1)
                pairdata_dict['area2'].append(area2)
                pairdata_dict['xyval1'].append(xyval1)
                pairdata_dict['xyval2'].append(xyval2)

                pairdata_dict['aver_rate1'].append(aver_rate1)
                pairdata_dict['aver_rate2'].append(aver_rate2)
                pairdata_dict['aver_rate_pair'].append(aver_rate_pair)
                pairdata_dict['peak_rate1'].append(peak_rate1)
                pairdata_dict['peak_rate2'].append(peak_rate2)
                pairdata_dict['minocc1'].append(minocc1)
                pairdata_dict['minocc2'].append(minocc2)
                pairdata_dict['fieldcoor1'].append(fcoor1)
                pairdata_dict['fieldcoor2'].append(fcoor2)
                pairdata_dict['com1'].append(com1)
                pairdata_dict['com2'].append(com2)

                pairdata_dict['rate_angle1'].append(rate_angle1)
                pairdata_dict['rate_angle2'].append(rate_angle2)
                pairdata_dict['rate_anglep'].append(rate_anglep)
                pairdata_dict['rate_R1'].append(rate_R1)
                pairdata_dict['rate_R2'].append(rate_R2)
                pairdata_dict['rate_Rp'].append(rate_Rp)

                pairdata_dict['num_spikes1'].append(nspks1)
                pairdata_dict['num_spikes2'].append(nspks2)
                pairdata_dict['num_spikes_pair'].append(num_spikes_pair)

                pairdata_dict['phaselag_AB'].append(phaselag_AB)
                pairdata_dict['phaselag_BA'].append(phaselag_BA)
                pairdata_dict['corr_info_AB'].append(corr_info_AB)
                pairdata_dict['corr_info_BA'].append(corr_info_BA)
                pairdata_dict['thetaT_AB'].append(thetaT_AB)
                pairdata_dict['thetaT_BA'].append(thetaT_BA)

                pairdata_dict['rate_AB'].append(rate_AB)
                pairdata_dict['rate_BA'].append(rate_BA)
                pairdata_dict['corate'].append(corate)
                pairdata_dict['pair_rate'].append(pair_rate)
                pairdata_dict['kld'].append(kld)
                pairdata_dict['rate_R_pvalp'].append(rate_R_pvalp)

                pairdata_dict['precess_df1'].append(fitted_precessdf1)
                pairdata_dict['precess_angle1'].append(precessangle1)
                pairdata_dict['precess_R1'].append(precessR1)
                pairdata_dict['numpass_at_precess1'].append(numpass_at_precess1)
                pairdata_dict['precess_df2'].append(fitted_precessdf2)
                pairdata_dict['precess_angle2'].append(precessangle2)
                pairdata_dict['precess_R2'].append(precessR2)
                pairdata_dict['numpass_at_precess2'].append(numpass_at_precess2)
                pairdata_dict['precess_dfp'].append(fitted_precess_dfp)

                pair_id += 1


    pairdata = pd.DataFrame(pairdata_dict)
    pairdata = append_extrinsicity(pairdata)
    pairdata.to_pickle(save_pth)
    return pairdata
def plot_pair_examples(df, vthresh=5, sthresh=3, plot_dir=None):


    def select_cases(ntrial, ca, npair):

        if (ntrial==0) and (ca== 'CA1') and (npair==9):
            return 'kld', 1
        elif (ntrial==7) and (ca== 'CA1') and (npair==22):
            return 'kld', 0

        elif (ntrial==123) and (ca== 'CA1') and (npair==0):  # 628
            return 'eg', 14
        elif (ntrial==9) and (ca== 'CA1') and (npair==1):  # 140
            return 'eg', 4
        elif (ntrial==56) and (ca== 'CA1') and (npair==1):  # 334
            return 'eg', 8
        elif (ntrial==73) and (ca== 'CA1') and (npair==2):  # 394
            return 'eg', 12
        elif (ntrial==17) and (ca== 'CA2') and (npair==4):  # 256
            return 'eg', 0
        elif (ntrial==18) and (ca== 'CA2') and (npair==4):  # 263
            return 'eg', 6
        elif (ntrial==26) and (ca== 'CA3') and (npair==1):  # 299
            return 'eg', 10
        elif (ntrial==21) and (ca== 'CA3') and (npair==2):  # 283
            return 'eg', 2
        else:
            return None, None

    all_ntrials = [0, 7, 123, 9, 56, 73, 17, 18, 26, 21]


    # Paired spikes
    figw_pairedsp = total_figw*0.8
    figh_pairedsp = figw_pairedsp/5


    # Pair eg
    figw_paireg = total_figw*0.8
    figh_paireg = figw_paireg/4 * 1.1
    fig_paireg = plt.figure(figsize=(figw_paireg, figh_paireg))
    ax_paireg = np.array([
        fig_paireg.add_subplot(2, 8, 1), fig_paireg.add_subplot(2, 8, 2, polar=True),
         fig_paireg.add_subplot(2, 8, 3), fig_paireg.add_subplot(2, 8, 4, polar=True),
         fig_paireg.add_subplot(2, 8, 5), fig_paireg.add_subplot(2, 8, 6, polar=True),
         fig_paireg.add_subplot(2, 8, 7), fig_paireg.add_subplot(2, 8, 8, polar=True),
         fig_paireg.add_subplot(2, 8, 9), fig_paireg.add_subplot(2, 8, 10, polar=True),
         fig_paireg.add_subplot(2, 8, 11), fig_paireg.add_subplot(2, 8, 12, polar=True),
         fig_paireg.add_subplot(2, 8, 13), fig_paireg.add_subplot(2, 8, 14, polar=True),
         fig_paireg.add_subplot(2, 8, 15), fig_paireg.add_subplot(2, 8, 16, polar=True),
    ])

    # KLD
    figw_kld = total_figw*0.9/2  # leave 0.2 for colorbar in fig 5
    figh_kld = total_figw*0.9/4
    fig_kld, ax_kld = plt.subplots(2, 4, figsize=(figw_kld, figh_kld), subplot_kw={'polar':True})


    num_trials = df.shape[0]
    aedges = np.linspace(-np.pi, np.pi, 36)
    aedm = midedges(aedges)
    abind = aedges[1] - aedges[0]
    sp_binwidth = 5
    precess_filter = PrecessionFilter()
    for ntrial in range(num_trials):
        if ntrial not in all_ntrials:
            continue
        wave = df.loc[ntrial, 'wave']
        precesser = PrecessionProcesser(sthresh=sthresh, vthresh=vthresh, wave=wave)


        for ca in ['CA%d' % (i + 1) for i in range(3)]:

            # Get data
            pair_df = df.loc[ntrial, ca + 'pairs']
            field_df = df.loc[ntrial, ca + 'fields']
            indata = df.loc[ntrial, ca + 'indata']
            if (indata.shape[0] < 1) & (pair_df.shape[0] < 1) & (field_df.shape[0] < 1):
                continue

            tunner = IndataProcessor(indata, vthresh=vthresh, smooth=True)
            interpolater_angle = interp1d(tunner.t, tunner.angle)
            interpolater_x = interp1d(tunner.t, tunner.x)
            interpolater_y = interp1d(tunner.t, tunner.y)
            all_maxt, all_mint = tunner.t.max(), tunner.t.min()
            trange = (all_maxt, all_mint)
            dt = tunner.t[1] - tunner.t[0]
            precesser.set_trange(trange)

            ##  Loop for pairs
            num_pairs = pair_df.shape[0]
            for npair in range(num_pairs):

                case, case_axid = select_cases(ntrial, ca, npair)
                if case is None:
                    continue

                print('trial %d, %s, pair %d, case=%s' % (ntrial, ca, npair, case))

                # find within-mask indexes
                field_ids = pair_df.loc[npair, 'fi'] - 1  # minus 1 to convert to python index
                mask1 = field_df.loc[field_ids[0], 'mask']
                mask2 = field_df.loc[field_ids[1], 'mask']

                # field's boundaries
                xyval1 = field_df.loc[field_ids[0], 'xyval']
                xyval2 = field_df.loc[field_ids[1], 'xyval']

                # Find overlap
                pf1 = field_df.loc[field_ids[0], 'pf']
                pf2 = field_df.loc[field_ids[1], 'pf']
                _, ks_dist, _ = dist_overlap(pf1['map'], pf2['map'], mask1, mask2)

                # Field's center coordinates
                maskedmap1, maskedmap2 = pf1['map'] * mask1, pf2['map'] * mask2
                cooridx1 = np.unravel_index(maskedmap1.argmax(), maskedmap1.shape)
                cooridx2 = np.unravel_index(maskedmap2.argmax(), maskedmap2.shape)
                fcoor1 = np.array([pf1['X'][cooridx1[0], cooridx1[1]], pf1['Y'][cooridx1[0], cooridx1[1]]])
                fcoor2 = np.array([pf2['X'][cooridx2[0], cooridx2[1]], pf2['Y'][cooridx2[0], cooridx2[1]]])


                # get single fields' statistics
                passdf1 = field_df.loc[field_ids[0], 'passes']
                passdf2 = field_df.loc[field_ids[1], 'passes']

                (x1list, y1list, t1list, angle1list), tsp1list = append_info_from_passes(passdf1, vthresh, sthresh,
                                                                                         trange)
                (x2list, y2list, t2list, angle2list), tsp2list = append_info_from_passes(passdf2, vthresh, sthresh,
                                                                                         trange)


                if (len(t1list) < 1) or (len(t2list) < 1) or (len(tsp1list) < 1) or (len(tsp2list) < 1):
                    continue

                x1, x2 = np.concatenate(x1list), np.concatenate(x2list)
                y1, y2 = np.concatenate(y1list), np.concatenate(y2list)
                hd1, hd2 = np.concatenate(angle1list), np.concatenate(angle2list)
                pos1, pos2 = np.stack([x1, y1]).T, np.stack([x2, y2]).T

                tsp1, tsp2 = np.concatenate(tsp1list), np.concatenate(tsp2list)
                xsp1, xsp2 = interpolater_x(tsp1), interpolater_x(tsp2)
                ysp1, ysp2 = interpolater_y(tsp1), interpolater_y(tsp2)
                possp1, possp2 = np.stack([xsp1, ysp1]).T, np.stack([xsp2, ysp2]).T
                hdsp1, hdsp2 = interpolater_angle(tsp1), interpolater_angle(tsp2)
                nspks1, nspks2 = tsp1.shape[0], tsp2.shape[0]


                # Directionality
                biner1 = DirectionerBining(aedges, hd1)
                fangle1, fR1, (spbins1, occbins1, normprob1) = biner1.get_directionality(hdsp1)
                mlmer1 = DirectionerMLM(pos1, hd1, dt=dt, sp_binwidth=sp_binwidth, a_binwidth=abind)
                fangle1_mlm, fR1_mlm, normprob1_mlm = mlmer1.get_directionality(possp1, hdsp1)
                normprob1_mlm[np.isnan(normprob1_mlm)] = 0
                biner2 = DirectionerBining(aedges, hd2)
                fangle2, fR2, (spbins2, occbins2, normprob2) = biner2.get_directionality(hdsp2)
                mlmer2 = DirectionerMLM(pos2, hd2, dt=dt, sp_binwidth=sp_binwidth, a_binwidth=abind)
                fangle2_mlm, fR2_mlm, normprob2_mlm = mlmer2.get_directionality(possp2, hdsp2)
                normprob2_mlm[np.isnan(normprob2_mlm)] = 0

                # Pass-based phaselag, extrinsic/intrinsic
                phase_finder = ThetaEstimator(0.005, 0.3, [5, 12])
                AB_tsp1_list, BA_tsp1_list = [], []
                AB_tsp2_list, BA_tsp2_list = [], []
                nspikes_AB_list, nspikes_BA_list = [], []
                duration_AB_list, duration_BA_list = [], []
                t_all = []
                passangles_all, x_all, y_all = [], [], []
                paired_tsp_list = []

                # Precession per pair
                neuro_keys_dict1 = dict(tsp='tsp1', spikev='spike1v', spikex='spike1x', spikey='spike1y',
                                        spikeangle='spike1angle')
                neuro_keys_dict2 = dict(tsp='tsp2', spikev='spike2v', spikex='spike2x', spikey='spike2y',
                                        spikeangle='spike2angle')

                pass_dict_keys = precesser._gen_precess_infokeys()
                pass_dict1 = precesser.gen_precess_dict(tag='1')
                pass_dict2 = precesser.gen_precess_dict(tag='2')
                pass_dictp = {**pass_dict1, **pass_dict2, **{'direction':[]}}


                pairedpasses = pair_df.loc[npair, 'pairedpasses']
                num_passes = pairedpasses.shape[0]



                for npass in range(num_passes):

                    # Get spikes
                    tsp1, tsp2, vsp1, vsp2 = pairedpasses.loc[npass, ['tsp1', 'tsp2', 'spike1v', 'spike2v']]
                    x, y, pass_angles, v = pairedpasses.loc[npass, ['x', 'y', 'angle', 'v']]

                    # Straightness
                    straightrank = compute_straightness(pass_angles)
                    if straightrank < sthresh:
                        continue

                    # Find direction
                    infield1, infield2, t = pairedpasses.loc[npass, ['infield1', 'infield2', 't']]
                    loc, direction = ThetaEstimator.find_direction(infield1, infield2)
                    duration = t.max() - t.min()

                    # Speed threshold
                    passmask = v > vthresh
                    spmask1, spmask2 = vsp1 > vthresh, vsp2 > vthresh
                    xinv, yinv, pass_angles_inv, tinv = x[passmask], y[passmask], pass_angles[passmask], t[passmask]
                    tsp1_inv, tsp2_inv = tsp1[spmask1], tsp2[spmask2]

                    # Find paired spikes
                    pairidx1, pairidx2 = find_pair_times(tsp1_inv, tsp2_inv)
                    paired_tsp1, paired_tsp2 = tsp1_inv[pairidx1], tsp2_inv[pairidx2]
                    if (paired_tsp1.shape[0] < 1) and (paired_tsp2.shape[0] < 1):
                        continue
                    paired_tsp_eachpass = np.concatenate([paired_tsp1, paired_tsp2])
                    paired_tsp_list.append(paired_tsp_eachpass)

                    # Get pass info
                    passangles_all.append(pass_angles_inv)
                    x_all.append(xinv)
                    y_all.append(yinv)
                    t_all.append(tinv)

                    if direction == 'A->B':
                        AB_tsp1_list.append(tsp1_inv)
                        AB_tsp2_list.append(tsp2_inv)
                        nspikes_AB_list.append(tsp1_inv.shape[0] + tsp2_inv.shape[0])
                        duration_AB_list.append(duration)

                    elif direction == 'B->A':
                        BA_tsp1_list.append(tsp1_inv)
                        BA_tsp2_list.append(tsp2_inv)
                        nspikes_BA_list.append(tsp1_inv.shape[0] + tsp2_inv.shape[0])
                        duration_BA_list.append(duration)

                    if (direction == 'A->B') or (direction == 'B->A'):
                        precess1 = precesser._get_precession(pairedpasses, npass, neuro_keys_dict1)
                        precess2 = precesser._get_precession(pairedpasses, npass, neuro_keys_dict2)
                        if (precess1 is None) or (precess2 is None):
                            continue
                        else:
                            pass_dictp = precesser.append_pass_dict(pass_dictp, precess1, tag='1')
                            pass_dictp = precesser.append_pass_dict(pass_dictp, precess2, tag='2')
                            pass_dictp['direction'].append(direction)



                    ############## Plot paired spikes examples ##############
                    if (ntrial==26) and (ca=='CA3') and (npair==1) and (npass==10):

                        if (tsp1_inv.shape[0] != 0) or (tsp2_inv.shape[0] != 0):
                            mintsp_plt = np.min(np.concatenate([tsp1_inv, tsp2_inv]))
                            tsp1_inv = tsp1_inv - mintsp_plt
                            tsp2_inv = tsp2_inv - mintsp_plt
                            tmp_idx1, tmp_idx2 = find_pair_times(tsp1_inv, tsp2_inv)
                            pairedsp1, pairedsp2 = tsp1_inv[tmp_idx1], tsp2_inv[tmp_idx2]

                            fig_pairsp, ax_pairsp = plt.subplots(figsize=(figw_pairedsp, figh_pairedsp))
                            ax_pairsp.eventplot(tsp1_inv, color='k', lineoffsets=0, linelengths=1, linewidths=0.75)
                            ax_pairsp.eventplot(tsp2_inv, color='k', lineoffsets=1, linelengths=1, linewidths=0.75)
                            ax_pairsp.eventplot(pairedsp1, color='darkorange', lineoffsets=0, linelengths=1, linewidths=0.75)
                            ax_pairsp.eventplot(pairedsp2, color='darkorange', lineoffsets=1, linelengths=1, linewidths=0.75)
                            ax_pairsp.set_yticks([0, 1])
                            ax_pairsp.set_yticklabels(['Field A', 'Field B'])
                            ax_pairsp.set_ylim(-0.7, 1.7)
                            ax_pairsp.tick_params(labelsize=ticksize)
                            ax_pairsp.set_xlabel('t (s)', fontsize=fontsize)
                            ax_pairsp.xaxis.set_label_coords(1, -0.075)
                            ax_pairsp.spines['left'].set_visible(False)
                            ax_pairsp.spines['right'].set_visible(False)
                            ax_pairsp.spines['top'].set_visible(False)

                            fig_pairsp.tight_layout()
                            fig_pairsp.savefig(os.path.join(plot_dir, 'example_pairedspikes.%s' % (figext)), dpi=dpi)
                            # fig_pairsp.savefig(os.path.join(plot_dir, 'examples_pairspikes', 'trial-%d_%s_pair-%d_pass-%d.%s' % (ntrial, ca, npair, npass, figext)), dpi=dpi)


                # Paired spikes
                if (len(paired_tsp_list) == 0) or (len(passangles_all) == 0):
                    continue
                hd_pair = np.concatenate(passangles_all)
                x_pair, y_pair = np.concatenate(x_all), np.concatenate(y_all)
                pos_pair = np.stack([x_pair, y_pair]).T

                paired_tsp = np.concatenate(paired_tsp_list)
                paired_tsp = paired_tsp[(paired_tsp <= all_maxt) & (paired_tsp >= all_mint)]
                if paired_tsp.shape[0] < 1:
                    continue
                num_spikes_pair = paired_tsp.shape[0]

                hdsp_pair = interpolater_angle(paired_tsp)
                xsp_pair = interpolater_x(paired_tsp)
                ysp_pair = interpolater_y(paired_tsp)
                possp_pair = np.stack([xsp_pair, ysp_pair]).T

                # Pair Directionality
                biner_pair = DirectionerBining(aedges, hd_pair)
                fanglep, fRp, (spbinsp, occbinsp, normprobp) = biner_pair.get_directionality(hdsp_pair)
                mlmer_pair = DirectionerMLM(pos_pair, hd_pair, dt, sp_binwidth, abind)
                fanglep_mlm, fRp_mlm, normprobp_mlm = mlmer_pair.get_directionality(possp_pair, hdsp_pair)
                normprobp_mlm[np.isnan(normprobp_mlm)] = 0

                # KLD
                kld_mlm = calc_kld(normprob1_mlm, normprob2_mlm, normprobp_mlm)

                ############## Plot pair examples ##############
                if case == 'eg':
                    ax_paireg[case_axid].plot(tunner.x, tunner.y, c='0.8', linewidth=0.75)
                    ax_paireg[case_axid].plot(xyval1[:, 0], xyval1[:, 1], c='k', linewidth=1)
                    ax_paireg[case_axid].plot(xyval2[:, 0], xyval2[:, 1], c='k', linewidth=1)

                    ax_paireg[case_axid].axis('off')
                    xp, yp = circular_density_1d(aedm, 20 * np.pi, 60, (-np.pi, np.pi), w=normprobp_mlm)
                    ax_paireg[case_axid+1] = directionality_polar_plot(ax_paireg[case_axid+1], xp, yp, fanglep_mlm, linewidth=0.75)

                    ax_paireg[case_axid].text(0.6, -0.05, '%0.2f' % (ks_dist), fontsize=legendsize, c='k', transform=ax_paireg[case_axid].transAxes)
                    ax_paireg[case_axid+1].text(0.6, -0.05, '%0.2f' % (fRp_mlm), fontsize=legendsize, c='k', transform=ax_paireg[case_axid+1].transAxes)

                ############## Plot KLD examples ##############
                if case == 'kld':
                    x1, y1 = circular_density_1d(aedm, 20 * np.pi, 60, (-np.pi, np.pi), w=normprob1_mlm)
                    x2, y2 = circular_density_1d(aedm, 20 * np.pi, 60, (-np.pi, np.pi), w=normprob2_mlm)
                    xp, yp = circular_density_1d(aedm, 20 * np.pi, 60, (-np.pi, np.pi), w=normprobp_mlm)

                    indep_prob = normprob1_mlm * normprob2_mlm
                    indep_prob = indep_prob / np.sum(indep_prob)
                    indep_angle = circmean(aedm, w=indep_prob, d=abind)
                    xindep, yindep = circular_density_1d(aedm, 20 * np.pi, 60, (-np.pi, np.pi), w=indep_prob)

                    kld_linewidth = 0.75
                    ax_kld[case_axid, 0] = directionality_polar_plot(ax_kld[case_axid, 0], x1, y1, fangle1_mlm, linewidth=kld_linewidth)
                    ax_kld[case_axid, 1] = directionality_polar_plot(ax_kld[case_axid, 1], x2, y2, fangle2_mlm, linewidth=kld_linewidth)
                    ax_kld[case_axid, 2] = directionality_polar_plot(ax_kld[case_axid, 2], xindep, yindep, indep_angle, linewidth=kld_linewidth)
                    ax_kld[case_axid, 3] = directionality_polar_plot(ax_kld[case_axid, 3], xp, yp, fanglep_mlm, linewidth=kld_linewidth)


                    kldtext_x, kldtext_y = 0, -0.2
                    ax_kld[case_axid, 3].text(kldtext_x, kldtext_y, 'KLD=%0.2f' % (kld_mlm), fontsize=legendsize, transform=ax_kld[case_axid, 3].transAxes)
def plot_placefield_examples(rawdf, save_dir=None):

    example_dir = join(save_dir, 'example_fields2')
    os.makedirs(example_dir)

    field_figl = total_figw / 8
    field_linew = 0.75
    field_ms = 1
    warnings.filterwarnings("ignore")

    # Parameters
    vthresh = 5
    sthresh = 3
    num_trials = rawdf.shape[0]
    aedges = np.linspace(-np.pi, np.pi, 36)
    aedm = midedges(aedges)
    abind = aedges[1] - aedges[0]
    sp_binwidth = 5
    precess_filter = PrecessionFilter()

    # Plot color wheel
    fig_ch, ax_ch = color_wheel('hsv', (field_figl, field_figl))
    fig_ch.savefig(os.path.join(save_dir, 'colorwheel.%s' % (figext)),
                   transparent=True,
                   dpi=dpi)

    # Selected Examples
    example_list = [
        'CA1-field16',
        'CA2-field11',
        'CA3-field378',  # Precession
        'CA1-field151',
        'CA1-field222',
        'CA1-field389',
        'CA1-field400',
        'CA1-field409',
        'CA2-field28',
        'CA2-field47',
        'CA3-field75'
    ]
    fieldid = dict(CA1=0, CA2=0, CA3=0)
    fieldid_skip = dict(CA1=0, CA2=0, CA3=0)

    for ntrial in range(num_trials):

        wave = rawdf.loc[ntrial, 'wave']
        precesser = PrecessionProcesser(wave=wave)
        for ca in ['CA%d' % i for i in range(1, 4)]:

            # Get data
            field_df = rawdf.loc[ntrial, ca + 'fields']
            indata = rawdf.loc[ntrial, ca + 'indata']
            if indata.shape[0] < 1:
                continue
            trajx, trajy = indata['x'].to_numpy(), indata['y'].to_numpy()
            num_fields = field_df.shape[0]
            tunner = IndataProcessor(indata,
                                     vthresh=5,
                                     sthresh=3,
                                     minpasstime=0.4,
                                     smooth=None)
            interpolater_angle = interp1d(tunner.t, tunner.angle)
            interpolater_x = interp1d(tunner.t, tunner.x)
            interpolater_y = interp1d(tunner.t, tunner.y)
            all_maxt, all_mint = tunner.t.max(), tunner.t.min()
            trange = (all_maxt, all_mint)
            dt = tunner.t[1] - tunner.t[0]
            precesser.set_trange(trange)
            for nfield in range(num_fields):
                if ('%s-field%d' % (ca, fieldid[ca])) not in example_list:
                    fieldid[ca] += 1
                    continue
                print('Plotting fields: Trial %d/%d %s field %d/%d' %
                      (ntrial, num_trials, ca, nfield, num_fields))

                # Get field spikes
                xytsp, xyval = field_df.loc[nfield, ['xytsp', 'xyval']]
                tsp, xsp, ysp = xytsp['tsp'], xytsp['xsp'], xytsp['ysp']
                pf = field_df.loc[nfield, 'pf']

                # Get spike angles
                insession = np.where((tsp < all_maxt) & (tsp > all_mint))[0]
                tspins, xspins, yspins = tsp[insession], xsp[insession], ysp[
                    insession]
                anglespins = interpolater_angle(tspins)

                # Directionality
                passdf = field_df.loc[nfield, 'passes']
                (xl, yl, tl,
                 hdl), tspl = append_info_from_passes(passdf, vthresh, sthresh,
                                                      trange)
                if (len(tspl) < 1) or (len(tl) < 1):
                    # fieldid[ca] += 1
                    fieldid_skip[ca] += 1
                    continue

                x = np.concatenate(xl)
                y = np.concatenate(yl)
                pos = np.stack([x, y]).T
                hd = np.concatenate(hdl)
                tsp = np.concatenate(tspl)
                xsp, ysp = interpolater_x(tsp), interpolater_y(tsp)
                possp = np.stack([xsp, ysp]).T
                hdsp = interpolater_angle(tsp)

                # Directionality
                biner = DirectionerBining(aedges, hd)
                fieldangle, fieldR, (spike_bins, occ_bins,
                                     _) = biner.get_directionality(hdsp)
                mlmer = DirectionerMLM(pos, hd, dt, sp_binwidth, abind)
                fieldangle, fieldR, norm_prob = mlmer.get_directionality(
                    possp, hdsp)
                norm_prob[np.isnan(norm_prob)] = 0

                # # (Plot) Place field Example
                fig2 = plt.figure(figsize=(field_figl * 2, field_figl))
                peak_rate = pf['map'].max()
                ax_field2 = fig2.add_subplot(1, 2, 1)
                ax_field2.plot(trajx, trajy, c='0.8', linewidth=0.25)
                ax_field2.plot(xyval[:, 0],
                               xyval[:, 1],
                               c='k',
                               zorder=3,
                               linewidth=field_linew)
                ax_field2.scatter(xspins,
                                  yspins,
                                  c=anglespins,
                                  marker='.',
                                  cmap='hsv',
                                  s=field_ms,
                                  zorder=2.5)
                ax_field2.axis('off')

                x_new, y_new = circular_density_1d(aedm,
                                                   30 * np.pi,
                                                   100, (-np.pi, np.pi),
                                                   w=norm_prob)
                l = y_new.max()
                ax_polar = fig2.add_axes([0.33, 0.3, 0.6, 0.6], polar=True)
                ax_polar.plot(x_new,
                              y_new,
                              c='0.3',
                              linewidth=field_linew,
                              zorder=2.1)
                ax_polar.plot([x_new[-1], x_new[0]], [y_new[-1], y_new[0]],
                              c='0.3',
                              linewidth=field_linew,
                              zorder=2.1)
                # ax_polar.plot([fieldangle, fieldangle], [0, l], c='k', linewidth=field_linew)
                ax_polar.annotate("",
                                  xy=(fieldangle, l),
                                  xytext=(0, 0),
                                  color='k',
                                  zorder=3,
                                  arrowprops=dict(arrowstyle="->"))
                ax_polar.annotate(r'$\theta_{rate}$',
                                  xy=(fieldangle, l),
                                  fontsize=legendsize)
                ax_polar.spines['polar'].set_visible(False)
                ax_polar.set_xticks([0, np.pi / 2, np.pi, 3 * np.pi / 2])
                ax_polar.set_yticks([])
                ax_polar.set_yticklabels([])
                ax_polar.set_xticklabels([])

                basey = -0.3
                ax_polar.annotate('%0.2f' % (peak_rate),
                                  xy=(0.75, 0.05),
                                  color='k',
                                  zorder=3,
                                  fontsize=legendsize,
                                  xycoords='figure fraction')
                ax_polar.annotate('%0.2f' % (fieldR),
                                  xy=(0.75, 0.175),
                                  color='k',
                                  zorder=3,
                                  fontsize=legendsize,
                                  xycoords='figure fraction')
                # ax_polar.text(0.5, basey, '%0.2f' % (peak_rate), fontsize=legendsize, transform=ax_polar.transAxes)
                # ax_polar.text(0.5, basey + 0.25, '%0.2f' % (fieldR), fontsize=legendsize, transform=ax_polar.transAxes)

                fig2.tight_layout()
                fig2.savefig(os.path.join(example_dir, '%s-field%d.png' %
                                          (ca, fieldid[ca])),
                             dpi=dpi)
                fig2.savefig(os.path.join(example_dir, '%s-field%d.eps' %
                                          (ca, fieldid[ca])),
                             dpi=dpi)
                plt.close()

                fieldid[ca] += 1
def single_field_preprocess_networks(simdata,
                                     radius=10,
                                     vthresh=2,
                                     sthresh=80,
                                     NShuffles=200,
                                     subsample_fraction=1,
                                     save_pth=None):
    """

    Parameters
    ----------
    Indata
    SpikeData
    NeuronPos
    radius
    vthresh : float
        Default = 5. Same as Emily's data.
    sthresh : float
        Default = 80. Determined by the same percentile (10%) of passes excluded in Emily's data (sthresh = 3 there).
    subsample_fraction : float
        The fraction that the spikes would be subsampled. =1 if no subsampling is needed

    Returns
    -------

    """

    datadict = dict(num_spikes=[],
                    border=[],
                    aver_rate=[],
                    peak_rate=[],
                    fieldangle_mlm=[],
                    fieldR_mlm=[],
                    spike_bins=[],
                    occ_bins=[],
                    shift_pval_mlm=[],
                    precess_df=[],
                    precess_angle=[],
                    precess_angle_low=[],
                    precess_R=[])

    Indata, SpikeData, NeuronPos = simdata['Indata'], simdata[
        'SpikeData'], simdata['NeuronPos']
    aedges = np.linspace(-np.pi, np.pi, 36)
    abind = aedges[1] - aedges[0]
    sp_binwidth = 5

    tunner = IndataProcessor(Indata, vthresh, True)
    wave = dict(tax=Indata['t'].to_numpy(),
                phase=Indata['phase'].to_numpy(),
                theta=np.ones(Indata.shape[0]))

    interpolater_angle = interp1d(tunner.t, tunner.angle)
    interpolater_x = interp1d(tunner.t, tunner.x)
    interpolater_y = interp1d(tunner.t, tunner.y)
    trange = (tunner.t.max(), tunner.t.min())
    dt = tunner.t[1] - tunner.t[0]
    precesser = PrecessionProcesser(sthresh=sthresh,
                                    vthresh=vthresh,
                                    wave=wave)
    precesser.set_trange(trange)
    precess_filter = PrecessionFilter()
    nspikes_stats = (13, 40)
    pass_nspikes = []

    num_neurons = NeuronPos.shape[0]

    for nidx in range(num_neurons):
        print('%d/%d Neuron' % (nidx, num_neurons))
        # Get spike indexes + subsample
        spdf = SpikeData[SpikeData['neuronidx'] == nidx].reset_index(drop=True)
        tidxsp = spdf['tidxsp'].to_numpy().astype(int)
        if subsample_fraction < 1:
            np.random.seed(nidx)
            sampled_tidxsp = np.random.choice(tidxsp.shape[0] - 1,
                                              int(tidxsp.shape[0] *
                                                  subsample_fraction),
                                              replace=False)
            sampled_tidxsp.sort()
            tidxsp = tidxsp[sampled_tidxsp]

        # Get tok
        neuronx, neurony = NeuronPos.loc[nidx, ['neuronx', 'neurony']]
        dist = np.sqrt((neuronx - tunner.x)**2 + (neurony - tunner.y)**2)
        tok = dist < radius

        # Check border
        border = check_border_sim(neuronx, neurony, radius, (-40, 40))

        # Get info from passes
        all_passidx = segment_passes(tok)
        passdf = construct_passdf_sim(tunner, all_passidx, tidxsp)
        beh_info, all_tsp_list = append_info_from_passes(
            passdf, vthresh, sthresh, trange)
        (all_x_list, all_y_list, all_t_list, all_passangles_list) = beh_info
        if (len(all_tsp_list) < 1) or (len(all_x_list) < 1):
            continue
        all_x = np.concatenate(all_x_list)
        all_y = np.concatenate(all_y_list)
        all_passangles = np.concatenate(all_passangles_list)
        all_tsp = np.concatenate(all_tsp_list)
        all_anglesp = interpolater_angle(all_tsp)
        xsp, ysp = interpolater_x(all_tsp), interpolater_y(all_tsp)
        pos = np.stack([all_x, all_y]).T
        possp = np.stack([xsp, ysp]).T

        # Average firing rate
        aver_rate = all_tsp.shape[0] / (all_x.shape[0] * dt)

        # Field's directionality
        num_spikes = all_tsp.shape[0]
        occ_bins, _ = np.histogram(all_passangles, bins=aedges)
        spike_bins, _ = np.histogram(all_anglesp, bins=aedges)
        mlmer = DirectionerMLM(pos,
                               all_passangles,
                               dt,
                               sp_binwidth=sp_binwidth,
                               a_binwidth=abind)
        fieldangle_mlm, fieldR_mlm, norm_prob_mlm = mlmer.get_directionality(
            possp, all_anglesp)

        # Precession per pass
        neuro_keys_dict = dict(tsp='tsp',
                               spikev='spikev',
                               spikex='spikex',
                               spikey='spikey',
                               spikeangle='spikeangle')
        precessdf, precess_angle, precess_R, _ = get_single_precessdf(
            passdf, precesser, precess_filter, neuro_keys_dict)

        filtered_precessdf = precessdf[precessdf['precess_exist']].reset_index(
            drop=True)

        pass_nspikes = pass_nspikes + list(filtered_precessdf['pass_nspikes'])

        # Precession - pass with low/high spike number
        if precessdf.shape[0] > 0:
            ldf = precessdf[precessdf['pass_nspikes'] <
                            nspikes_stats[0]]  # 25% quantile
            if (ldf.shape[0] > 0) and (ldf['precess_exist'].sum() > 0):
                passangle_l, pass_nspikes_l = ldf['spike_angle'].to_numpy(
                ), ldf['pass_nspikes'].to_numpy()
                precess_mask_l = ldf['precess_exist'].to_numpy()
                precess_angle_low, _, _ = compute_precessangle(
                    passangle_l, pass_nspikes_l, precess_mask_l)
            else:
                precess_angle_low = None
        else:
            precess_angle_low = None

        # Time shift shuffling
        shi_pval_mlm = timeshift_shuffle_exp_wrapper(
            all_tsp_list, all_t_list, fieldR_mlm, NShuffles, mlmer,
            interpolater_x, interpolater_y, interpolater_angle, trange)

        datadict['border'].append(border)
        datadict['num_spikes'].append(num_spikes)
        datadict['aver_rate'].append(aver_rate)
        datadict['peak_rate'].append(None)
        datadict['fieldangle_mlm'].append(fieldangle_mlm)
        datadict['fieldR_mlm'].append(fieldR_mlm)
        datadict['spike_bins'].append(spike_bins)
        datadict['occ_bins'].append(occ_bins)
        datadict['shift_pval_mlm'].append(shi_pval_mlm)
        datadict['precess_df'].append(precessdf)
        datadict['precess_angle'].append(precess_angle)
        datadict['precess_angle_low'].append(precess_angle_low)
        datadict['precess_R'].append(precess_R)

    print('Num spikes:\n')
    print(
        pd.DataFrame({'pass_nspikes':
                      pass_nspikes})['pass_nspikes'].describe())
    datadf = pd.DataFrame(datadict)

    datadf.to_pickle(save_pth)
def single_field_preprocess_exp(df,
                                vthresh=5,
                                sthresh=3,
                                NShuffles=200,
                                save_pth=None):
    fielddf_dict = dict(ca=[],
                        num_spikes=[],
                        border=[],
                        aver_rate=[],
                        peak_rate=[],
                        rate_angle=[],
                        rate_R=[],
                        rate_R_pval=[],
                        minocc=[],
                        field_area=[],
                        field_bound=[],
                        precess_df=[],
                        precess_angle=[],
                        precess_angle_low=[],
                        precess_R=[],
                        precess_R_pval=[],
                        numpass_at_precess=[],
                        numpass_at_precess_low=[])

    num_trials = df.shape[0]
    aedges = np.linspace(-np.pi, np.pi, 36)
    abind = aedges[1] - aedges[0]
    sp_binwidth = 5

    aedges_precess = np.linspace(-np.pi, np.pi, 6)
    kappa_precess = 1
    precess_filter = PrecessionFilter()
    nspikes_stats = {
        'CA1': 6,
        'CA2': 6,
        'CA3': 7
    }  # 25% quantile of precessing passes
    for ntrial in range(num_trials):

        wave = df.loc[ntrial, 'wave']
        precesser = PrecessionProcesser(wave=wave)

        for ca in ['CA%d' % (i + 1) for i in range(3)]:

            # Get data
            field_df = df.loc[ntrial, ca + 'fields']
            indata = df.loc[ntrial, ca + 'indata']
            num_fields = field_df.shape[0]
            if indata.shape[0] < 1:
                continue

            tunner = IndataProcessor(indata,
                                     vthresh=vthresh,
                                     sthresh=sthresh,
                                     minpasstime=0.4)
            interpolater_angle = interp1d(tunner.t, tunner.angle)
            interpolater_x = interp1d(tunner.t, tunner.x)
            interpolater_y = interp1d(tunner.t, tunner.y)
            trange = (tunner.t.max(), tunner.t.min())
            dt = tunner.t[1] - tunner.t[0]
            precesser.set_trange(trange)

            for nf in range(num_fields):
                print('%d/%d trial, %s, %d/%d field' %
                      (ntrial, num_trials, ca, nf, num_fields))
                # Get field info
                mask, pf, xyval = field_df.loc[nf, ['mask', 'pf', 'xyval']]
                tsp = field_df.loc[nf, 'xytsp']['tsp']
                xaxis, yaxis = pf['X'][:, 0], pf['Y'][0, :]
                field_area = np.sum(mask)
                field_d = np.sqrt(field_area / np.pi) * 2
                border = check_border(mask, margin=2)

                # Construct passes (segment & chunk)
                tok, idin = tunner.get_idin(mask, xaxis, yaxis)
                passdf = tunner.construct_singlefield_passdf(
                    tok, tsp, interpolater_x, interpolater_y,
                    interpolater_angle)
                allchunk_df = passdf[(~passdf['rejected'])
                                     & (passdf['chunked'] < 2)].reset_index(
                                         drop=True)
                # allchunk_df = passdf[(~passdf['rejected'])].reset_index(drop=True)

                # Get info from passdf and interpolate
                if allchunk_df.shape[0] < 1:
                    continue
                all_x_list, all_y_list = allchunk_df['x'].to_list(
                ), allchunk_df['y'].to_list()
                all_t_list, all_passangles_list = allchunk_df['t'].to_list(
                ), allchunk_df['angle'].to_list()
                all_tsp_list, all_chunked_list = allchunk_df['tsp'].to_list(
                ), allchunk_df['chunked'].to_list()
                all_x = np.concatenate(all_x_list)
                all_y = np.concatenate(all_y_list)
                all_passangles = np.concatenate(all_passangles_list)
                all_tsp = np.concatenate(all_tsp_list)
                all_anglesp = np.concatenate(
                    allchunk_df['spikeangle'].to_list())
                xsp, ysp = np.concatenate(
                    allchunk_df['spikex'].to_list()), np.concatenate(
                        allchunk_df['spikey'].to_list())
                pos = np.stack([all_x, all_y]).T
                possp = np.stack([xsp, ysp]).T

                # Average firing rate
                aver_rate = all_tsp.shape[0] / (all_x.shape[0] * dt)
                peak_rate = np.max(field_df.loc[nf, 'pf']['map'] * mask)

                # Field's directionality - need angle, anglesp, pos
                num_spikes = all_tsp.shape[0]
                occ_bins, _ = np.histogram(all_passangles, bins=aedges)
                minocc = occ_bins.min()
                mlmer = DirectionerMLM(pos,
                                       all_passangles,
                                       dt,
                                       sp_binwidth=sp_binwidth,
                                       a_binwidth=abind)
                rate_angle, rate_R, norm_prob_mlm = mlmer.get_directionality(
                    possp, all_anglesp)

                # Time shift shuffling for rate directionality
                if np.isnan(rate_R):
                    rate_R_pval = np.nan
                else:
                    rate_R_pval = timeshift_shuffle_exp_wrapper(
                        all_tsp_list, all_t_list, rate_R, NShuffles, mlmer,
                        interpolater_x, interpolater_y, interpolater_angle,
                        trange)

                # Precession per pass
                neuro_keys_dict = dict(tsp='tsp',
                                       spikev='spikev',
                                       spikex='spikex',
                                       spikey='spikey',
                                       spikeangle='spikeangle')
                accept_mask = (~passdf['rejected']) & (passdf['chunked'] < 2)
                passdf['excluded_for_precess'] = ~accept_mask
                precessdf, precess_angle, precess_R, _ = get_single_precessdf(
                    passdf,
                    precesser,
                    precess_filter,
                    neuro_keys_dict,
                    field_d=field_d,
                    kappa=kappa_precess,
                    bins=None)
                fitted_precessdf = precessdf[precessdf['fitted']].reset_index(
                    drop=True)
                # Proceed only if precession exists
                if (precess_angle is not None) and (
                        fitted_precessdf['precess_exist'].sum() > 0):

                    # Post-hoc precession exclusion
                    _, binR, postdoc_dens = compute_precessangle(
                        pass_angles=fitted_precessdf['mean_anglesp'].to_numpy(
                        ),
                        pass_nspikes=fitted_precessdf['pass_nspikes'].to_numpy(
                        ),
                        precess_mask=fitted_precessdf['precess_exist'].
                        to_numpy(),
                        kappa=None,
                        bins=aedges_precess)
                    (_, passbins_p, passbins_np, _) = postdoc_dens
                    all_passbins = passbins_p + passbins_np
                    numpass_at_precess = get_numpass_at_angle(
                        target_angle=precess_angle,
                        aedge=aedges_precess,
                        all_passbins=all_passbins)

                    # Precession - low-spike passes
                    ldf = fitted_precessdf[fitted_precessdf['pass_nspikes'] <
                                           nspikes_stats[ca]]  # 25% quantile
                    if (ldf.shape[0] > 0) and (ldf['precess_exist'].sum() > 0):
                        precess_angle_low, _, _ = compute_precessangle(
                            pass_angles=ldf['mean_anglesp'].to_numpy(),
                            pass_nspikes=ldf['pass_nspikes'].to_numpy(),
                            precess_mask=ldf['precess_exist'].to_numpy(),
                            kappa=kappa_precess,
                            bins=None)
                        _, _, postdoc_dens_low = compute_precessangle(
                            pass_angles=ldf['mean_anglesp'].to_numpy(),
                            pass_nspikes=ldf['pass_nspikes'].to_numpy(),
                            precess_mask=ldf['precess_exist'].to_numpy(),
                            kappa=None,
                            bins=aedges_precess)
                        (_, passbins_p_low, passbins_np_low,
                         _) = postdoc_dens_low
                        all_passbins_low = passbins_p_low + passbins_np_low
                        numpass_at_precess_low = get_numpass_at_angle(
                            target_angle=precess_angle_low,
                            aedge=aedges_precess,
                            all_passbins=all_passbins_low)
                    else:
                        precess_angle_low = None
                        numpass_at_precess_low = None

                    # # Precession - time shift shuffle
                    # psr = PassShuffler(all_t_list, all_tsp_list)
                    # all_shuffled_R = np.zeros(NShuffles)
                    # shufi = 0
                    # fail_times = 0
                    # repeated_times = 0
                    # tmpt = time.time()
                    # while (shufi < NShuffles) & (fail_times < NShuffles):
                    #
                    #     # Re-construct passdf
                    #     shifted_tsp_boxes = psr.timeshift_shuffle(seed=shufi+fail_times, return_concat=False)
                    #
                    #
                    #     shifted_spikex_boxes, shifted_spikey_boxes = [], []
                    #     shifted_spikeangle_boxes, shifted_rejected_boxes = [], []
                    #
                    #     for boxidx, shuffled_tsp_box in enumerate(shifted_tsp_boxes):
                    #         shuffled_tsp_box = shuffled_tsp_box[(shuffled_tsp_box < trange[0]) & (shuffled_tsp_box > trange[1])]
                    #         shifted_spikex_boxes.append(interpolater_x(shuffled_tsp_box))
                    #         shifted_spikey_boxes.append(interpolater_y(shuffled_tsp_box))
                    #         shifted_spikeangle_boxes.append(interpolater_angle(shuffled_tsp_box))
                    #         rejected, _ = tunner.rejection_singlefield(shuffled_tsp_box, all_t_list[boxidx], all_passangles_list[boxidx])
                    #         shifted_rejected_boxes.append(rejected)
                    #     shuffled_passdf = pd.DataFrame({'x':all_x_list, 'y':all_y_list, 't':all_t_list, 'angle':all_passangles_list,
                    #                                     'chunked':all_chunked_list, 'rejected':shifted_rejected_boxes,
                    #                                     neuro_keys_dict['tsp']:shifted_tsp_boxes, neuro_keys_dict['spikex']:shifted_spikex_boxes,
                    #                                     neuro_keys_dict['spikey']:shifted_spikey_boxes, neuro_keys_dict['spikeangle']:shifted_spikeangle_boxes})
                    #     shuffled_accept_mask = (~shuffled_passdf['rejected']) & (shuffled_passdf['chunked'] < 2)
                    #     shuffled_passdf['excluded_for_precess'] = ~shuffled_accept_mask
                    #     shuf_precessdf, shuf_precessangle, shuf_precessR, _ = get_single_precessdf(shuffled_passdf, precesser, precess_filter, neuro_keys_dict, occ_bins.min(),
                    #                                                                                field_d=field_d, kappa=kappa_precess, bins=None)
                    #
                    #     if (shuf_precessdf.shape[0] > 0) and (shuf_precessR is not None):
                    #         all_shuffled_R[shufi] = shuf_precessR
                    #         shufi += 1
                    #     else:
                    #         fail_times += 1
                    # precess_R_pval = 1 - np.nanmean(precess_R > all_shuffled_R)
                    # precess_shuffletime = time.time()-tmpt
                    # print('Shuf %0.4f - %0.4f - %0.4f, Target %0.4f'%(np.quantile(all_shuffled_R, 0.25), np.quantile(all_shuffled_R, 0.5), np.quantile(all_shuffled_R, 0.75), precess_R))
                    # print('Best precess=%0.2f, pval=%0.5f, time=%0.2f, failed=%d, repeated=%d'%(precess_angle, precess_R_pval, precess_shuffletime, fail_times, repeated_times))

                    precess_R_pval = 1

                else:
                    numpass_at_precess = None
                    precess_angle_low = None
                    numpass_at_precess_low = None
                    precess_R_pval = None

                fielddf_dict['ca'].append(ca)
                fielddf_dict['num_spikes'].append(num_spikes)
                fielddf_dict['field_area'].append(field_area)
                fielddf_dict['field_bound'].append(xyval)
                fielddf_dict['border'].append(border)
                fielddf_dict['aver_rate'].append(aver_rate)
                fielddf_dict['peak_rate'].append(peak_rate)
                fielddf_dict['rate_angle'].append(rate_angle)
                fielddf_dict['rate_R'].append(rate_R)
                fielddf_dict['rate_R_pval'].append(rate_R_pval)
                fielddf_dict['minocc'].append(minocc)
                fielddf_dict['precess_df'].append(fitted_precessdf)
                fielddf_dict['precess_angle'].append(precess_angle)
                fielddf_dict['precess_angle_low'].append(precess_angle_low)
                fielddf_dict['precess_R'].append(precess_R)
                fielddf_dict['precess_R_pval'].append(precess_R_pval)
                fielddf_dict['numpass_at_precess'].append(numpass_at_precess)
                fielddf_dict['numpass_at_precess_low'].append(
                    numpass_at_precess_low)

                # tmpdf = pd.DataFrame(dict(ca=fielddf_dict['ca'], pval=fielddf_dict['precess_R_pval']))
                # for catmp, cadftmp in tmpdf.groupby('ca'):
                #     nonnan_count = cadftmp[~cadftmp['pval'].isna()].shape[0]
                #     if nonnan_count ==0:
                #         nonnan_count = 1
                #     sig_count = cadftmp[cadftmp['pval'] < 0.05].shape[0]
                #     print('%s: ALL %d/%d=%0.2f, Among Precess %d/%d=%0.2f'%(catmp, sig_count, cadftmp.shape[0],
                #                                                             sig_count/cadftmp.shape[0], sig_count,
                #                                                             nonnan_count, sig_count/nonnan_count))

    fielddf_raw = pd.DataFrame(fielddf_dict)

    # Assign field ids within ca
    fielddf_raw['fieldid_ca'] = 0
    for ca, cadf in fielddf_raw.groupby('ca'):
        fieldidca = np.arange(cadf.shape[0]) + 1
        index_ca = cadf.index
        fielddf_raw.loc[index_ca, 'fieldid_ca'] = fieldidca

    fielddf_raw.to_pickle(save_pth)
    return fielddf_raw
def pair_field_preprocess_Romani(simdata,
                                 save_pth,
                                 radius=2,
                                 vthresh=2,
                                 sthresh=80,
                                 NShuffles=200):
    pairdata_dict = dict(
        neuron1id=[],
        neuron2id=[],
        neuron1pos=[],
        neuron2pos=[],
        neurondist=[],
        # overlap is calculated afterward
        border1=[],
        border2=[],
        aver_rate1=[],
        aver_rate2=[],
        aver_rate_pair=[],
        com1=[],
        com2=[],
        rate_angle1=[],
        rate_angle2=[],
        rate_anglep=[],
        rate_R1=[],
        rate_R2=[],
        rate_Rp=[],
        num_spikes1=[],
        num_spikes2=[],
        num_spikes_pair=[],
        phaselag_AB=[],
        phaselag_BA=[],
        corr_info_AB=[],
        corr_info_BA=[],
        rate_AB=[],
        rate_BA=[],
        corate=[],
        pair_rate=[],
        kld=[],
        rate_R_pvalp=[],
        precess_df1=[],
        precess_angle1=[],
        precess_R1=[],
        precess_df2=[],
        precess_angle2=[],
        precess_R2=[],
        numpass_at_precess1=[],
        numpass_at_precess2=[],
        precess_dfp=[])

    Indata, SpikeData, NeuronPos = simdata['Indata'], simdata[
        'SpikeData'], simdata['NeuronPos']
    wave = dict(tax=Indata['t'].to_numpy(),
                phase=Indata['phase'].to_numpy(),
                theta=np.ones(Indata.shape[0]))
    subsample_fraction = 0.25
    # setting
    minpasstime = 0.4
    minspiketresh = 14
    default_T = 1 / 10
    aedges = np.linspace(-np.pi, np.pi, 36)
    abind = aedges[1] - aedges[0]
    sp_binwidth = 2 * np.pi / 16
    precesser = PrecessionProcesser(wave=wave)
    precess_filter = PrecessionFilter()
    kappa_precess = 1
    aedges_precess = np.linspace(-np.pi, np.pi, 6)
    field_d = radius * 2

    # Precomputation
    tunner = IndataProcessor(Indata,
                             vthresh=vthresh,
                             sthresh=sthresh,
                             minpasstime=minpasstime)
    interpolater_angle = interp1d(tunner.t, tunner.angle)
    interpolater_x = interp1d(tunner.t, tunner.x)
    interpolater_y = interp1d(tunner.t, tunner.y)
    all_maxt, all_mint = tunner.t.max(), tunner.t.min()
    trange = (all_maxt, all_mint)
    dt = tunner.t[1] - tunner.t[0]
    precesser.set_trange(trange)

    # Combination of pairs
    np.random.seed(0)
    sampled_idx_all = np.random.choice(NeuronPos.shape[0], 200,
                                       replace=False).astype(int)
    samples1 = sampled_idx_all[0:80]
    samples2 = sampled_idx_all[80:160]
    total_num = samples1.shape[0] * samples2.shape[0]

    progress = 0
    for i in samples1:
        for j in samples2:
            print(
                '%d/%d (%d,%d): sample %d' %
                (progress, total_num, i, j, len(pairdata_dict['neuron1pos'])))

            # Distance
            neu1x, neu1y = NeuronPos.iloc[i]
            neu2x, neu2y = NeuronPos.iloc[j]
            neurondist = np.sqrt((neu1x - neu2x)**2 + (neu1y - neu2y)**2)

            # Get tok
            tok1 = np.sqrt((tunner.x - neu1x)**2 +
                           (tunner.y - neu1y)**2) < radius
            tok2 = np.sqrt((tunner.x - neu2x)**2 +
                           (tunner.y - neu2y)**2) < radius
            tok_union = tok1 | tok2
            tok_intersect = tok1 & tok2
            if np.sum(tok_intersect) < 10:  # anything larger than 0 will do
                progress += 1
                continue

            # Get spike indexes + subsample
            spdf1 = SpikeData[SpikeData['neuronidx'] == i].reset_index(
                drop=True)
            spdf2 = SpikeData[SpikeData['neuronidx'] == j].reset_index(
                drop=True)
            tidxsp1, tidxsp2 = spdf1['tidxsp'].to_numpy().astype(
                int), spdf2['tidxsp'].to_numpy().astype(int)
            np.random.seed(i * j)
            sampled_tidxsp1 = np.random.choice(tidxsp1.shape[0] - 1,
                                               int(tidxsp1.shape[0] *
                                                   subsample_fraction),
                                               replace=False)
            sampled_tidxsp1.sort()
            tidxsp1 = tidxsp1[sampled_tidxsp1]
            np.random.seed(i * j + 1)
            sampled_tidxsp2 = np.random.choice(tidxsp2.shape[0] - 1,
                                               int(tidxsp2.shape[0] *
                                                   subsample_fraction),
                                               replace=False)
            sampled_tidxsp2.sort()
            tidxsp2 = tidxsp2[sampled_tidxsp2]
            tsp1 = tunner.t[tidxsp1]
            tsp2 = tunner.t[tidxsp2]

            # Condition: Paired spikes exist
            tsp_diff = pair_diff(tsp1, tsp2)
            spok = np.sum(np.abs(tsp_diff) < default_T)
            if spok < minspiketresh:
                progress = progress + 1
                continue

            # Check border
            border1 = check_border_sim(neu1x, neu1y, radius, (0, 2 * np.pi))
            border2 = check_border_sim(neu2x, neu2y, radius, (0, 2 * np.pi))

            # Construct passes (segment & chunk) for field 1 and 2
            passdf1 = tunner.construct_singlefield_passdf(
                tok1, tsp1, interpolater_x, interpolater_y, interpolater_angle)
            passdf2 = tunner.construct_singlefield_passdf(
                tok2, tsp2, interpolater_x, interpolater_y, interpolater_angle)
            # allchunk_df1 = passdf1[(~passdf1['rejected']) & (passdf1['chunked']<2)].reset_index(drop=True)
            # allchunk_df2 = passdf2[(~passdf2['rejected']) & (passdf2['chunked']<2)].reset_index(drop=True)
            allchunk_df1 = passdf1[(~passdf1['rejected'])].reset_index(
                drop=True)
            allchunk_df2 = passdf2[(~passdf2['rejected'])].reset_index(
                drop=True)
            if (allchunk_df1.shape[0] < 1) or (allchunk_df2.shape[0] < 1):
                continue

            x1list, y1list, angle1list = allchunk_df1['x'].to_list(
            ), allchunk_df1['y'].to_list(), allchunk_df1['angle'].to_list()
            t1list, tsp1list = allchunk_df1['t'].to_list(
            ), allchunk_df1['tsp'].to_list()
            x2list, y2list, angle2list = allchunk_df2['x'].to_list(
            ), allchunk_df2['y'].to_list(), allchunk_df2['angle'].to_list()
            t2list, tsp2list = allchunk_df2['t'].to_list(
            ), allchunk_df2['tsp'].to_list()
            if (len(t1list) < 1) or (len(t2list) < 1) or (
                    len(tsp1list) < 1) or (len(tsp2list) < 1):
                continue
            x1, x2 = np.concatenate(x1list), np.concatenate(x2list)
            y1, y2 = np.concatenate(y1list), np.concatenate(y2list)
            hd1, hd2 = np.concatenate(angle1list), np.concatenate(angle2list)
            pos1, pos2 = np.stack([x1, y1]).T, np.stack([x2, y2]).T

            tsp1, tsp2 = np.concatenate(tsp1list), np.concatenate(tsp2list)
            xsp1, xsp2 = np.concatenate(
                allchunk_df1['spikex'].to_list()), np.concatenate(
                    allchunk_df2['spikex'].to_list())
            ysp1, ysp2 = np.concatenate(
                allchunk_df1['spikey'].to_list()), np.concatenate(
                    allchunk_df2['spikey'].to_list())
            possp1, possp2 = np.stack([xsp1, ysp1]).T, np.stack([xsp2, ysp2]).T
            hdsp1 = np.concatenate(allchunk_df1['spikeangle'].to_list())
            hdsp2 = np.concatenate(allchunk_df2['spikeangle'].to_list())
            nspks1, nspks2 = tsp1.shape[0], tsp2.shape[0]

            # Rates
            aver_rate1 = nspks1 / (x1.shape[0] * dt)
            peak_rate1 = None
            aver_rate2 = nspks2 / (x2.shape[0] * dt)
            peak_rate2 = None

            # Directionality
            occbins1, _ = np.histogram(hd1, bins=aedges)
            spbins1, _ = np.histogram(hdsp1, bins=aedges)
            mlmer1 = DirectionerMLM(pos1,
                                    hd1,
                                    dt=dt,
                                    sp_binwidth=sp_binwidth,
                                    a_binwidth=abind)
            rate_angle1, rate_R1, normprob1_mlm = mlmer1.get_directionality(
                possp1, hdsp1)
            normprob1_mlm[np.isnan(normprob1_mlm)] = 0

            occbins2, _ = np.histogram(hd2, bins=aedges)
            spbins2, _ = np.histogram(hdsp2, bins=aedges)
            mlmer2 = DirectionerMLM(pos2,
                                    hd2,
                                    dt=dt,
                                    sp_binwidth=sp_binwidth,
                                    a_binwidth=abind)
            rate_angle2, rate_R2, normprob2_mlm = mlmer2.get_directionality(
                possp2, hdsp2)
            normprob2_mlm[np.isnan(normprob2_mlm)] = 0
            neuro_keys_dict = dict(tsp='tsp',
                                   spikev='spikev',
                                   spikex='spikex',
                                   spikey='spikey',
                                   spikeangle='spikeangle')

            # Precession1 & Post-hoc exclusion for 1st field
            accept_mask1 = (~passdf1['rejected']) & (passdf1['chunked'] < 2)
            passdf1['excluded_for_precess'] = ~accept_mask1
            precessdf1, precessangle1, precessR1, _ = get_single_precessdf(
                passdf1,
                precesser,
                precess_filter,
                neuro_keys_dict,
                field_d=field_d,
                kappa=kappa_precess,
                bins=None)
            fitted_precessdf1 = precessdf1[precessdf1['fitted']].reset_index(
                drop=True)
            if (precessangle1
                    is not None) and (fitted_precessdf1.shape[0] > 0):

                # Post-hoc precession exclusion
                _, _, postdoc_dens1 = compute_precessangle(
                    pass_angles=fitted_precessdf1['mean_anglesp'].to_numpy(),
                    pass_nspikes=fitted_precessdf1['pass_nspikes'].to_numpy(),
                    precess_mask=fitted_precessdf1['precess_exist'].to_numpy(),
                    kappa=None,
                    bins=aedges_precess)
                (_, passbins_p1, passbins_np1, _) = postdoc_dens1
                all_passbins1 = passbins_p1 + passbins_np1
                numpass_at_precess1 = get_numpass_at_angle(
                    target_angle=precessangle1,
                    aedge=aedges_precess,
                    all_passbins=all_passbins1)
            else:
                numpass_at_precess1 = None

            # Precession2 & Post-hoc exclusion for 2nd field
            accept_mask2 = (~passdf2['rejected']) & (passdf2['chunked'] < 2)
            passdf2['excluded_for_precess'] = ~accept_mask2
            precessdf2, precessangle2, precessR2, _ = get_single_precessdf(
                passdf2,
                precesser,
                precess_filter,
                neuro_keys_dict,
                field_d=field_d,
                kappa=kappa_precess,
                bins=None)
            fitted_precessdf2 = precessdf2[precessdf2['fitted']].reset_index(
                drop=True)
            if (precessangle2
                    is not None) and (fitted_precessdf2.shape[0] > 0):

                # Post-hoc precession exclusion
                _, _, postdoc_dens2 = compute_precessangle(
                    pass_angles=fitted_precessdf2['mean_anglesp'].to_numpy(),
                    pass_nspikes=fitted_precessdf2['pass_nspikes'].to_numpy(),
                    precess_mask=fitted_precessdf2['precess_exist'].to_numpy(),
                    kappa=None,
                    bins=aedges_precess)
                (_, passbins_p2, passbins_np2, _) = postdoc_dens2
                all_passbins2 = passbins_p2 + passbins_np2
                numpass_at_precess2 = get_numpass_at_angle(
                    target_angle=precessangle2,
                    aedge=aedges_precess,
                    all_passbins=all_passbins2)
            else:
                numpass_at_precess2 = None

            # # Paired field processing

            field_d_union = radius * 2

            pairedpasses = tunner.construct_pairfield_passdf(
                tok_union, tok1, tok2, tsp1, tsp2, interpolater_x,
                interpolater_y, interpolater_angle)

            phase_finder = ThetaEstimator(0.005, 0.3, [5, 12])
            AB_tsp1_list, BA_tsp1_list = [], []
            AB_tsp2_list, BA_tsp2_list = [], []
            nspikes_AB_list, nspikes_BA_list = [], []
            duration_AB_list, duration_BA_list = [], []
            t_all = []
            passangles_all, x_all, y_all = [], [], []
            paired_tsp_list = []

            # accepted_df = pairedpasses[(~pairedpasses['rejected']) & (pairedpasses['chunked']<2)].reset_index(drop=True)
            accepted_df = pairedpasses[(
                ~pairedpasses['rejected'])].reset_index(drop=True)
            for npass in range(accepted_df.shape[0]):

                # Get data
                t, tsp1, tsp2 = accepted_df.loc[npass, ['t', 'tsp1', 'tsp2']]
                x, y, pass_angles, v, direction = accepted_df.loc[
                    npass, ['x', 'y', 'angle', 'v', 'direction']]
                duration = t.max() - t.min()

                # Find paired spikes
                pairidx1, pairidx2 = find_pair_times(tsp1, tsp2)
                paired_tsp1, paired_tsp2 = tsp1[pairidx1], tsp2[pairidx2]
                if (paired_tsp1.shape[0] < 1) and (paired_tsp2.shape[0] < 1):
                    continue
                paired_tsp_eachpass = np.concatenate(
                    [paired_tsp1, paired_tsp2])
                paired_tsp_list.append(paired_tsp_eachpass)
                passangles_all.append(pass_angles)
                x_all.append(x)
                y_all.append(y)
                t_all.append(t)
                if direction == 'A->B':
                    AB_tsp1_list.append(tsp1)
                    AB_tsp2_list.append(tsp2)
                    nspikes_AB_list.append(tsp1.shape[0] + tsp2.shape[0])
                    duration_AB_list.append(duration)

                elif direction == 'B->A':
                    BA_tsp1_list.append(tsp1)
                    BA_tsp2_list.append(tsp2)
                    nspikes_BA_list.append(tsp1.shape[0] + tsp2.shape[0])
                    duration_BA_list.append(duration)

            # Phase lags
            thetaT_AB, phaselag_AB, corr_info_AB = phase_finder.find_theta_isi_hilbert(
                AB_tsp1_list, AB_tsp2_list)
            thetaT_BA, phaselag_BA, corr_info_BA = phase_finder.find_theta_isi_hilbert(
                BA_tsp1_list, BA_tsp2_list)

            # Pair precession
            neuro_keys_dict1 = dict(tsp='tsp1',
                                    spikev='spike1v',
                                    spikex='spike1x',
                                    spikey='spike1y',
                                    spikeangle='spike1angle')
            neuro_keys_dict2 = dict(tsp='tsp2',
                                    spikev='spike2v',
                                    spikex='spike2x',
                                    spikey='spike2y',
                                    spikeangle='spike2angle')

            accept_mask = (~pairedpasses['rejected']) & (
                pairedpasses['chunked'] < 2) & (
                    (pairedpasses['direction'] == 'A->B') |
                    (pairedpasses['direction'] == 'B->A'))

            pairedpasses['excluded_for_precess'] = ~accept_mask
            precess_dfp = precesser.get_single_precession(pairedpasses,
                                                          neuro_keys_dict1,
                                                          field_d_union,
                                                          tag='1')
            precess_dfp = precesser.get_single_precession(precess_dfp,
                                                          neuro_keys_dict2,
                                                          field_d_union,
                                                          tag='2')
            precess_dfp = precess_filter.filter_pair(precess_dfp)
            fitted_precess_dfp = precess_dfp[
                precess_dfp['fitted1']
                & precess_dfp['fitted2']].reset_index(drop=True)

            # Paired spikes
            if (len(paired_tsp_list) == 0) or (len(passangles_all) == 0):
                continue
            hd_pair = np.concatenate(passangles_all)
            x_pair, y_pair = np.concatenate(x_all), np.concatenate(y_all)
            pos_pair = np.stack([x_pair, y_pair]).T
            paired_tsp = np.concatenate(paired_tsp_list)
            paired_tsp = paired_tsp[(paired_tsp <= all_maxt)
                                    & (paired_tsp >= all_mint)]
            if paired_tsp.shape[0] < 1:
                continue
            num_spikes_pair = paired_tsp.shape[0]
            hdsp_pair = interpolater_angle(paired_tsp)
            xsp_pair = interpolater_x(paired_tsp)
            ysp_pair = interpolater_y(paired_tsp)
            possp_pair = np.stack([xsp_pair, ysp_pair]).T
            aver_rate_pair = num_spikes_pair / (x_pair.shape[0] * dt)

            # Pair Directionality
            occbinsp, _ = np.histogram(hd_pair, bins=aedges)
            spbinsp, _ = np.histogram(hdsp_pair, bins=aedges)
            mlmer_pair = DirectionerMLM(pos_pair, hd_pair, dt, sp_binwidth,
                                        abind)
            rate_anglep, rate_Rp, normprobp_mlm = mlmer_pair.get_directionality(
                possp_pair, hdsp_pair)
            normprobp_mlm[np.isnan(normprobp_mlm)] = 0

            # Time shift shuffling
            rate_R_pvalp = timeshift_shuffle_exp_wrapper(
                paired_tsp_list, t_all, rate_Rp, NShuffles, mlmer_pair,
                interpolater_x, interpolater_y, interpolater_angle, trange)

            # Rates
            with np.errstate(divide='ignore',
                             invalid='ignore'):  # None means no sample
                rate_AB = np.sum(nspikes_AB_list) / np.sum(duration_AB_list)
                rate_BA = np.sum(nspikes_BA_list) / np.sum(duration_BA_list)
                corate = np.sum(nspikes_AB_list +
                                nspikes_BA_list) / np.sum(duration_AB_list +
                                                          duration_BA_list)
                pair_rate = num_spikes_pair / np.sum(duration_AB_list +
                                                     duration_BA_list)

            # KLD
            kld = calc_kld(normprob1_mlm, normprob2_mlm, normprobp_mlm)

            pairdata_dict['neuron1id'].append(i)
            pairdata_dict['neuron2id'].append(j)
            pairdata_dict['neuron1pos'].append(NeuronPos.iloc[i].to_numpy())
            pairdata_dict['neuron2pos'].append(NeuronPos.iloc[j].to_numpy())
            pairdata_dict['neurondist'].append(neurondist)

            pairdata_dict['border1'].append(border1)
            pairdata_dict['border2'].append(border2)

            pairdata_dict['aver_rate1'].append(aver_rate1)
            pairdata_dict['aver_rate2'].append(aver_rate2)
            pairdata_dict['aver_rate_pair'].append(aver_rate_pair)
            pairdata_dict['com1'].append(NeuronPos.iloc[i].to_numpy())
            pairdata_dict['com2'].append(NeuronPos.iloc[j].to_numpy())

            pairdata_dict['rate_angle1'].append(rate_angle1)
            pairdata_dict['rate_angle2'].append(rate_angle2)
            pairdata_dict['rate_anglep'].append(rate_anglep)
            pairdata_dict['rate_R1'].append(rate_R1)
            pairdata_dict['rate_R2'].append(rate_R2)
            pairdata_dict['rate_Rp'].append(rate_Rp)

            pairdata_dict['num_spikes1'].append(nspks1)
            pairdata_dict['num_spikes2'].append(nspks2)
            pairdata_dict['num_spikes_pair'].append(num_spikes_pair)

            pairdata_dict['phaselag_AB'].append(phaselag_AB)
            pairdata_dict['phaselag_BA'].append(phaselag_BA)
            pairdata_dict['corr_info_AB'].append(corr_info_AB)
            pairdata_dict['corr_info_BA'].append(corr_info_BA)

            pairdata_dict['rate_AB'].append(rate_AB)
            pairdata_dict['rate_BA'].append(rate_BA)
            pairdata_dict['corate'].append(corate)
            pairdata_dict['pair_rate'].append(pair_rate)
            pairdata_dict['kld'].append(kld)
            pairdata_dict['rate_R_pvalp'].append(rate_R_pvalp)

            pairdata_dict['precess_df1'].append(fitted_precessdf1)
            pairdata_dict['precess_angle1'].append(precessangle1)
            pairdata_dict['precess_R1'].append(precessR1)
            pairdata_dict['precess_df2'].append(fitted_precessdf2)
            pairdata_dict['precess_angle2'].append(precessangle2)
            pairdata_dict['precess_R2'].append(precessR2)
            pairdata_dict['numpass_at_precess1'].append(numpass_at_precess1)
            pairdata_dict['numpass_at_precess2'].append(numpass_at_precess2)
            pairdata_dict['precess_dfp'].append(fitted_precess_dfp)
            progress = progress + 1
    pairdata = pd.DataFrame(pairdata_dict)

    # Convert distance to overlap
    dist_range = pairdata['neurondist'].max() - pairdata['neurondist'].min()
    pairdata['overlap'] = (pairdata['neurondist'].max() -
                           pairdata['neurondist']) / dist_range
    pairdata.to_pickle(save_pth)
    pairdata = append_extrinsicity(pairdata)
    pairdata.to_pickle(save_pth)
    return pairdata
def single_field_preprocess_Romani(simdata,
                                   radius=2,
                                   vthresh=2,
                                   sthresh=80,
                                   NShuffles=200,
                                   save_pth=None):
    """

    Parameters
    ----------
    Indata
    SpikeData
    NeuronPos
    radius
    vthresh : float
        Default = 2. Determined by the ratio between avergae speed in Emily's data and the target vthresh 5 there.
    sthresh : float
        Default = 80. Determined by the same percentile (10%) of passes excluded in Emily's data (sthresh = 3 there).
        It is different since straightrank depends on sampling frequency. It is 1ms in sumulation.
    subsample_fraction : float
        The fraction that the spikes would be subsampled.

    Returns
    -------

    """

    subsample_fraction = 0.4
    datadict = dict(num_spikes=[],
                    border=[],
                    aver_rate=[],
                    rate_angle=[],
                    rate_R=[],
                    rate_R_pval=[],
                    minocc=[],
                    precess_df=[],
                    precess_angle=[],
                    precess_angle_low=[],
                    precess_R=[],
                    numpass_at_precess=[],
                    numpass_at_precess_low=[])

    Indata, SpikeData, NeuronPos = simdata['Indata'], simdata[
        'SpikeData'], simdata['NeuronPos']
    aedges = np.linspace(-np.pi, np.pi, 36)
    abind = aedges[1] - aedges[0]
    sp_binwidth = 2 * np.pi / 16
    tunner = IndataProcessor(Indata,
                             vthresh=vthresh,
                             sthresh=sthresh,
                             minpasstime=0.4)
    wave = dict(tax=Indata['t'].to_numpy(),
                phase=Indata['phase'].to_numpy(),
                theta=np.ones(Indata.shape[0]))

    interpolater_angle = interp1d(tunner.t, tunner.angle)
    interpolater_x = interp1d(tunner.t, tunner.x)
    interpolater_y = interp1d(tunner.t, tunner.y)
    trange = (tunner.t.max(), tunner.t.min())
    dt = tunner.t[1] - tunner.t[0]
    precesser = PrecessionProcesser(wave=wave)
    precesser.set_trange(trange)
    precess_filter = PrecessionFilter()
    lowspike_num = 13

    kappa_precess = 1
    aedges_precess = np.linspace(-np.pi, np.pi, 6)

    num_neurons = NeuronPos.shape[0]

    for nidx in range(num_neurons):
        print('%d/%d Neuron' % (nidx, num_neurons))
        # Get spike indexes + subsample
        spdf = SpikeData[SpikeData['neuronidx'] == nidx].reset_index(drop=True)
        tidxsp = spdf['tidxsp'].to_numpy().astype(int)
        np.random.seed(nidx)
        sampled_tidxsp = np.random.choice(tidxsp.shape[0] - 1,
                                          int(tidxsp.shape[0] *
                                              subsample_fraction),
                                          replace=False)
        sampled_tidxsp.sort()
        tidxsp = tidxsp[sampled_tidxsp]
        tsp = Indata.loc[tidxsp, 't'].to_numpy()
        neuron_pos = NeuronPos.iloc[nidx].to_numpy()

        # Check border
        border = check_border_sim(neuron_pos[0], neuron_pos[1], radius,
                                  (0, 2 * np.pi))

        # Construct passdf
        dist = np.sqrt((neuron_pos[0] - tunner.x)**2 +
                       (neuron_pos[1] - tunner.y)**2)
        tok = dist < radius
        passdf = tunner.construct_singlefield_passdf(tok, tsp, interpolater_x,
                                                     interpolater_y,
                                                     interpolater_angle)
        # allchunk_df = passdf[(~passdf['rejected']) & (passdf['chunked']<2)].reset_index(drop=True)
        allchunk_df = passdf[(~passdf['rejected'])].reset_index(drop=True)

        # Get info from passdf and interpolate
        if allchunk_df.shape[0] < 1:
            continue
        all_x_list, all_y_list = allchunk_df['x'].to_list(
        ), allchunk_df['y'].to_list()
        all_t_list, all_passangles_list = allchunk_df['t'].to_list(
        ), allchunk_df['angle'].to_list()
        all_tsp_list, all_chunked_list = allchunk_df['tsp'].to_list(
        ), allchunk_df['chunked'].to_list()
        all_x = np.concatenate(all_x_list)
        all_y = np.concatenate(all_y_list)
        all_passangles = np.concatenate(all_passangles_list)
        all_tsp = np.concatenate(all_tsp_list)
        all_anglesp = np.concatenate(allchunk_df['spikeangle'].to_list())
        xsp, ysp = np.concatenate(
            allchunk_df['spikex'].to_list()), np.concatenate(
                allchunk_df['spikey'].to_list())
        pos = np.stack([all_x, all_y]).T
        possp = np.stack([xsp, ysp]).T

        # Average firing rate
        aver_rate = all_tsp.shape[0] / (all_x.shape[0] * dt)

        # Field's directionality
        num_spikes = all_tsp.shape[0]
        occ_bins, _ = np.histogram(all_passangles, bins=aedges)
        minocc = occ_bins.min()
        mlmer = DirectionerMLM(pos,
                               all_passangles,
                               dt,
                               sp_binwidth=sp_binwidth,
                               a_binwidth=abind)
        rate_angle, rate_R, norm_prob_mlm = mlmer.get_directionality(
            possp, all_anglesp)

        # Time shift shuffling for rate directionality
        rate_R_pval = timeshift_shuffle_exp_wrapper(all_tsp_list, all_t_list,
                                                    rate_R, NShuffles, mlmer,
                                                    interpolater_x,
                                                    interpolater_y,
                                                    interpolater_angle, trange)

        # Precession per pass
        neuro_keys_dict = dict(tsp='tsp',
                               spikev='spikev',
                               spikex='spikex',
                               spikey='spikey',
                               spikeangle='spikeangle')
        accept_mask = (~passdf['rejected']) & (passdf['chunked'] < 2)
        passdf['excluded_for_precess'] = ~accept_mask
        precessdf, precess_angle, precess_R, _ = get_single_precessdf(
            passdf,
            precesser,
            precess_filter,
            neuro_keys_dict,
            field_d=radius * 2,
            kappa=kappa_precess,
            bins=None)
        fitted_precessdf = precessdf[precessdf['fitted']].reset_index(
            drop=True)

        # Proceed only if precession exists
        if (precess_angle
                is not None) and (fitted_precessdf['precess_exist'].sum() > 0):

            # Post-hoc precession exclusion
            _, binR, postdoc_dens = compute_precessangle(
                pass_angles=fitted_precessdf['mean_anglesp'].to_numpy(),
                pass_nspikes=fitted_precessdf['pass_nspikes'].to_numpy(),
                precess_mask=fitted_precessdf['precess_exist'].to_numpy(),
                kappa=None,
                bins=aedges_precess)
            (_, passbins_p, passbins_np, _) = postdoc_dens
            all_passbins = passbins_p + passbins_np
            numpass_at_precess = get_numpass_at_angle(
                target_angle=precess_angle,
                aedge=aedges_precess,
                all_passbins=all_passbins)

            # Precession - low-spike passes
            ldf = fitted_precessdf[fitted_precessdf['pass_nspikes'] <
                                   lowspike_num]  # 25% quantile
            if (ldf.shape[0] > 0) and (ldf['precess_exist'].sum() > 0):
                precess_angle_low, _, _ = compute_precessangle(
                    pass_angles=ldf['mean_anglesp'].to_numpy(),
                    pass_nspikes=ldf['pass_nspikes'].to_numpy(),
                    precess_mask=ldf['precess_exist'].to_numpy(),
                    kappa=kappa_precess,
                    bins=None)
                _, _, postdoc_dens_low = compute_precessangle(
                    pass_angles=ldf['mean_anglesp'].to_numpy(),
                    pass_nspikes=ldf['pass_nspikes'].to_numpy(),
                    precess_mask=ldf['precess_exist'].to_numpy(),
                    kappa=None,
                    bins=aedges_precess)
                (_, passbins_p_low, passbins_np_low, _) = postdoc_dens_low
                all_passbins_low = passbins_p_low + passbins_np_low
                numpass_at_precess_low = get_numpass_at_angle(
                    target_angle=precess_angle_low,
                    aedge=aedges_precess,
                    all_passbins=all_passbins_low)
            else:
                precess_angle_low = None
                numpass_at_precess_low = None

        else:
            numpass_at_precess = None
            precess_angle_low = None
            numpass_at_precess_low = None

        datadict['num_spikes'].append(num_spikes)
        datadict['border'].append(border)
        datadict['aver_rate'].append(aver_rate)
        #
        datadict['rate_angle'].append(rate_angle)
        datadict['rate_R'].append(rate_R)
        datadict['rate_R_pval'].append(rate_R_pval)
        datadict['minocc'].append(minocc)

        datadict['precess_df'].append(precessdf)
        datadict['precess_angle'].append(precess_angle)
        datadict['precess_angle_low'].append(precess_angle_low)
        datadict['precess_R'].append(precess_R)

        datadict['numpass_at_precess'].append(numpass_at_precess)
        datadict['numpass_at_precess_low'].append(numpass_at_precess_low)

    datadf = pd.DataFrame(datadict)

    datadf.to_pickle(save_pth)