示例#1
0
def bin_array(y, nr_bins):
    """Takes circular array and digitises into n bins"""
    y = np.array(y)
    bins = np.arange(-math.pi - 0.000001, math.pi - (2 * math.pi / nr_bins),
                     (2 * math.pi / nr_bins))
    y_bin = np.digitize(pycircstat.cdiff(y, 0), bins)
    return y, y_bin
示例#2
0
def calculate_bias(signed_error, target, other, subid, confidence = None):
    targori = np.radians(target)
    otherori = np.radians(other)
    
    angdiff = circstat.cdiff(otherori, targori)
    
    #demean error to remove any angular biases in overall responding (agnostic to any other data)
    signed_error = np.subtract(signed_error, signed_error.mean())
    
    nbin = 64 #64 bins
    pbin = 0.25 #1/4 fo the data per bin
    
    bins = circ_bini(angdiff, nbin, pbin)
    
    biases    = np.full(shape = nbin, fill_value = np.nan)
    bincentre = np.full(shape = nbin, fill_value = np.nan)
    precs     = np.full(shape = nbin, fill_value = np.nan)
    binid     = np.full(shape = nbin, fill_value = 0)
    
    for i in range(nbin): #loop over bins
        biases[i]     = np.mean(signed_error[bins[i,:]]) #sp.stats.circmean(signed_error[bins[i,:]], high = np.pi, low = -np.pi)
        bincentre[i]  = sp.stats.circmean(angdiff[bins[i,:]], high = np.pi, low = -np.pi)
        precs[i]      = sp.stats.circstd(signed_error[bins[i,:]], high = np.pi, low = -np.pi)
        binid[i]      = i
    
    df = pd.DataFrame()
    df['subid'] = np.full(shape = nbin, fill_value = subid)
    df['bincentre'] = bincentre
    df['bias'] = biases
    df['prec'] = precs
    df['binid'] = binid
    
    return df
    def compute_hilbert_for_cluster(self, this_cluster_name):

        # first, get the eeg for just channels in cluster
        cluster_rows = self.res['clusters'][this_cluster_name].notna()
        cluster_elec_labels = self.res['clusters'][cluster_rows]['label']
        cluster_eeg = self.subject_data[:, np.in1d(self.subject_data.channel, cluster_elec_labels)]

        # bandpass eeg at the mean frequency, making sure the lower frequency isn't too low
        cluster_mean_freq = self.res['clusters'][cluster_rows][this_cluster_name].mean()
        cluster_freq_range = [cluster_mean_freq - self.hilbert_half_range, cluster_mean_freq + self.hilbert_half_range]
        if cluster_freq_range[0] < SubjectTravelingWaveAnalysis.LOWER_MIN_FREQ:
            cluster_freq_range[0] = SubjectTravelingWaveAnalysis.LOWER_MIN_FREQ
        filtered_eeg = RAM_helpers.band_pass_eeg(cluster_eeg, cluster_freq_range)
        filtered_eeg = filtered_eeg.transpose('channel', 'event', 'time')

        # run the hilbert transform
        complex_hilbert_res = hilbert(filtered_eeg.data, N=filtered_eeg.shape[-1], axis=-1)

        # compute the phase of the filtered eeg
        phase_data = filtered_eeg.copy()
        phase_data.data = np.unwrap(np.angle(complex_hilbert_res))

        # compute the power
        power_data = filtered_eeg.copy()
        power_data.data = np.abs(complex_hilbert_res) ** 2

        # compute mean phase and phase difference between ref phase and each electrode phase
        ref_phase = pycircstat.mean(phase_data.data, axis=0)
        phase_data.data = pycircstat.cdiff(phase_data.data, ref_phase)
        return phase_data, power_data, cluster_mean_freq
