Пример #1
0
def baseline_stimulus_correlation_mean(off_firing, baseline_window, stimulus_time,
                                  stimulus_window_size, step_size, shuffle_repeats, file, laser):
    corr_dat = pd.DataFrame() 
    all_pre_stim_dist = []
    all_stim_dist = []
    for taste in range(len(off_firing)):
        data = off_firing[taste]
        
        baseline_start = int(baseline_window[0]/step_size)
        baseline_end = int(baseline_window[1]/step_size)
        stim_start = int(stimulus_time/step_size)
        stim_end = int((stimulus_time + stimulus_window_size)/step_size)
        
        mean_pre_stim = np.mean(data[:,:,baseline_start:baseline_end],axis = 2).T #(neurons x trials)
        pre_stim_dist = np.tril(dist_mat(mean_pre_stim,mean_pre_stim)) # Take out upper diagonal to prevent double counting
        all_pre_stim_dist.append(pre_stim_dist)
        
        mean_stim_dat = np.mean(data[:,:,stim_start:stim_end],axis=2).T
        stim_dist = np.tril(dist_mat(mean_stim_dat,mean_stim_dat))
        all_stim_dist.append(stim_dist)
        
        temp_corr = pearsonr(pre_stim_dist[pre_stim_dist.nonzero()].flatten(),stim_dist[stim_dist.nonzero()].flatten())
        temp_corr_dat = pd.DataFrame(dict(file = file, taste = taste, 
                baseline_end = baseline_end*step_size, rho = temp_corr[0],p = temp_corr[1],
                index = [corr_dat.shape[0]], shuffle = False, pre_stim_window_size = (baseline_end - baseline_start)*step_size,laser = laser))
        corr_dat = pd.concat([corr_dat, temp_corr_dat])
        
        for repeat in range(shuffle_repeats):
            output = stim_corr_shuffle(pre_stim_dist, stim_dist,baseline_end, baseline_start, step_size, file, taste, laser, corr_dat)
            corr_dat = pd.concat([corr_dat,output])
        
    return corr_dat, all_pre_stim_dist, all_stim_dist
Пример #2
0
def baseline_stimulus_correlation_acc(off_firing, baseline_window, stimulus_time,
                                  stimulus_window_size, step_size, shuffle_repeats, file, laser):
    corr_dat = pd.DataFrame() 
    all_pre_stim_dist = []
    all_stim_dist = []
    for taste in range(len(off_firing)):
        data = off_firing[taste]
        
        baseline_start = int(baseline_window[0]/step_size)
        baseline_end = int(baseline_window[1]/step_size)
        stim_start = int(stimulus_time/step_size)
        stim_end = int((stimulus_time + stimulus_window_size)/step_size)
        
        pre_dat = data[:,:,baseline_start:baseline_end]
        pre_dists = np.zeros((pre_dat.shape[1],pre_dat.shape[1],pre_dat.shape[2]))
        for time_bin in range(pre_dists.shape[2]):
            pre_dists[:,:,time_bin] = dist_mat(pre_dat[:,:,time_bin].T,pre_dat[:,:,time_bin].T)
        pre_stim_dist = np.tril(np.sum(pre_dists,axis = 2))
        all_pre_stim_dist.append(pre_stim_dist)
        
        stim_dat = data[:,:,stim_start:stim_end]
        stim_dists = np.zeros((stim_dat.shape[1],stim_dat.shape[1],stim_dat.shape[2]))
        for time_bin in range(stim_dists.shape[2]):
            stim_dists[:,:,time_bin] = dist_mat(stim_dat[:,:,time_bin].T,stim_dat[:,:,time_bin].T)
        sum_stim_dists = np.tril(np.sum(stim_dists,axis = 2))
        all_stim_dist.append(sum_stim_dists)
        
        temp_corr = pearsonr(pre_stim_dist[pre_stim_dist.nonzero()].flatten(),sum_stim_dists[sum_stim_dists.nonzero()].flatten())
        temp_corr_dat = pd.DataFrame(dict(file = file, taste = taste, 
                baseline_end = baseline_end*step_size, rho = temp_corr[0],p = temp_corr[1],
                index = [corr_dat.shape[0]], shuffle = False, pre_stim_window_size = (baseline_end - baseline_start)*step_size,laser = laser))
        corr_dat = pd.concat([corr_dat, temp_corr_dat])
        
        for repeat in range(shuffle_repeats):
            output = stim_corr_shuffle(pre_stim_dist, sum_stim_dists,baseline_end, baseline_start, step_size, file, taste, laser, corr_dat)
            corr_dat = pd.concat([corr_dat,output])
        
    return corr_dat, all_pre_stim_dist, all_stim_dist