def plot_marginal_slices(ax, xx, yy, zz, selected_x, slice_yrange, slicegap):
    """
    Marginal slices at different x values, as a function of y-axis
    Parameters
    ----------
    ax : axes object
    xx : ndarray
    yy : ndarray
    zz : ndarray
    selected_x : tuple
    slice_yrange : tuple
    slicegap : float

    Returns
    -------

    """

    # Slices
    yax = yy[:, 0]  # y
    xax = xx[0, :]  # x
    color_list = []
    for idx, xval in enumerate(selected_x):
        selected_xid = np.argmin(np.abs(xval - xax))
        yrange_ids = [
            np.argmin(np.abs(yax - y_bound_val))
            for y_bound_val in slice_yrange
        ]

        slice_cut = zz[yrange_ids[0]:yrange_ids[1], selected_xid]
        slice_norm = slice_cut / np.sum(slice_cut)
        slice_shifted = (slice_norm - slice_norm.min()) + idx * slicegap
        yax_cut = yax[yrange_ids[0]:yrange_ids[1]]

        slice_color = cm.brg(idx / selected_x.shape[0])
        color_list.append(slice_color)
        ax.plot(yax_cut, slice_shifted, c=slice_color)

        y_cog = circmean(yax_cut, w=slice_norm)
        slice_point = slice_shifted[np.argmin(np.abs(cdiff(yax_cut, y_cog)))]
        marker, m_c = '.', 'k'
        ax.scatter(y_cog, slice_point, marker=marker, c=[m_c], zorder=3.1)
        ax.scatter(y_cog - 2 * np.pi,
                   slice_point,
                   marker=marker,
                   c=[m_c],
                   zorder=3.1)
        ax.scatter(y_cog + 2 * np.pi,
                   slice_point,
                   marker=marker,
                   c=[m_c],
                   zorder=3.1)

    return ax, color_list
示例#5
0
    def compute_hilbert_for_cluster(self, this_cluster_name):

        # first, get the eeg for just channels in cluster
        cluster_rows = self.res['clusters'][this_cluster_name].notna()
        cluster_elec_labels = self.res['clusters'][cluster_rows]['label']
        cluster_eeg = self.subject_data[:,
                                        np.in1d(self.subject_data.
                                                channel, cluster_elec_labels)]

        # bandpass eeg at the mean frequency, making sure the lower frequency isn't too low
        cluster_mean_freq = self.res['clusters'][cluster_rows][
            this_cluster_name].mean()
        cluster_freq_range = [
            cluster_mean_freq - self.hilbert_half_range,
            cluster_mean_freq + self.hilbert_half_range
        ]
        if cluster_freq_range[0] < SubjectTravelingWaveAnalysis.LOWER_MIN_FREQ:
            cluster_freq_range[0] = SubjectTravelingWaveAnalysis.LOWER_MIN_FREQ
        filtered_eeg = ecog_helpers.band_pass_eeg(cluster_eeg,
                                                  cluster_freq_range)
        filtered_eeg = filtered_eeg.transpose('channel', 'event', 'time')

        # run the hilbert transform
        complex_hilbert_res = hilbert(filtered_eeg.data,
                                      N=filtered_eeg.shape[-1],
                                      axis=-1)

        # compute the phase of the filtered eeg
        phase_data = filtered_eeg.copy()
        phase_data.data = np.unwrap(np.angle(complex_hilbert_res))

        # compute the power
        power_data = filtered_eeg.copy()
        power_data.data = np.abs(complex_hilbert_res)**2

        # compute mean phase and phase difference between ref phase and each electrode phase
        ref_phase = pycircstat.mean(phase_data.data, axis=0)
        phase_data.data = pycircstat.cdiff(phase_data.data, ref_phase)
        return phase_data, power_data, cluster_mean_freq
示例#6
0
def test_circular_distance():
    a = np.array([4.85065953, 0.79063862, 1.35698570])
    assert_allclose(pycircstat.cdiff(a, a), np.zeros_like(a))
示例#7
0
def compute_asymmetry(data, y=None, reference=None, min_deg=30, max_deg=60):
    """Compute distribution of x by phase-bins in the Instantaneous Frequency.

    Parameters
    ----------
    data : ndarray
        three-dimensional array of [trial repeats by classes by time points]. 
        Correct class is in position n_classes/2 
        
        *alternatively 'data', 'y', and 'reference' can be specified
         as a dictionary
    y : ndarray
        Input vector for presented stimulus in radians  with the same 
        length as trials 
    reference : ndarray
        Input vector for reference stimulus in radians  with the same 
        length as trials 
    min_deg : float
         min angular distance cut-off point for trials to include
    max_deg : float
         max angular distance cut-off point for trials to include

    Returns
    -------
    shift : ndarray
        array containing asymmetry score for every time point.
    CW : ndarray
        array containing evidence for CW angular difference trials
    CCW : ndarray
        array containing evidence for CCW angular difference trials

    """

    # check input
    if type(data) is dict:
        reference = data['reference']
        y = data['y']
        data = data['single_trial_ev_centered']

    if type(data) is np.ndarray:
        if y is None or reference is None:
            raise Exception("Specifiy both y and reference variable")  # check

        if y.max() > (2 * np.pi) or reference.max() > (2 * np.pi):
            # if variables are out of range
            raise Exception("Define y and reference in radians")

    # main body

    n_trials, n_classes, n_tps = data.shape

    angular_difference = pycircstat.cdiff(reference, y)

    # compute evidence for all classes for trials CW from current angle
    CW = data[(angular_difference <= -min_deg / 90 * np.pi) &
              (angular_difference > -max_deg / 90 * np.pi), :, :].mean(0)

    # compute evidence for all classes for trials CCW from current angle
    CCW = data[(angular_difference >= min_deg / 90 * np.pi) &
               (angular_difference < max_deg / 90 * np.pi), :, :].mean(0)

    # sum
    shift = ((CW[0:int(n_classes / 2 - 1), :].mean(0) -
              CW[int(n_classes / 2):(n_classes - 1), :].mean(0)) -
             (CCW[0:int(n_classes / 2 - 1), :].mean(0) -
              CCW[int(n_classes / 2):(n_classes - 1), :].mean(0)))

    return shift, CW, CCW