for x in dir_list:
    file_list = file_list + glob.glob(x + '/**/' + '*.h5',recursive=True)

file  = 4

this_dir = file_list[file].split(sep='/')[-2]
data_dir = os.path.dirname(file_list[file])
data = ephys_data(data_dir = data_dir ,file_id = file, use_chosen_units = True)
data.firing_rate_params = dict(zip(['step_size','window_size','total_time','calc_type','baks_len'],
                               [25,250,7000,'conv',700]))
data.get_data()
data.get_firing_rates()
data.get_normalized_firing()
data.firing_overview('off')

all_firing_array = np.asarray(data.all_normal_off_firing)

#taste = 2
X = all_firing_array.swapaxes(1,2) #data.all_normal_off_firing[:,:,100:200].swapaxes(1,2) #
rank = 7

# Fit CP tensor decomposition (two times).
U = tt.cp_als(X, rank=rank, verbose=True)

# Compare the low-dimensional factors from the two fits.
fig, _, _ = tt.plot_factors(U.factors)

## We should be able to see differences in tastes by using distance matrices on trial factors
trial_factors = U.factors.factors[-1]
trial_distances = dist_mat(trial_factors,trial_factors)
plt.figure();plt.imshow(trial_distances)
def firing_correlation(firing_array,
                       baseline_window,
                       stimulus_window,
                       data_step_size=25,
                       shuffle_repeats=100,
                       accumulated=False):
    """
    General function, not bound by object parameters
    Calculates correlations in 2 windows of a firin_array (defined below) 
        according to either accumulated distance or distance of mean points
    PARAMS
    :firing_array: (nrn x trial x time) array of firing rates
    :baseline_window: Tuple of time in ms of what window to take for BASELINE firing
    :stimulus_window: Tuple of time in ms of what window to take for STIMULUS firing
    :data_step_size: Resolution at which the data was binned (if at all)
    :shuffle repeats: How many shuffle repeats to perform for analysis control
    :accumulated:   If True -> will calculate temporally integrated pair-wise distances between all points
                    If False -> will calculate distance between mean of all points  
    """
    # Calculate indices for slicing data
    baseline_start_ind = int(baseline_window[0] / data_step_size)
    baseline_end_ind = int(baseline_window[1] / data_step_size)
    stim_start_ind = int(stimulus_window[0] / data_step_size)
    stim_end_ind = int(stimulus_window[1] / data_step_size)

    pre_dat = firing_array[:, :, baseline_start_ind:baseline_end_ind]
    stim_dat = firing_array[:, :, stim_start_ind:stim_end_ind]

    if accumulated:
        # Calculate accumulated pair-wise distances for baseline data
        pre_dists = np.zeros(
            (pre_dat.shape[1], pre_dat.shape[1], pre_dat.shape[2]))
        for time_bin in range(pre_dists.shape[2]):
            pre_dists[:, :, time_bin] = dist_mat(pre_dat[:, :, time_bin].T,
                                                 pre_dat[:, :, time_bin].T)
        sum_pre_dist = np.sum(pre_dists, axis=2)

        # Calculate accumulated pair-wise distances for post-stimulus data
        stim_dists = np.zeros(
            (stim_dat.shape[1], stim_dat.shape[1], stim_dat.shape[2]))
        for time_bin in range(stim_dists.shape[2]):
            stim_dists[:, :, time_bin] = dist_mat(stim_dat[:, :, time_bin].T,
                                                  stim_dat[:, :, time_bin].T)
        sum_stim_dist = np.sum(stim_dists, axis=2)

        # Remove lower triangle in correlation to not double count points
        indices = np.mask_indices(stim_dat.shape[1], np.triu, 1)
        rho, p = pearsonr(sum_pre_dist[indices], sum_stim_dist[indices])

        pre_mat, stim_mat = sum_pre_dist, sum_stim_dist

    else:
        # Calculate accumulate pair-wise distances for baseline data
        mean_pre = np.mean(pre_dat, axis=2)
        mean_pre_dist = dist_mat(mean_pre.T, mean_pre.T)

        # Calculate accumulate pair-wise distances for post-stimulus data
        mean_stim = np.mean(stim_dat, axis=2)
        mean_stim_dist = dist_mat(mean_stim.T, mean_stim.T)

        indices = np.mask_indices(stim_dat.shape[1], np.triu, 1)
        rho, p = pearsonr(mean_pre_dist[indices], mean_stim_dist[indices])

        pre_mat, stim_mat = mean_pre_dist, mean_stim_dist

    rho_sh_vec = np.empty(shuffle_repeats)
    p_sh_vec = np.empty(shuffle_repeats)
    for repeat in range(shuffle_repeats):
        rho_sh_vec[repeat], p_sh_vec[repeat] = pearsonr(
            np.random.permutation(pre_mat[indices]), stim_mat[indices])

    return rho, p, rho_sh_vec, p_sh_vec, pre_mat, stim_mat
Пример #5
0
    fig, ax = plt.subplots()
    for idx, shrink in enumerate(shrink_list[::-1]):

        ## Normal loading
        # beads = np.loadtxt("beads_from_func/" + cell + "_" + str(shrink).ljust(4, '0') + "_" + tag + "_beads.txt", delimiter=",")
        # disps = np.loadtxt("beads_from_func/" + cell + "_" + str(shrink).ljust(4, '0') + "_" + tag + "_disps.txt", delimiter=",")
        ## Loading for check_randomness
        beads = np.loadtxt("simultaneous_beads/" + cell + "_beads.txt",
                           delimiter=",")
        disps = np.loadtxt("simultaneous_beads/" + cell + "_" +
                           str(shrink).ljust(4, '0') + "_" + tag +
                           "_disps.txt",
                           delimiter=",")

        disp_mag = np.linalg.norm(disps, axis=1)
        d_mat = dist_mat(beads, cell_points)
        dist2cell = d_mat.min(1)
        # fig, ax = plt.subplots()
        # ax.scatter(dist2cell, disp_mag, c='k', s=3)
        xx, yy, n = get_bins(dist2cell, disp_mag)
        ax.plot(xx, yy, '-o', c=color_list[idx])

        if idx == 0:
            temp_x_lim = ax.get_xlim()
            temp_y_lim = ax.get_ylim()

            #########################
            ax2 = ax.twinx()
            ax2.bar(xx + idx - 2,
                    n,
                    1,
Пример #6
0
 # =============================================================================
 # smooth_long = smooth_array[0,:,:,:]
 # for taste in range(1,smooth_array.shape[0]):
 #     smooth_long = np.concatenate((smooth_long, smooth_array[taste,:,:,:]),axis=1)
 # =============================================================================
     
 # Make joint distribution over baseline and post-stimulus firing
             
 for taste in range(smooth_array.shape[0]):
     for neuron in range(smooth_array.shape[1]):
         this_neuron = smooth_array[taste,neuron,:,:]
         this_pre = this_neuron[:,baseline_inds]
         this_post = this_neuron[:,stimulus_inds]
         
         # Simultaneously do distance correlation to see if methods corroborate eachother
         pre_dists = dist_mat(this_pre,this_pre)
         post_dists = dist_mat(this_post,this_post)
         dist_corr = pearsonr(pre_dists[np.tril_indices(pre_dists.shape[0])],post_dists[np.tril_indices(post_dists.shape[0])])[0]
         all_dist_corrs.append(dist_corr)
         
         # Determine quartiles over all trials, but for post and pre separately
         quartiles = np.linspace(0,100,symbols+1)
         pre_vals = np.percentile(this_pre.flatten(),quartiles)
         post_vals = np.percentile(this_post.flatten(),quartiles)
         
         pre_bins = np.zeros((this_neuron.shape[0],symbols))
         post_bins = np.zeros((this_neuron.shape[0],symbols))
         for trial in range(this_neuron.shape[0]):
             for val in range(symbols):
                 pre_bins[trial,val] = np.sum((this_pre[trial,:] < pre_vals[val+1]) &
                               (this_pre[trial,:] >= pre_vals[val])) / len(this_pre[trial,:])
Пример #7
0
# All data points
plot_data = np.concatenate((all_baseline,all_stimulus),axis=1)
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = Axes3D(fig)
colors  = np.concatenate((np.ones((1,all_baseline.shape[1])),np.multiply(np.ones((1,all_stimulus.shape[1])),2)),axis=1)
base = ax.scatter(all_baseline[0,:],all_baseline[1,:],all_baseline[2,:],c=range(len(all_baseline[0,:])),label='baseline')
stim = ax.scatter(all_stimulus[0,:],all_stimulus[1,:],all_stimulus[2,:],c='blue',label='stimulus')
plt.colorbar(base)

# Mean trials
mean_baseline_trajs = np.mean(baseline_trajs,axis=2)
mean_stimulus_trajs = np.mean(stimulus_trajs,axis=2)
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = Axes3D(fig)
colors  = np.concatenate((np.ones((1,all_baseline.shape[1])),np.multiply(np.ones((1,all_stimulus.shape[1])),2)),axis=1)
base = ax.scatter(mean_baseline_trajs[0,:],mean_baseline_trajs[1,:],mean_baseline_trajs[2,:],c=range(len(mean_baseline_trajs[0,:])),label='baseline',cmap='hsv')
ax.plot(mean_baseline_trajs[0,:],mean_baseline_trajs[1,:],mean_baseline_trajs[2,:],linewidth=0.5)
stim = ax.scatter(mean_stimulus_trajs[0,:],mean_stimulus_trajs[1,:],mean_stimulus_trajs[2,:],c='black',label='stimulus')
plt.colorbar(base)

# Distance
distance_data = np.concatenate((all_trajectories[:,:,range(80)],all_trajectories[:,:,range(80,160)]),axis=1)
distance_array = np.empty((distance_data.shape[1],distance_data.shape[1],distance_data.shape[2]))
for time in range(distance_array.shape[2]):
    distance_array[:,:,time] = dist_mat(distance_data[:,:,time].T,distance_data[:,:,time].T)
fin_acc_dists = np.mean(distance_array,axis=2)

mean_distance_data = np.mean(distance_data,axis=2)
fin_mean_dists = dist_mat(mean_distance_data.T,mean_distance_data.T)
        this_off = all_firing_array[taste, :, :, :]
        this_spikes = all_spikes_array[taste, :, :, :]

        # =============================================================================
        #         total_off = this_off[0,:,:]
        #         for nrn in range(1,this_off.shape[0]):
        #             total_off = np.concatenate((total_off,this_off[int(nrn),:,:]),axis=1)
        # =============================================================================

        # Tensor decomposition for trial clustering
        rank = 3  #X.shape[0]
        U = tt.cp_als(np.swapaxes(this_off, 1, 2), rank=rank, verbose=True)
        fig, ax, po = tt.plot_factors(U.factors)
        [nrn_f, time_f, trial_f] = U.factors.factors
        trial_dists = exposure.equalize_hist(dist_mat(trial_f, trial_f))

        #trial_dists = exposure.equalize_hist(dist_mat(total_off,total_off))

        # =============================================================================
        # reduced_off_pca = pca(n_components = 15).fit(total_off)
        # reduced_off = reduced_off_pca.transform(total_off)
        # trial_dists_red = dist_mat(reduced_off,reduced_off)
        # =============================================================================

        clf = kmeans(n_clusters=n_components, n_init=200)
        this_groups = clf.fit_predict(trial_dists)

        # =============================================================================
        #         gmm = GaussianMixture(n_components=n_components, covariance_type='full',
        #                               n_init = 500).fit(trial_dists)
Пример #9
0
this_spikes = spikes_array[taste, :, :, 2000:5000]

mean_firing = np.mean(this_firing, axis=1)

# Reduce to 1D
mean_pca = pca(n_components=1).fit(mean_firing.T)
reduced_mean = mean_pca.transform(mean_firing.T)

# Reduce all other trials using same transformation
reduced_trials = np.zeros(this_firing.shape[1:])
for trial in range(this_firing.shape[1]):
    reduced_trials[trial, :] = mean_pca.transform(
        this_firing[:, trial, :].T).flatten()

# Calculate distances and Perform k-means on trials
nrn_dist = exposure.equalize_hist(dist_mat(reduced_trials, reduced_trials))

n_components = 3
clf = kmeans(n_clusters=n_components, n_init=500)
this_groups = clf.fit_predict(nrn_dist)

trial_order = np.argsort(this_groups)

# Pull out and cluster distance matrices
clust_post_dist = nrn_dist[trial_order, :]
clust_post_dist = clust_post_dist[:, trial_order]

## Distance matrix cluster plots
plt.figure()
plt.subplot(221)
plt.imshow(exposure.equalize_hist(nrn_dist))
    for nrn in range(data.off_spikes[0].shape[0]):
        for taste in range(4):
            
            # Only take neurons which fire in every trial
            this_spikes = data.off_spikes[taste]
            this_spikes = this_spikes[nrn,:,2000:4000]
            if not (np.sum(np.sum(this_spikes,axis=1) == 0) > 0):
                
                this_off = data.normal_off_firing[taste]
                this_off = this_off[nrn,:,80:160]
                
                mean_coeff_var = np.mean(np.std(this_off,axis=0)/np.mean(this_off,axis=0))
                
                #this_off_red = pca(n_components = 5).fit_transform(this_off)
                
                nrn_dist = exposure.equalize_hist(dist_mat(this_off,this_off))
                
            # =============================================================================
            #     gmm = GaussianMixture(n_components=n_components, covariance_type='full',
            #                           n_init = 100)
            #     this_groups = gmm.fit(nrn_dist).predict(nrn_dist)
            # =============================================================================
                
                clf = kmeans(n_clusters = n_components, n_init = 100)
                this_groups = clf.fit_predict(nrn_dist)
                

                group_sizes  = np.asarray([sum(this_groups == x) for x in np.unique(this_groups)])
                min_group_size = len(this_groups)/3
                #max_group_size = len(this_groups)*2/3
                dynamic_criterion = (np.sum(group_sizes >= min_group_size) >= 2) # Atleast 2 groups are greater than 1/3 of total number
    return np.mean(cluster_strengths)

# =============================================================================
# =============================================================================
all_strenghts = []
all_deviations = []
repeats = 50
points = 15
dims = 120

for repeat in range(repeats):
    fake_dat = np.random.normal(size = (points,dims))
    #plt.scatter(dat[:,0],dat[:,1])
    
    n_components = 2
    dist = exposure.equalize_hist(dist_mat(fake_dat,fake_dat))
    
    gmm = GaussianMixture(n_components=n_components, covariance_type='full',
                          n_init = n_components*100).fit(dist)
    print(gmm.predict(dist))
    
    this_groups = gmm.predict(dist)
    trial_order = np.argsort(this_groups)
    
    # Pull out and cluster distance matrices
    clust_post_dist = dist[trial_order,:]
    clust_post_dist = clust_post_dist[:,trial_order]
    all_strenghts.append(clust_strength(fake_dat,this_groups))
    
plt.hist(all_strenghts,20)
# =============================================================================
    all_off_firing_long = np.concatenate(
        (all_off_firing_long, all_off_firing[int(nrn), :, :]), axis=1)

all_off_red_pca = pca(n_components=20).fit(all_off_firing_long)
all_off_red = all_off_red_pca.transform(all_off_firing_long)

plt.imshow(exposure.equalize_hist(all_off_red))
groups = np.sort(np.asarray([0, 1, 2, 3] * 15))
plt.figure()
plt.scatter(all_off_red[:, 0], all_off_red[:, 1], c=groups)
plt.colorbar()

taste_lda = lda().fit(all_off_red, groups)
print(np.mean(taste_lda.predict(all_off_red) == groups))

trial_dist = dist_mat(all_off_firing_long, all_off_firing_long)
plt.figure()
plt.imshow(exposure.equalize_hist(trial_dist))

##

n_components = 3
taste = 1
pre_inds = np.arange(0, 80)
post_inds = np.arange(80, 160)

this_off = data.normal_off_firing[taste]
this_off_pre = this_off[:, :, pre_inds]
this_off_post = this_off[:, :, post_inds]

total_off_post = this_off_post[0, :, :]
plt.show()

# Part B -> Run over all trials but use respective basline for every trial

# Part C -> Calculate a distance matrix...patterning will show regularity of trajectory
# Normalize and sum dist matrices for all trials
# Assuming state transition are bounded in time, there should be a trend
# Observation: Trials for different tastes show different structure
taste = 0
#dist_array = np.empty((firing_len,firing_len,np.int(off_firing[0].shape[1]/firing_len)))
#dist_array = np.empty((firing_len,firing_len,off_firing[0].shape[0]))
dist_array = np.empty((80, 80, off_firing[0].shape[0]))
for trial in range(dist_array.shape[2]):
    #dat = np.transpose(off_firing[taste][:,firing_len*trial:firing_len*(trial+1)])
    dat = np.transpose(off_firing[taste][trial, :, 0:80])
    dist_array[:, :, trial] = dist_mat(dat, dat)

fig = plt.figure(figsize=(21, 6))
columns = 7
rows = 2
count = 0
for i in range(1, columns * rows + 1):
    fig.add_subplot(rows, columns, i)
    #plt.imshow(stats.zscore(dist_array[:,:,count]))
    plt.imshow(dist_array[:, :, count])
    count += 1
plt.show()
plt.tight_layout()
#fig.savefig('taste%i.png' % taste)
plt.imshow(np.mean(dist_array, axis=2))
for nrn in range(mean_off_dat.shape[1]):
    plt.figure()
    for taste in range(mean_off_dat.shape[0]):
        plt.subplot(4,1,taste+1)
        plt.plot(np.linspace(0,2.5,len(time_inds)),off_dat[taste,nrn,:,time_inds],c = 'r')
        plt.plot(np.linspace(0,2.5,len(time_inds)),on_dat[taste,nrn,:,time_inds],c='b')

# =============================================================================
# Split trials for a taste into 2 groups to remove dead trials, then do average
# =============================================================================
nrn = 4
for taste in range(4):
    this_nrn_off = np.asarray(data.normal_off_firing)[taste,nrn,:,:]
    this_nrn_on = np.asarray(data.normal_on_firing)[taste,nrn,:,:]
    
    off_dist = exposure.equalize_hist(dist_mat(this_nrn_off,this_nrn_off))
    on_dist = exposure.equalize_hist(dist_mat(this_nrn_on,this_nrn_on))
    
    n_components = 2
    clf = kmeans(n_clusters = n_components, n_init = 500)
    off_groups = clf.fit_predict(off_dist)
    on_groups = clf.fit_predict(on_dist)
    
    
    off_order = np.argsort(off_groups)
    on_order = np.argsort(on_groups)
    
    sorted_off = this_nrn_off[off_order]
    sorted_on = this_nrn_on[on_order]
    
    off_clust_list = []
        this_spiketimes = rspiketimes[i,j,:]
        this_spiketimes = this_spiketimes[this_spiketimes > 0]
        this_spiketimes = np.asarray([int(x) for x in this_spiketimes])
        rspikes[i,j,this_spiketimes - 1] = 1

# =============================================================================
# =============================================================================
nrn = 1
n_components = clusters
this_spikes = rspikes[:,nrn,:]

this_off = np.zeros(this_spikes.shape)
for trial in range(this_off.shape[0]):
    this_off[trial,:] = gauss_filt(this_spikes[trial,:],100)

nrn_dist = exposure.equalize_hist(dist_mat(this_off,this_off))

gmm = GaussianMixture(n_components=n_components, covariance_type='full',
                      n_init = n_components*200).fit(nrn_dist)
print(gmm.predict(nrn_dist))
        

this_groups = gmm.predict(nrn_dist)
trial_order = np.argsort(this_groups)

# Pull out and cluster distance matrices
clust_post_dist = nrn_dist[trial_order,:]
clust_post_dist = clust_post_dist[:,trial_order]

## Distance matrix cluster plots
plt.figure()
                    binned_firing[taste,trial,this_bin] = np.median(this_dat)
        
        binned_firing_long = binned_firing[0,:,:]
        for taste in range(1,binned_firing.shape[0]):
            binned_firing_long = np.concatenate((binned_firing_long,binned_firing[taste,:,:]))
            
        pred_taste = np.zeros((this_nrn_firing.shape[0],len(test_inds)))
        train_dat = binned_firing_long[train_inds,:]
        test_dat = binned_firing_long[test_inds,:]
        
        mean_taste_firing = np.zeros((binned_firing.shape[0],binned_firing_long.shape[1]))
        for taste in range(binned_firing.shape[0]):
            mean_taste_firing[taste,:] = np.mean(train_dat[train_label==taste,:],axis=0)
        
        for trial in range(test_dat.shape[0]):
            this_dists = dist_mat(mean_taste_firing,test_dat[trial,:][np.newaxis,:]).flatten()
            pred_taste[:,trial] = this_dists/np.sum(this_dists)
        
        psth_class_accuracy = np.mean(np.argmin(pred_taste,axis=0)==test_label)*100
        this_frac_psth_accuracy.append(psth_class_accuracy)
        
        if np.mod(i,10)==0:
            print(i)
        
    all_prob_class_acc.append(this_frac_prob_accuracy)
    all_psth_class_acc.append(this_frac_psth_accuracy)
    
plt.figure()
plt.errorbar(x = all_train_fractions,y=np.mean(np.asarray(all_prob_class_acc),axis=1),yerr = np.std(np.asarray(all_prob_class_acc),axis=1),
             color = 'blue',label = 'Prob')
plt.errorbar(x=all_train_fractions,y=np.mean(np.asarray(all_psth_class_acc),axis=1),yerr = np.std(np.asarray(all_psth_class_acc),axis=1),