示例#8
0
    def analysis(self):
        """

        """

        if self.subject_data is None:
            print('%s: compute or load data first with .load_data()!' % self.subject)

        # Get recalled or not labels
        if self.recall_filter_func is None:
            print('%s SME: please provide a .recall_filter_func function.' % self.subject)
        recalled = self.recall_filter_func(self.subject_data)

        # filter to electrodes in ROIs. First get broad electrode region labels
        region_df = self.bin_eloctrodes_into_rois()
        region_df['merged_col'] = region_df['hemi'] + '-' + region_df['region']

        # make sure we have electrodes in each unique region
        for roi in self.roi_list:
            has_elecs = []
            for label in roi:
                if np.any(region_df.merged_col == label):
                    has_elecs.append(True)
            if ~np.any(has_elecs):
                print('{}: no {} electrodes, cannot compute synchrony.'.format(self.subject, roi))
                return

        # then filter into just to ROIs defined above
        elecs_to_use = region_df.merged_col.isin([item for sublist in self.roi_list for item in sublist])
        elec_scheme = self.elec_info.copy(deep=True)
        elec_scheme['ROI'] = region_df.merged_col[elecs_to_use]
        elec_scheme = elec_scheme[elecs_to_use].reset_index()

        if self.use_wavelets:
            phase_data = MorletWaveletFilter(self.subject_data[:, elecs_to_use], self.wavelet_freq,
                                             output='phase', width=5, cpus=12,
                                             verbose=False).filter()
        else:
            # band pass eeg
            phase_data = RAM_helpers.band_pass_eeg(self.subject_data[:, elecs_to_use], self.hilbert_band_pass_range)

            # get phase at each timepoint
            phase_data.data = np.angle(hilbert(phase_data.data, N=phase_data.shape[-1], axis=-1))

        # remove the buffer
        phase_data = phase_data.remove_buffer(self.buf_ms / 1000.)

        # loop over each pair of ROIs
        for region_pair in combinations(self.roi_list, 2):
            elecs_region_1 = np.where(elec_scheme.ROI.isin(region_pair[0]))[0]
            elecs_region_2 = np.where(elec_scheme.ROI.isin(region_pair[1]))[0]

            elec_label_pairs = []
            elec_pair_pvals = []
            elec_pair_zs = []
            elec_pair_rvls = []
            elec_pair_pvals_rec = []
            elec_pair_zs_rec = []
            elec_pair_rvls_rec = []
            elec_pair_pvals_nrec = []
            elec_pair_zs_nrec = []
            elec_pair_rvls_nrec = []
            delta_mem_rayleigh_zscores = []
            delta_mem_rvl_zscores = []

            elec_pair_phase_diffs = []

            # loop over all pairs of electrodes in the ROIs
            for elec_1 in elecs_region_1:
                for elec_2 in elecs_region_2:
                    elec_label_pairs.append([elec_scheme.iloc[elec_1].label, elec_scheme.iloc[elec_2].label])

                    # and take the difference in phase values for this electrode pair
                    elec_pair_phase_diff = pycircstat.cdiff(phase_data[:, elec_1], phase_data[:, elec_2])
                    if self.include_phase_diffs_in_res:
                        elec_pair_phase_diffs.append(elec_pair_phase_diff)

                    # compute the circular stats
                    elec_pair_stats = calc_circ_stats(elec_pair_phase_diff, recalled, do_perm=False)
                    elec_pair_pvals.append(elec_pair_stats['elec_pair_pval'])
                    elec_pair_zs.append(elec_pair_stats['elec_pair_z'])
                    elec_pair_rvls.append(elec_pair_stats['elec_pair_rvl'])
                    elec_pair_pvals_rec.append(elec_pair_stats['elec_pair_pval_rec'])
                    elec_pair_zs_rec.append(elec_pair_stats['elec_pair_z_rec'])
                    elec_pair_pvals_nrec.append(elec_pair_stats['elec_pair_pval_nrec'])
                    elec_pair_zs_nrec.append(elec_pair_stats['elec_pair_z_nrec'])
                    elec_pair_rvls_rec.append(elec_pair_stats['elec_pair_rvl_rec'])
                    elec_pair_rvls_nrec.append(elec_pair_stats['elec_pair_rvl_nrec'])

                    # compute null distributions for the memory stats
                    if self.do_perm_test:
                        delta_mem_rayleigh_zscore, delta_mem_rvl_zscore = self.compute_null_stats(elec_pair_phase_diff,
                                                                                                  recalled,
                                                                                                  elec_pair_stats)
                        delta_mem_rayleigh_zscores.append(delta_mem_rayleigh_zscore)
                        delta_mem_rvl_zscores.append(delta_mem_rvl_zscore)

            region_pair_key = '+'.join(['-'.join(r) for r in region_pair])
            self.res[region_pair_key] = {}
            self.res[region_pair_key]['elec_label_pairs'] = elec_label_pairs
            self.res[region_pair_key]['elec_pair_pvals'] = np.stack(elec_pair_pvals, 0)
            self.res[region_pair_key]['elec_pair_zs'] = np.stack(elec_pair_zs, 0)
            self.res[region_pair_key]['elec_pair_rvls'] = np.stack(elec_pair_rvls, 0)
            self.res[region_pair_key]['elec_pair_pvals_rec'] = np.stack(elec_pair_pvals_rec, 0)
            self.res[region_pair_key]['elec_pair_zs_rec'] = np.stack(elec_pair_zs_rec, 0)
            self.res[region_pair_key]['elec_pair_pvals_nrec'] = np.stack(elec_pair_pvals_nrec, 0)
            self.res[region_pair_key]['elec_pair_zs_nrec'] = np.stack(elec_pair_zs_nrec, 0)
            self.res[region_pair_key]['elec_pair_rvls_rec'] = np.stack(elec_pair_rvls_rec, 0)
            self.res[region_pair_key]['elec_pair_rvls_nrec'] = np.stack(elec_pair_rvls_nrec, 0)
            if self.do_perm_test:
                self.res[region_pair_key]['delta_mem_rayleigh_zscores'] = np.stack(delta_mem_rayleigh_zscores, 0)
                self.res[region_pair_key]['delta_mem_rvl_zscores'] = np.stack(delta_mem_rvl_zscores, 0)
            if self.include_phase_diffs_in_res:
                self.res[region_pair_key]['elec_pair_phase_diffs'] = np.stack(elec_pair_phase_diffs, -1)
            self.res[region_pair_key]['time'] = phase_data.time.data
            self.res[region_pair_key]['recalled'] = recalled
    def analysis(self):
        """
        Runs the phase synchrony analysis.
        """

        if self.subject_data is None:
            print('%s: compute or load data first with .load_data()!' % self.subject)

        # Get recalled or not labels
        if self.recall_filter_func is None:
            print('%s SME: please provide a .recall_filter_func function.' % self.subject)
        recalled = self.recall_filter_func(self.subject_data)

        # filter to electrodes in ROIs. First get broad electrode region labels
        region_df = self.bin_eloctrodes_into_rois()
        region_df['merged_col'] = region_df['hemi'] + '-' + region_df['region']

        # make sure we have electrodes in each unique region
        for roi in self.roi_list:
            has_elecs = []
            for label in roi:
                if np.any(region_df.merged_col == label):
                    has_elecs.append(True)
            if ~np.any(has_elecs):
                print('{}: no {} electrodes, cannot compute synchrony.'.format(self.subject, roi))
                return

        # then filter into just to ROIs defined above
        elecs_to_use = region_df.merged_col.isin([item for sublist in self.roi_list for item in sublist])
        elec_scheme = self.elec_info.copy(deep=True)
        elec_scheme['ROI'] = region_df.merged_col[elecs_to_use]
        elec_scheme = elec_scheme[elecs_to_use].reset_index()

        if self.use_wavelets:
            phase_data = MorletWaveletFilter(self.subject_data[:, elecs_to_use], self.wavelet_freq,
                                             output='phase', width=5, cpus=12,
                                             verbose=False).filter()
        else:
            # band pass eeg
            phase_data = ecog_helpers.band_pass_eeg(self.subject_data[:, elecs_to_use], self.hilbert_band_pass_range)

            # get phase at each timepoint
            phase_data.data = np.angle(hilbert(phase_data.data, N=phase_data.shape[-1], axis=-1))

        # remove the buffer
        phase_data = phase_data.remove_buffer(self.buf_ms / 1000.)

        # loop over each pair of ROIs
        for region_pair in combinations(self.roi_list, 2):
            elecs_region_1 = np.where(elec_scheme.ROI.isin(region_pair[0]))[0]
            elecs_region_2 = np.where(elec_scheme.ROI.isin(region_pair[1]))[0]

            elec_label_pairs = []
            elec_pair_pvals = []
            elec_pair_zs = []
            elec_pair_rvls = []
            elec_pair_pvals_rec = []
            elec_pair_zs_rec = []
            elec_pair_rvls_rec = []
            elec_pair_pvals_nrec = []
            elec_pair_zs_nrec = []
            elec_pair_rvls_nrec = []
            delta_mem_rayleigh_zscores = []
            delta_mem_rvl_zscores = []

            elec_pair_phase_diffs = []

            # loop over all pairs of electrodes in the ROIs
            for elec_1 in elecs_region_1:
                for elec_2 in elecs_region_2:
                    elec_label_pairs.append([elec_scheme.iloc[elec_1].label, elec_scheme.iloc[elec_2].label])

                    # and take the difference in phase values for this electrode pair
                    elec_pair_phase_diff = pycircstat.cdiff(phase_data[:, elec_1], phase_data[:, elec_2])
                    if self.include_phase_diffs_in_res:
                        elec_pair_phase_diffs.append(elec_pair_phase_diff)

                    # compute the circular stats
                    elec_pair_stats = calc_circ_stats(elec_pair_phase_diff, recalled, do_perm=False)
                    elec_pair_pvals.append(elec_pair_stats['elec_pair_pval'])
                    elec_pair_zs.append(elec_pair_stats['elec_pair_z'])
                    elec_pair_rvls.append(elec_pair_stats['elec_pair_rvl'])
                    elec_pair_pvals_rec.append(elec_pair_stats['elec_pair_pval_rec'])
                    elec_pair_zs_rec.append(elec_pair_stats['elec_pair_z_rec'])
                    elec_pair_pvals_nrec.append(elec_pair_stats['elec_pair_pval_nrec'])
                    elec_pair_zs_nrec.append(elec_pair_stats['elec_pair_z_nrec'])
                    elec_pair_rvls_rec.append(elec_pair_stats['elec_pair_rvl_rec'])
                    elec_pair_rvls_nrec.append(elec_pair_stats['elec_pair_rvl_nrec'])

                    # compute null distributions for the memory stats
                    if self.do_perm_test:
                        delta_mem_rayleigh_zscore, delta_mem_rvl_zscore = self.compute_null_stats(elec_pair_phase_diff,
                                                                                                  recalled,
                                                                                                  elec_pair_stats)
                        delta_mem_rayleigh_zscores.append(delta_mem_rayleigh_zscore)
                        delta_mem_rvl_zscores.append(delta_mem_rvl_zscore)

            region_pair_key = '+'.join(['-'.join(r) for r in region_pair])
            self.res[region_pair_key] = {}
            self.res[region_pair_key]['elec_label_pairs'] = elec_label_pairs
            self.res[region_pair_key]['elec_pair_pvals'] = np.stack(elec_pair_pvals, 0)
            self.res[region_pair_key]['elec_pair_zs'] = np.stack(elec_pair_zs, 0)
            self.res[region_pair_key]['elec_pair_rvls'] = np.stack(elec_pair_rvls, 0)
            self.res[region_pair_key]['elec_pair_pvals_rec'] = np.stack(elec_pair_pvals_rec, 0)
            self.res[region_pair_key]['elec_pair_zs_rec'] = np.stack(elec_pair_zs_rec, 0)
            self.res[region_pair_key]['elec_pair_pvals_nrec'] = np.stack(elec_pair_pvals_nrec, 0)
            self.res[region_pair_key]['elec_pair_zs_nrec'] = np.stack(elec_pair_zs_nrec, 0)
            self.res[region_pair_key]['elec_pair_rvls_rec'] = np.stack(elec_pair_rvls_rec, 0)
            self.res[region_pair_key]['elec_pair_rvls_nrec'] = np.stack(elec_pair_rvls_nrec, 0)
            if self.do_perm_test:
                self.res[region_pair_key]['delta_mem_rayleigh_zscores'] = np.stack(delta_mem_rayleigh_zscores, 0)
                self.res[region_pair_key]['delta_mem_rvl_zscores'] = np.stack(delta_mem_rvl_zscores, 0)
            if self.include_phase_diffs_in_res:
                self.res[region_pair_key]['elec_pair_phase_diffs'] = np.stack(elec_pair_phase_diffs, -1)
            self.res[region_pair_key]['time'] = phase_data.time.data
            self.res[region_pair_key]['recalled'] = recalled
def test_circular_distance():
    a = np.array([4.85065953, 0.79063862, 1.35698570])
    assert_allclose(pycircstat.cdiff(a, a), np.zeros_like(a))
示例#11
0
def calculate_bias_covar(signed_error, target, other, subid, nbin = 64, pbin = 0.25, covariate = None, covariate_label = None, median_split = None):
    targori = np.radians(target)
    otherori = np.radians(other)
    
    angdiff = circstat.cdiff(otherori, targori)
    
    #demean error to remove any angular biases in overall responding (agnostic to any other data)
    signed_error = np.subtract(signed_error, signed_error.mean())
    
    # #params for creating data bins
    # nbin = 64 #64 bins
    # pbin = 0.25 #1/4 fo the data per bin
    
    if median_split == True and covariate is not None: #check there's a covar of interest and that we need to median split data for it
        covar_median = np.median(covariate)
        belowmed_trls = np.less_equal(covariate, covar_median)
        abovemed_trls = np.greater(covariate, covar_median)
        
        belowmed_vars = dict()
        abovemed_vars = dict()
        
        belowmed_vars['angdiff']        = angdiff[belowmed_trls]
        belowmed_vars['targori']        = targori[belowmed_trls]
        belowmed_vars['otherori']       = otherori[belowmed_trls]
        belowmed_vars['signed_error']   = signed_error[belowmed_trls]
        belowmed_vars[covariate_label]  = covariate[belowmed_trls]
        
        abovemed_vars['angdiff']        = angdiff[abovemed_trls]
        abovemed_vars['targori']        = targori[abovemed_trls]
        abovemed_vars['otherori']       = otherori[abovemed_trls]
        abovemed_vars['signed_error']   = signed_error[abovemed_trls]
        abovemed_vars[covariate_label]  = covariate[abovemed_trls]
        
        belowmed_bins = circ_bini(belowmed_vars['angdiff'], nbin, pbin)
        abovemed_bins = circ_bini(abovemed_vars['angdiff'], nbin, pbin)

        below_vars = dict()
        above_vars = dict()
        for key in ['biases', 'bincentre', 'precs', 'binid']:
            below_vars[key] = np.full(shape = nbin, fill_value = np.nan)
            above_vars[key] = np.full(shape = nbin, fill_value = np.nan)
        
        for i in range(nbin): #loop over bins
            below_vars['biases'][i]    = np.mean(belowmed_vars['signed_error'][belowmed_bins[i,:]])
            below_vars['bincentre'][i] = sp.stats.circmean( belowmed_vars['angdiff'][belowmed_bins[i,:]], high = np.pi, low = -np.pi )
            below_vars['precs'][i]     = sp.stats.circstd(  belowmed_vars['signed_error'][belowmed_bins[i,:]], high = np.pi, low = -np.pi )
            below_vars['binid'][i]     = i
    
            above_vars['biases'][i]    = np.mean(abovemed_vars['signed_error'][abovemed_bins[i,:]])
            above_vars['bincentre'][i] = sp.stats.circmean( abovemed_vars['angdiff'][abovemed_bins[i,:]], high = np.pi, low = -np.pi )
            above_vars['precs'][i]     = sp.stats.circstd(  abovemed_vars['signed_error'][abovemed_bins[i,:]], high = np.pi, low = -np.pi )
            above_vars['binid'][i]     = i
        
        above_df = pd.DataFrame(above_vars); above_df['median_split'] = 'above'; above_df.binid = above_df.binid.astype(int)
        below_df = pd.DataFrame(below_vars); below_df['median_split'] = 'below'; below_df.binid = below_df.binid.astype(int)
        
        df = pd.concat([above_df, below_df])
    else:
        bins = circ_bini(angdiff, nbin, pbin)
        
        biases    = np.full(shape = nbin, fill_value = np.nan)
        bincentre = np.full(shape = nbin, fill_value = np.nan)
        precs     = np.full(shape = nbin, fill_value = np.nan)
        binid     = np.full(shape = nbin, fill_value = 0)
        
        for i in range(nbin): #loop over bins
            biases[i]     = np.mean(signed_error[bins[i,:]]) #sp.stats.circmean(signed_error[bins[i,:]], high = np.pi, low = -np.pi)
            bincentre[i]  = sp.stats.circmean(angdiff[bins[i,:]], high = np.pi, low = -np.pi)
            precs[i]      = sp.stats.circstd(signed_error[bins[i,:]], high = np.pi, low = -np.pi)
            binid[i]      = i
        
        df = pd.DataFrame()
        df['subid'] = np.full(shape = nbin, fill_value = subid)
        df['bincentre'] = bincentre
        df['bias'] = biases
        df['prec'] = precs
        df['binid'] = binid
    
    df['subid'] = subid
    return df
    # #fitting evidence
    # classifier_output[:,sb_count,2] = evidence['cos_convolved']
    
    
    # convolve with cosine and obtain evidence
    evidence['single_trial_evidence'] = evidence['single_trial_evidence_store']
    y, evidence['y'] = bin_array(np.array(df_read['presented'])[inx],nr_bins)
    evidence = cos_convolve(evidence)
    # #fitting evidence
    classifier_output[:,sb_count,0] = evidence['cos_convolved']
    classifier_output_tuning[:,:,sb_count,0] = evidence['centered_prediction']    
    
    # df_read.to_csv((projectloc + '/saved_data_forNick/behavioural_betweentrial_S%02d_June2021.csv' %sb_count))
    # scipy.io.savemat(projectloc + '/saved_data_forNick/LDA_betweentrial_S%02d_June2021.mat' %sb_count, {'data':evidence['single_trial_ev_centered']})

    df_read['diff_stim_prev_pang'] = pycircstat.cdiff(df_read['prev_non_probe_ang'],df_read['presented'])
    
    min_deg = 0 #2.903#2.8125 #10
    max_deg = 60 #58.06 #56.25 #50
    # right[:,:,sb_count] = evidence['single_trial_ev_centered'][(df_read['diff_stim_prev_pang'][inx] >= min_deg/90*np.pi) & (df_read['diff_stim_prev_pang'][inx] < max_deg/90*np.pi),:,:].mean(0)
    right[:,:,sb_count] = evidence['single_trial_ev_centered'][inx2][(df_read['diff_stim_prev_pang'][inx][inx2] >= min_deg/90*np.pi) & (df_read['diff_stim_prev_pang'][inx][inx2] < max_deg/90*np.pi),:,:].mean(0)
    # left[:,:,sb_count] = evidence['single_trial_ev_centered'][(df_read['diff_stim_prev_pang'][inx] <= -min_deg/90*np.pi) & (df_read['diff_stim_prev_pang'][inx] > -max_deg/90*np.pi),:,:].mean(0)
    left[:,:,sb_count] = evidence['single_trial_ev_centered'][inx2][(df_read['diff_stim_prev_pang'][inx][inx2] <= -min_deg/90*np.pi) & (df_read['diff_stim_prev_pang'][inx][inx2] > -max_deg/90*np.pi),:,:].mean(0)

    shift[:,sb_count] = (left[0:4,:,sb_count].mean(0) - left[5:9,:,sb_count].mean(0)) - (right[0:4,:,sb_count].mean(0) - right[5:9,:,sb_count].mean(0)) 

#%%   
# stim_prev = {}
# stim_prev['Crossdec_prev_probe_Protect'] = classifier_output[:,:,2]
# stim_prev['shift_prev_probe_Protect'] = shift