def plot_firing_rate_time_course(experiment):
    for session in experiment:
        predictor_A_Task_1, predictor_A_Task_2, predictor_A_Task_3,\
        predictor_B_Task_1, predictor_B_Task_2, predictor_B_Task_3, reward,\
        predictor_a_good_task_1,predictor_a_good_task_2, predictor_a_good_task_3 = re.predictors_pokes(session)

        aligned_spikes = session.aligned_rates
        n_neurons = aligned_spikes.shape[1]

        spikes_B_task_1 = aligned_spikes[np.where(predictor_B_Task_1 == 1)]
        spikes_A_task_1 = aligned_spikes[np.where(predictor_A_Task_1 == 1)]
        spikes_B_task_2 = aligned_spikes[np.where(predictor_B_Task_2 == 1)]
        spikes_A_task_2 = aligned_spikes[np.where(predictor_A_Task_2 == 1)]
        spikes_B_task_3 = aligned_spikes[np.where(predictor_B_Task_3 == 1)]
        spikes_A_task_3 = aligned_spikes[np.where(predictor_A_Task_3 == 1)]
        fig, axes = plt.subplots(figsize=(15, 5),
                                 ncols=n_neurons,
                                 sharex=True,
                                 sharey='col')
        mean_spikes_B_task_1 = np.mean(spikes_B_task_1, axis=0)
        mean_spikes_A_task_1 = np.mean(spikes_A_task_1, axis=0)
        mean_spikes_B_task_2 = np.mean(spikes_B_task_2, axis=0)
        mean_spikes_A_task_2 = np.mean(spikes_A_task_2, axis=0)
        mean_spikes_B_task_3 = np.mean(spikes_B_task_3, axis=0)
        mean_spikes_A_task_3 = np.mean(spikes_A_task_3, axis=0)
        for neuron in range(n_neurons):
            plt.axes[neuron].plot(mean_spikes_B_task_1[neuron],
                                  label='B Task 1')
            plt.axes[neuron].plot(mean_spikes_A_task_1[neuron],
                                  label='A Task 1')
            plt.axes[neuron].plot(mean_spikes_B_task_2[neuron],
                                  label='B Task 2')
            plt.axes[neuron].plot(mean_spikes_A_task_2[neuron],
                                  label='A Task 2')
            plt.axes[neuron].plot(mean_spikes_B_task_3[neuron],
                                  label='B Task 3')
            plt.axes[neuron].plot(mean_spikes_A_task_3[neuron],
                                  label='A Task 3')
        plt.axes[0].legend()
        plt.title('{}'.format(session.file_name))
def block_plot():  
    
    neuron_count_HP = 0
    neuron_count_PFC = 0
   
    for s,session in enumerate(experiment_aligned_HP):
        aligned_spikes = session.aligned_rates[:]        
        n_trials, n_neurons, n_timepoints = aligned_spikes.shape 
            
        for n in range(n_neurons):
            neuron_count_HP += 1           
            if neuron_count_HP == ind_n_HP[0][0]+1:
                spikes = aligned_spikes[:,n,:]
                spikes = np.mean(spikes,axis = 1)
                # Getting out task indicies   
                task = session.trial_data['task']
                forced_trials = session.trial_data['forced_trial']
                non_forced_array = np.where(forced_trials == 0)[0]
                task_non_forced = task[non_forced_array]
                task_1 = np.where(task_non_forced == 1)[0]
                task_2 = np.where(task_non_forced == 2)[0]    
                predictor_A_Task_1, predictor_A_Task_2, predictor_A_Task_3,\
                predictor_B_Task_1, predictor_B_Task_2, predictor_B_Task_3, reward,\
                predictor_a_good_task_1,predictor_a_good_task_2, predictor_a_good_task_3 = re.predictors_pokes(session)    
    
                # Getting out task indicies
                forced_trials = session.trial_data['forced_trial']
                outcomes = session.trial_data['outcomes']
    
                choices = session.trial_data['choices']
                non_forced_array = np.where(forced_trials == 0)[0]
                states  = session.trial_data['state']
                states = states[non_forced_array]
                
                choices = choices[non_forced_array]
                outcomes = outcomes[non_forced_array]
                ones = np.ones(len(choices))
                # Getting out task indicies
                predictors_all = OrderedDict([('latent_state',states),
                                  ('choice', choices),
                                  ('reward', outcomes),
                                  ('ones', ones)])
                X_all = np.vstack(predictors_all.values()).T[:len(choices),:].astype(float)


                choices = choices[:len(task_1)]
                outcomes = outcomes[:len(task_1)]
                latent_state = np.ones(len(task_1))
                latent_state[predictor_a_good_task_1] = -1
                ones = np.ones(len(task_1))
                spikes = spikes[:len(task_1)]

                
                predictors = OrderedDict([('latent_state',latent_state),
                                  ('choice', choices),
                                  ('reward', outcomes),
                                  ('ones', ones)])
    
                X = np.vstack(predictors.values()).T[:len(choices),:].astype(float)
                t = regression_code(spikes[:,np.newaxis], X)
                

                plt.figure(1)
                spikes = aligned_spikes[:,n,:]
                spikes = np.mean(spikes,axis = 1)


                x = np.arange(len(spikes))
                plt.plot(x,spikes)
                
                max_y = np.int(np.max(spikes)+ 5)
        
     
                forced_trials = session.trial_data['forced_trial']
                outcomes = session.trial_data['outcomes']
    
                choices = session.trial_data['choices']
                non_forced_array = np.where(forced_trials == 0)[0]
                           
                choices = choices[non_forced_array]
                aligned_spikes = aligned_spikes[:len(choices),:,:]
                outcomes = outcomes[non_forced_array]
                states  = session.trial_data['state']
                states = states[non_forced_array]
                
                task = session.trial_data['task']
                task_non_forced = task[non_forced_array]
                task_1 = np.where(task_non_forced == 1)[0]
                task_2 = np.where(task_non_forced == 2)[0] 
                task_3 = np.where(task_non_forced == 3)[0]  

                # Getting out task indicies
                reward_ind = np.where(outcomes == 1)
                plt.plot(reward_ind[0],outcomes[reward_ind]+max_y+2, "v", color = 'red', alpha = 0.7, markersize=1, label = 'reward')
                choices_ind = np.where(choices == 1 )
                conj_a_reward =  np.where((outcomes == 1) & (choices == 1))
                a_no_reward = np.where((outcomes == 0) & (choices == 1))
                conj_b_reward =  np.where((outcomes == 1) & (choices == 0))
                b_no_reward = np.where((outcomes == 0) & (choices == 0))
                
                plt.plot(choices_ind[0], choices[choices_ind]+max_y+5,"x", color = 'green', alpha = 0.7, markersize=3, label = 'choice')
                plt.plot(states+max_y, color = 'black', alpha = 0.7, label = 'State')
                plt.plot(task_1,np.zeros(len(task_1))+max_y+7, 'pink', label = 'Task')
                plt.plot(task_2,np.zeros(len(task_2))+max_y+9, 'pink')
                plt.plot(task_3,np.zeros(len(task_3))+max_y+11, 'pink')
                
                plt.vlines(conj_a_reward,ymin = 0, ymax = max_y, alpha = 0.3, color = 'grey', label = 'A reward')

                plt.vlines(a_no_reward,ymin = 0, ymax = max_y, alpha = 0.3,color = 'darkblue', label = 'A no reward')
                
                plt.vlines(conj_b_reward,ymin = 0, ymax = max_y, alpha = 0.3, color = 'orange', label = 'B reward')

                plt.vlines(b_no_reward,ymin = 0, ymax = max_y, alpha = 0.3,color = 'yellow', label = 'B no reward')

                plt.title('HP neuron above 5')

                plt.legend()
                    
    for s,session in enumerate(experiment_aligned_PFC):
        aligned_spikes = session.aligned_rates[:]
        
        if aligned_spikes.shape[1] > 0: # sessions with neurons? 
            n_trials, n_neurons, n_timepoints = aligned_spikes.shape 
            
            for n in range(n_neurons):
                neuron_count_PFC += 1
                
                if neuron_count_PFC == ind_n_PFC[0][0]:
                    spikes = aligned_spikes[:,n,:]
                    spikes = np.mean(spikes, axis = 1)
                    x = np.arange(len(spikes))
                    plt.figure(2)
                    plt.plot(x,spikes)
                    
                    max_y = np.int(np.max(spikes)+ 5)
            
         
                    forced_trials = session.trial_data['forced_trial']
                    outcomes = session.trial_data['outcomes']
        
                    choices = session.trial_data['choices']
                    non_forced_array = np.where(forced_trials == 0)[0]
                               
                    choices = choices[non_forced_array]
                    aligned_spikes = aligned_spikes[:len(choices),:,:]
                    outcomes = outcomes[non_forced_array]
                    states  = session.trial_data['state']
                    states = states[non_forced_array]
                    
                    task = session.trial_data['task']
                    task_non_forced = task[non_forced_array]
                    task_1 = np.where(task_non_forced == 1)[0]
                    task_2 = np.where(task_non_forced == 2)[0] 
                    task_3 = np.where(task_non_forced == 3)[0]  

                    # Getting out task indicies
                    reward_ind = np.where(outcomes == 1)
                    plt.plot(reward_ind[0],outcomes[reward_ind]+max_y+2, "v", color = 'red', alpha = 0.7, markersize=1, label = 'reward')
                    choices_ind = np.where(choices == 1 )
                    conj_a_reward =  np.where((outcomes == 1) & (choices == 1))
                    a_no_reward = np.where((outcomes == 0) & (choices == 1))
                    conj_b_reward =  np.where((outcomes == 1) & (choices == 0))
                    b_no_reward = np.where((outcomes == 0) & (choices == 0))
                    plt.plot(choices_ind[0], choices[choices_ind]+max_y+5,"x", color = 'green', alpha = 0.7, markersize=3, label = 'choice')
                    plt.plot(states+max_y, color = 'black', alpha = 0.7, label = 'State')
                    plt.plot(task_1,np.zeros(len(task_1))+max_y+7, 'pink', label = 'Task')
                    plt.plot(task_2,np.zeros(len(task_2))+max_y+9, 'pink')
                    plt.plot(task_3,np.zeros(len(task_3))+max_y+11, 'pink')
                    
                    plt.vlines(conj_a_reward,ymin = 0, ymax = max_y, alpha = 0.3, color = 'grey', label = 'A reward')

                    plt.vlines(a_no_reward,ymin = 0, ymax = max_y, alpha = 0.3,color = 'darkblue', label = 'A no reward')

                    plt.vlines(conj_b_reward,ymin = 0, ymax = max_y, alpha = 0.3, color = 'orange', label = 'B reward')

                    plt.vlines(b_no_reward,ymin = 0, ymax = max_y, alpha = 0.3,color = 'yellow', label = 'B no reward')

                    plt.title('PFC neuron above 6')
                
                    plt.legend()
def regression_latent_state(experiment, experiment_sim_Q4_values):  
    
    C_1 = []
    C_coef = []
    cpd_1 = []
    
    # Finding correlation coefficients for task 1 
    for s,session in enumerate(experiment):
        aligned_spikes= session.aligned_rates[:]
        if aligned_spikes.shape[1] > 0: # sessions with neurons? 
            n_trials, n_neurons, n_timepoints = aligned_spikes.shape 
            
            #aligned_spikes = np.mean(aligned_spikes, axis =  2) 
            
            # Getting out task indicies   
            task = session.trial_data['task']
            forced_trials = session.trial_data['forced_trial']
            non_forced_array = np.where(forced_trials == 0)[0]
            task_non_forced = task[non_forced_array]
            task_1 = np.where(task_non_forced == 1)[0]
            task_2 = np.where(task_non_forced == 2)[0]    
            predictor_A_Task_1, predictor_A_Task_2, predictor_A_Task_3,\
            predictor_B_Task_1, predictor_B_Task_2, predictor_B_Task_3, reward,\
            predictor_a_good_task_1,predictor_a_good_task_2, predictor_a_good_task_3 = re.predictors_pokes(session)    
            # Getting out task indicies
            Q4 = experiment_sim_Q4_values[s]
            forced_trials = session.trial_data['forced_trial']
            outcomes = session.trial_data['outcomes']

            choices = session.trial_data['choices']
            non_forced_array = np.where(forced_trials == 0)[0]
                       
            choices = choices[non_forced_array]
            Q4 = Q4[non_forced_array]
            aligned_spikes = aligned_spikes[:len(choices),:]
            outcomes = outcomes[non_forced_array]
            # Getting out task indicies
            
            ones = np.ones(len(choices))
            choices = choices[:len(task_1)]
            outcomes = outcomes[:len(task_1)]
            latent_state = np.ones(len(task_1))
            latent_state[predictor_a_good_task_1] = -1
            ones = ones[:len(task_1)]
            aligned_spikes = aligned_spikes[:len(task_1)]
            Q4 = Q4[:len(task_1)]
            choice_Q4 = choices*Q4


            predictors = OrderedDict([#('latent_state',latent_state), 
                                      ('choice', choices),
                                      ('reward', outcomes),
                                      ('Q4', Q4),
                                      ('choice_Q4',choice_Q4),
                                      ('ones', ones)])
        
           
            X = np.vstack(predictors.values()).T[:len(choices),:].astype(float)
            n_predictors = X.shape[1]
            y = aligned_spikes.reshape([len(aligned_spikes),-1]) # Activity matrix [n_trials, n_neurons*n_timepoints]
            tstats = reg_f.regression_code(y, X)
            ols = LinearRegression(copy_X = True,fit_intercept= True)
            ols.fit(X,y)
            C_coef.append(ols.coef_.reshape(n_neurons, n_predictors,n_timepoints)) # Predictor loadings     
            C_1.append(tstats.reshape(n_predictors,n_neurons,n_timepoints)) # Predictor loadings
            cpd_1.append(re._CPD(X,y).reshape(n_neurons,n_timepoints, n_predictors))

    C_1 = np.concatenate(C_1, axis = 1) # 
    C_coef = np.concatenate(C_coef, axis = 0) #
    cpd_1 = np.nanmean(np.concatenate(cpd_1,0), axis = 0)

    C_2 = []
    C_coef_2 = []
    cpd_2 = []

    # Finding correlation coefficients for task 1 
    for s,session in enumerate(experiment):
        aligned_spikes= session.aligned_rates[:]
        if aligned_spikes.shape[1] > 0: # sessions with neurons? 
            n_trials, n_neurons, n_timepoints = aligned_spikes.shape 
            #aligned_spikes = np.mean(aligned_spikes, axis =  2) 
            Q4 = experiment_sim_Q4_values[s]

            # Getting out task indicies   
            task = session.trial_data['task']
            forced_trials = session.trial_data['forced_trial']
            non_forced_array = np.where(forced_trials == 0)[0]
            task_non_forced = task[non_forced_array]
            task_1 = np.where(task_non_forced == 1)[0]
            task_2 = np.where(task_non_forced == 2)[0]    
            
            predictor_A_Task_1, predictor_A_Task_2, predictor_A_Task_3,\
            predictor_B_Task_1, predictor_B_Task_2, predictor_B_Task_3, reward,\
            predictor_a_good_task_1,predictor_a_good_task_2, predictor_a_good_task_3 = re.predictors_pokes(session)    

            # Getting out task indicies
            forced_trials = session.trial_data['forced_trial']
            outcomes = session.trial_data['outcomes']

            choices = session.trial_data['choices']
            non_forced_array = np.where(forced_trials == 0)[0]
            Q4 = Q4[non_forced_array]

            
            choices = choices[non_forced_array]
            aligned_spikes = aligned_spikes[:len(choices),:]
            outcomes = outcomes[non_forced_array]
            # Getting out task indicies

            ones = np.ones(len(choices))
            
            choices = choices[len(task_1):len(task_1)+len(task_2)]
            latent_state = np.ones(len(choices))
            latent_state[predictor_a_good_task_2] = -1
            
            outcomes = outcomes[len(task_1):len(task_1)+len(task_2)]
            ones = ones[len(task_1):len(task_1)+len(task_2)]
            aligned_spikes = aligned_spikes[len(task_1):len(task_1)+len(task_2)]
            Q4 = Q4[len(task_1):len(task_1)+len(task_2)]
            choice_Q4 = choices*Q4

            predictors = OrderedDict([#('latent_state',latent_state),
                                      ('choice', choices),
                                      ('reward', outcomes),
                                      ('Q4',Q4),
                                      ('choice_Q4',choice_Q4),
                                      ('ones', ones)])
        
           
            X = np.vstack(predictors.values()).T[:len(choices),:].astype(float)
            n_predictors = X.shape[1]
            y = aligned_spikes.reshape([len(aligned_spikes),-1]) # Activity matrix [n_trials, n_neurons*n_timepoints]
            tstats = reg_f.regression_code(y, X)
            C_2.append(tstats.reshape(n_predictors,n_neurons,n_timepoints)) # Predictor loadings
            
            ols = LinearRegression(copy_X = True,fit_intercept= True)
            ols.fit(X,y)
            C_coef_2.append(ols.coef_.reshape(n_neurons, n_predictors,n_timepoints)) # Predictor loadings
            cpd_2.append(re._CPD(X,y).reshape(n_neurons,n_timepoints, n_predictors))


    C_2 = np.concatenate(C_2, axis = 1) # Population CPD is mean over neurons.
    C_coef_2 = np.concatenate(C_coef_2, axis = 0) # Population CPD is mean over neurons.
    cpd_2 = np.nanmean(np.concatenate(cpd_2,0), axis = 0)

    C_3 = []
    C_coef_3 = []
    cpd_3 = []
    # Finding correlation coefficients for task 1 
    for s,session in enumerate(experiment):
        aligned_spikes= session.aligned_rates[:]
        if aligned_spikes.shape[1] > 0: # sessions with neurons? 
            n_trials, n_neurons, n_timepoints = aligned_spikes.shape 
            #aligned_spikes = np.mean(aligned_spikes, axis =  2) 

            
            # Getting out task indicies   
            task = session.trial_data['task']
            forced_trials = session.trial_data['forced_trial']
            non_forced_array = np.where(forced_trials == 0)[0]
            task_non_forced = task[non_forced_array]
            task_1 = np.where(task_non_forced == 1)[0]
            task_2 = np.where(task_non_forced == 2)[0]    
            Q4 = experiment_sim_Q4_values[s]

            predictor_A_Task_1, predictor_A_Task_2, predictor_A_Task_3,\
            predictor_B_Task_1, predictor_B_Task_2, predictor_B_Task_3, reward,\
            predictor_a_good_task_1,predictor_a_good_task_2, predictor_a_good_task_3 = re.predictors_pokes(session)    


            # Getting out task indicies
            forced_trials = session.trial_data['forced_trial']
            outcomes = session.trial_data['outcomes']

            choices = session.trial_data['choices']
            non_forced_array = np.where(forced_trials == 0)[0]
            
            Q4 = Q4[non_forced_array]
            choices = choices[non_forced_array]
            aligned_spikes = aligned_spikes[:len(choices),:]
            outcomes = outcomes[non_forced_array]
            # Getting out task indicies

            ones = np.ones(len(choices))
  
            choices = choices[len(task_1)+len(task_2):]
            latent_state = np.ones(len(choices))
            latent_state[predictor_a_good_task_3] = -1
            
            outcomes = outcomes[len(task_1)+len(task_2):]
            ones = ones[len(task_1)+len(task_2):]
            Q4 = Q4[len(task_1)+len(task_2):]
            choice_Q4 = choices*Q4
            aligned_spikes = aligned_spikes[len(task_1)+len(task_2):]
            
            predictors = OrderedDict([#('latent_state', latent_state),
                                      ('choice', choices),
                                      ('reward', outcomes),
                                      ('Q4', Q4),
                                      ('choice_Q4',choice_Q4),
                                      ('ones', ones)])
        
           
            X = np.vstack(predictors.values()).T[:len(choices),:].astype(float)
            n_predictors = X.shape[1]
            y = aligned_spikes.reshape([len(aligned_spikes),-1]) # Activity matrix [n_trials, n_neurons*n_timepoints]
            tstats = reg_f.regression_code(y, X)

            C_3.append(tstats.reshape(n_predictors,n_neurons,n_timepoints)) # Predictor loadings
            
            ols = LinearRegression(copy_X = True,fit_intercept= True)
            ols.fit(X,y)
            C_coef_3.append(ols.coef_.reshape(n_neurons,n_timepoints, n_predictors)) # Predictor loadings
            cpd_3.append(re._CPD(X,y).reshape(n_neurons,n_timepoints, n_predictors))


    C_3 = np.concatenate(C_3, axis = 1) # Population CPD is mean over neurons.
    C_coef_3 = np.concatenate(C_coef_3, axis = 0) # Population CPD is mean over neurons.
    cpd_3 = np.nanmean(np.concatenate(cpd_3,0), axis = 0)
    
    return C_1, C_2, C_3, C_coef,C_coef_2,C_coef_3,cpd_1,cpd_2,cpd_3,predictors
Exemplo n.º 4
0
def remapping_control(experiment):
    # Plotting the proportion of cells that were firing in Task 1 and stopped firing in Task 2 or started firing at Task 2 but 
    # were not firing in Task 1  and cells that were firing at Task 2 but stopped at Task 3 or were not firing at Task 2 
    # but started firing in Task 3 
    session_a1_a1 = []
    session_a2_a2 = []
    session_a1_a2 = []
    session_a2_a3 = []
   
    for i,session in enumerate(experiment):
        index_neuron = []
        
        aligned_spikes= session.aligned_rates 
        n_trials, n_neurons, n_timepoints = aligned_spikes.shape
        predictor_A_Task_1, predictor_A_Task_2, predictor_A_Task_3,\
        predictor_B_Task_1, predictor_B_Task_2, predictor_B_Task_3, reward,\
        predictor_a_good_task_1,predictor_a_good_task_2, predictor_a_good_task_3 = re.predictors_pokes(session)
        t_out = session.t_out
        initiate_choice_t = session.target_times #T Times of initiation and choice 
        #Find firing rates around choice
      
        initiate_choice_t = session.target_times #Initiation and Choice Times
        
        ind_choice = (np.abs(t_out-initiate_choice_t[-2])).argmin() # Find firing rates around choice
        ind_after_choice = ind_choice + 7 # 1 sec after choice
        spikes_around_choice = aligned_spikes[:,:,ind_choice-2:ind_after_choice] # Find firing rates only around choice      
        mean_spikes_around_choice  = np.mean(spikes_around_choice,axis =2) # Mean firing rates around choice 
        baseline_mean_trial = np.mean(aligned_spikes, axis =2)
        std_trial = np.std(aligned_spikes, axis =2)
        baseline_mean_all_trials = np.mean(baseline_mean_trial, axis =0)
        std_all_trials = np.std(baseline_mean_trial, axis =1)
        index_no_reward = np.where(reward ==0)
        
        a_1 = np.where(predictor_A_Task_1 == 1) #Poke A task 1 idicies
        a_2 = np.where(predictor_A_Task_2 == 1) #Poke A task 2 idicies
        a_3 = np.where(predictor_A_Task_3 == 1) #Poke A task 3 idicies
        
        a1_nR = [a for a in a_1[0] if a in index_no_reward[0]]
        a2_nR = [a for a in a_2[0] if a in index_no_reward[0]]
        a3_nR = [a for a in a_3[0] if a in index_no_reward[0]]


        choice_a1 = mean_spikes_around_choice[a1_nR]
        
        if choice_a1.shape[0]%2 == 0:
            half = (choice_a1.shape[0])/2
            a_1_first_half = choice_a1[:int(half)]
            a_1_last_half = choice_a1[int(half):]
        else: # If number of trials is uneven 
            half = (choice_a1.shape[0]-1)/2
            a_1_first_half = choice_a1[:int(half)]
            a_1_last_half = choice_a1[int(half):]
            

        choice_a2 = mean_spikes_around_choice[a2_nR]

        if choice_a2.shape[0]%2 == 0:
            half = (choice_a2.shape[0])/2
            a_2_first_half = choice_a2[:int(half)]
            a_2_last_half = choice_a2[int(half):]
        else: #If number of trials is uneven 
            half = (choice_a2.shape[0]-1)/2
            a_2_first_half = choice_a2[:int(half)]
            a_2_last_half = choice_a2[int(half):]
        
        
        choice_a3 = mean_spikes_around_choice[a3_nR]  

        if choice_a3.shape[0]%2 == 0:
            half = (choice_a3.shape[0])/2
            a_3_first_half = choice_a3[:int(half)]
            a_3_last_half = choice_a3[int(half):]

        else: # If number of trials is uneven 
            half = (choice_a3.shape[0]-1)/2
            a_3_first_half = choice_a3[:int(half)]
            a_3_last_half = choice_a3[int(half):]
       

        a1_pokes_mean_1 = np.mean(a_1_first_half, axis = 0)
        a1_pokes_mean_2 = np.mean(a_1_last_half, axis = 0)

                     
        a2_pokes_mean_1 = np.mean(a_2_first_half, axis = 0)
        a2_pokes_mean_2 = np.mean(a_2_last_half, axis = 0)

        a3_pokes_mean_1 = np.mean(a_3_first_half, axis = 0)
        a3_pokes_mean_2 = np.mean(a_3_last_half, axis = 0)
        
        for neuron in range(n_neurons):

            if a1_pokes_mean_1[neuron] > baseline_mean_all_trials[neuron] + 3*std_all_trials[neuron] \
            or a1_pokes_mean_2[neuron] > baseline_mean_all_trials[neuron] + 3*std_all_trials[neuron] \
            or a2_pokes_mean_1[neuron] > baseline_mean_all_trials[neuron] + 3*std_all_trials[neuron] \
            or a2_pokes_mean_2[neuron] > baseline_mean_all_trials[neuron] + 3*std_all_trials[neuron] \
            or a3_pokes_mean_1[neuron] > baseline_mean_all_trials[neuron] + 3*std_all_trials[neuron] \
            or a3_pokes_mean_2[neuron] > baseline_mean_all_trials[neuron] + 3*std_all_trials[neuron]:
                index_neuron.append(neuron)
        if len(index_neuron) > 0:
            a1_a1_angle = re.angle(a1_pokes_mean_1[index_neuron], a1_pokes_mean_2[index_neuron])            
            a2_a2_angle = re.angle(a2_pokes_mean_1[index_neuron], a2_pokes_mean_2[index_neuron])            
            a1_a2_angle = re.angle(a1_pokes_mean_2[index_neuron], a2_pokes_mean_1[index_neuron])
            a2_a3_angle = re.angle(a2_pokes_mean_2[index_neuron], a3_pokes_mean_1[index_neuron])
    
            session_a1_a1.append(a1_a1_angle)
            session_a2_a2.append(a2_a2_angle)
            session_a1_a2.append(a1_a2_angle)
            session_a2_a3.append(a2_a3_angle)
            
    mean_angle_a1_a1 = np.nanmean(session_a1_a1)
    mean_angle_a1_a2 = np.nanmean(session_a1_a2)
    mean_angle_a2_a2 = np.nanmean(session_a2_a2)
    mean_angle_a2_a3 = np.nanmean(session_a2_a3)
    mean_within = np.mean([mean_angle_a1_a1,mean_angle_a2_a2])
    std_within = np.nanstd([mean_angle_a1_a1,mean_angle_a2_a2])
    mean_between = np.mean([mean_angle_a1_a2,mean_angle_a2_a3])
    std_between = np.nanstd([mean_angle_a1_a2,mean_angle_a2_a3])
    
    
    return mean_within, mean_between, std_within, std_between


#plt.bar([1,2,3,4],[mean_within_HP,mean_within_PFC, mean_between_HP, mean_between_PFC], tick_label =['Within HP', 'Within PFC','Between HP','Between PFC',])

   
       
Exemplo n.º 5
0
def remapping_timecourse(experiment):
    
    # Lists for appending the last 20 trials of task 1 and the first 20 trials of task 2 for neurons 
    # that either decreased or increased their firing rates between two tasks around choice time
    
    a1_a2_list_increase = []
    a2_a3_list_increase = []
    a1_a2_list_decrease = []
    a2_a3_list_decrease = []
    
    for session in experiment:
        
        
        aligned_spikes= session.aligned_rates 
        n_trials, n_neurons, n_timepoints = aligned_spikes.shape
        
        # Numpy arrays to fill the firing rates of each neuron on the 40 trials where the A choice was made
        a1_a2_increase = np.ones(shape=(n_neurons,100))
        a1_a2_increase[:] = np.NaN
        a2_a3_increase= np.ones(shape=(n_neurons,100))
        a2_a3_increase[:] = np.NaN
        
        a1_a2_decrease = np.ones(shape=(n_neurons,100))
        a1_a2_decrease[:] = np.NaN
        a2_a3_decrease= np.ones(shape=(n_neurons,100))
        a2_a3_decrease[:] = np.NaN
        
        # Trial indices  of choices 
        predictor_A_Task_1, predictor_A_Task_2, predictor_A_Task_3, predictor_B_Task_1,\
        predictor_B_Task_2, predictor_B_Task_3, reward = re.predictors_pokes(session)
        t_out = session.t_out
      
        initiate_choice_t = session.target_times #Initiation and Choice Times
        
        ind_choice = (np.abs(t_out-initiate_choice_t[-2])).argmin() # Find firing rates around choice
        ind_after_choice = ind_choice + 7 # 1 sec after choice
        spikes_around_choice = aligned_spikes[:,:,ind_choice-2:ind_after_choice] # Find firing rates only around choice
        mean_spikes_around_choice  = np.mean(spikes_around_choice,axis =0)
        for i in mean_spikes_around_choice:
            figure()
            plot(mean_spikes_around_choice[i,:])
        forced_trials = session.trial_data['forced_trial']
        non_forced_array = np.where(forced_trials == 0)[0]
        task = session.trial_data['task']
        task_non_forced = task[non_forced_array]
        task_2 = np.where(task_non_forced == 2)[0] 
        a_pokes = predictor_A_Task_1 + predictor_A_Task_2 + predictor_A_Task_3 # All A pokes across all three tasks
        
        mean_trial = np.mean(spikes_around_choice,axis = 2) # Mean firing rates around choice 
        
        a_1 = np.where(predictor_A_Task_1 == 1)
        a_2 = np.where(predictor_A_Task_2 == 1)
        a_3 = np.where(predictor_A_Task_3 == 1)
        task_2_start = task_2[0]
        task_3_start = task_2[-1]+1
        n_trials_of_interest = 50
        
        where_a_task_1_2= np.where(a_pokes[task_2_start - n_trials_of_interest: task_2_start + n_trials_of_interest] == 1) # Indices of A pokes in the 40 trials around 1 and 2 task switch
        where_a_task_2_3 = np.where(a_pokes[task_3_start - n_trials_of_interest:task_3_start + n_trials_of_interest] == 1) # Indices of A pokes in the 40 trials around 2 and 3 task switch

        for neuron in range(n_neurons):
            trials_firing = mean_trial[:,neuron]  # Firing rate of each neuron
           
            a1_fr = trials_firing[a_1]  # Firing rates on poke A choices in Task 1 
            a1_mean = np.mean(a1_fr, axis = 0)  # Mean rate on poke A choices in Task 1 
            a1_std = np.std(a1_fr) # Standard deviation on poke A choices in Task 1 
            a2_fr = trials_firing[a_2]
            a2_std = np.std(a2_fr)
            a2_mean = np.mean(a2_fr, axis = 0)
            a3_fr = trials_firing[a_3]
            a3_mean = np.mean(a3_fr, axis = 0)
            # If mean firing rate on a2 is higher than on a1 or mean firing rate on a3 
            #is higher than a2 find the firing rate on the trials for that neuron
            if a2_mean > (a1_mean+(3*a1_std)) or a3_mean > (a2_mean+(3*a2_std)):
                t1_t2 = trials_firing[task_2_start - n_trials_of_interest:task_2_start + n_trials_of_interest] 
                t2_t3 = trials_firing[task_3_start - n_trials_of_interest:task_3_start + n_trials_of_interest]
                
                a1_a2_increase[neuron,where_a_task_1_2] = t1_t2[where_a_task_1_2]
                a2_a3_increase[neuron,where_a_task_2_3] = t2_t3[where_a_task_2_3]
            # If mean firing rate on a2 is lower than on a1 or mean firing rate on a3 
            #is lower than a2 find the firing rate on the trials for that neuron
            elif a2_mean < (a1_mean+(3*a1_std)) or a3_mean < (a2_mean+(3*a2_std)):
                t1_t2 = trials_firing[task_2_start - n_trials_of_interest:task_2_start + n_trials_of_interest] 
                t2_t3 = trials_firing[task_3_start - n_trials_of_interest:task_3_start + n_trials_of_interest]
                a1_a2_decrease[neuron,where_a_task_1_2] = t1_t2[where_a_task_1_2]
                a2_a3_decrease[neuron,where_a_task_2_3] = t2_t3[where_a_task_2_3]
                
        a1_a2_list_increase.append(a1_a2_increase)
        a2_a3_list_increase.append(a2_a3_increase)
        a1_a2_list_decrease.append(a1_a2_decrease)
        a2_a3_list_decrease.append(a2_a3_decrease)
             
             
    a1_a2_list_increase = np.array(a1_a2_list_increase)
    a2_a3_list_increase = np.array(a2_a3_list_increase)
    
    a1_a2_list_decrease = np.array(a1_a2_list_decrease)
    a2_a3_list_decrease = np.array(a2_a3_list_decrease)
    
    a_list_increase = np.concatenate([a1_a2_list_increase,a2_a3_list_increase])
    a_list_decrease = np.concatenate([a1_a2_list_decrease,a2_a3_list_decrease])

    flattened_a_list_increase = []
    for x in a_list_increase:
        for y in x:
            index = np.isnan(y) 
            if np.all(index):
                continue
            else:
                flattened_a_list_increase.append(y)
                
    flattened_a_list_decrease = []
    for x in a_list_decrease:
        for y in x:
            index = np.isnan(y) 
            if np.all(index):
                continue
            else:
                flattened_a_list_decrease.append(y)
    flattened_a_list_increase = np.array(flattened_a_list_increase)
    flattened_a_list_decrease = np.array(flattened_a_list_decrease)
    x_array = np.arange(1,101)
    task_change = 50
    mean_increase = np.nanmean(flattened_a_list_increase, axis = 0)
    mean_decrease = np.nanmean(flattened_a_list_decrease, axis = 0)
    plt.figure()
    smoothed_dec = gs(mean_decrease, 15)
    plt.plot(x_array, smoothed_dec)
    plt.axvline(task_change, color='k', linestyle=':')
    plt.ylabel('Number of Trials Before and After Task Change')
    plt.xlabel('Mean Firing Rate')
    plt.title('Cells that Decrease Firing Rates')
    
    plt.figure()
    smoothed_inc = gs(mean_increase, 15)
    plt.plot(x_array, smoothed_inc)
    plt.axvline(task_change, color='k', linestyle=':')
    plt.ylabel('Number of Trials Before and After Task Change')
    plt.xlabel('Mean Firing Rate')
    plt.title('Cells that Increase Firing Rates')
def block_change(experiment):
    neuron_n = 0
    all_neurons = 0
    for session in experiment:
        if session.file_name != 'm486-2018-07-28-171910.txt' and session.file_name != 'm480-2018-08-14-145623.txt':
            aligned_spikes = session.aligned_rates
            n_trials, n_neurons, n_timepoints = aligned_spikes.shape
            all_neurons += n_neurons

    a1_fr_r_last = np.zeros([all_neurons, 4])
    a1_fr_r_first = np.zeros([all_neurons, 4])

    a1_fr_nr_last = np.zeros([all_neurons, 4])
    a1_fr_nr_first = np.zeros([all_neurons, 4])

    a2_fr_r_last = np.zeros([all_neurons, 4])
    a2_fr_r_first = np.zeros([all_neurons, 4])

    a2_fr_nr_last = np.zeros([all_neurons, 4])
    a2_fr_nr_first = np.zeros([all_neurons, 4])

    a3_fr_r_last = np.zeros([all_neurons, 4])
    a3_fr_r_first = np.zeros([all_neurons, 4])

    a3_fr_nr_last = np.zeros([all_neurons, 4])
    a3_fr_nr_first = np.zeros([all_neurons, 4])

    for i, session in enumerate(experiment):
        if session.file_name != 'm486-2018-07-28-171910.txt' and session.file_name != 'm480-2018-08-14-145623.txt':
            aligned_spikes = session.aligned_rates
            n_trials, n_neurons, n_timepoints = aligned_spikes.shape

            # Trial indices  of choices
            predictor_A_Task_1, predictor_A_Task_2, predictor_A_Task_3,\
            predictor_B_Task_1, predictor_B_Task_2, predictor_B_Task_3, reward,\
            predictor_a_good_task_1,predictor_a_good_task_2, predictor_a_good_task_3 = re.predictors_pokes(session)
            t_out = session.t_out

            initiate_choice_t = session.target_times  #Initiation and Choice Times

            ind_choice = (np.abs(t_out - initiate_choice_t[-2])
                          ).argmin()  # Find firing rates around choice
            ind_after_choice = ind_choice + 7  # 1 sec after choice
            spikes_around_choice = aligned_spikes[:, :, ind_choice - 2:
                                                  ind_after_choice]  # Find firing rates only around choice
            aligned_spikes = np.mean(spikes_around_choice, axis=2)

            a_r_1 = aligned_spikes[np.where((predictor_A_Task_1 == 1)
                                            & (reward == 1)), :]
            a_nr_1 = aligned_spikes[np.where((predictor_A_Task_1 == 1)
                                             & (reward == 0)), :]

            a_r_2 = aligned_spikes[np.where((predictor_A_Task_2 == 1)
                                            & (reward == 1)), :]
            a_nr_2 = aligned_spikes[np.where((predictor_A_Task_2 == 1)
                                             & (reward == 0)), :]

            a_r_3 = aligned_spikes[np.where((predictor_A_Task_3 == 1)
                                            & (reward == 1)), :]
            a_nr_3 = aligned_spikes[np.where((predictor_A_Task_3 == 1)
                                             & (reward == 0)), :]
            for neuron in range(n_neurons):

                a1_fr_r_last[neuron_n, :] = a_r_1[0, -4:, neuron]
                a1_fr_r_first[neuron_n, :] = a_r_1[0, :4, neuron]

                a1_fr_nr_last[neuron_n, :] = a_nr_1[0, -4:, neuron]
                a1_fr_nr_first[neuron_n, :] = a_nr_1[0, :4, neuron]

                a2_fr_r_last[neuron_n, :] = a_r_2[0, -4:, neuron]
                a2_fr_r_first[neuron_n, :] = a_r_2[0, :4, neuron]

                a2_fr_nr_last[neuron_n, :] = a_nr_2[0, -4:, neuron]
                a2_fr_nr_first[neuron_n, :] = a_nr_2[0, :4, neuron]

                a3_fr_r_last[neuron_n, :] = a_r_3[0, -4:, neuron]
                a3_fr_r_first[neuron_n, :] = a_r_3[0, :4, neuron]

                a3_fr_nr_last[neuron_n, :] = a_nr_3[0, -4:, neuron]
                a3_fr_nr_first[neuron_n, :] = a_nr_3[0, :4, neuron]

                neuron_n += 1

            a = np.concatenate([a1_fr_r_first,a1_fr_r_last,a1_fr_nr_first,a1_fr_nr_last,a2_fr_r_first,a2_fr_r_last,a2_fr_nr_first,a2_fr_nr_last,\
                                a3_fr_r_first,a3_fr_r_last,a3_fr_nr_first,a3_fr_nr_last], axis = 1)

            transitions = np.concatenate([a1_fr_r_last,a1_fr_nr_last,a2_fr_r_first,a2_fr_nr_first,a2_fr_nr_last,a2_fr_r_last,\
                                a3_fr_r_first,a3_fr_nr_first], axis = 1)
def regression_choices_c(experiment, all_sessions):

    C_task_1 = []  # To strore predictor loadings for each session in task 1.
    C_task_2 = []  # To strore predictor loadings for each session in task 2.
    C_task_3 = []  # To strore predictor loadings for each session in task 2.

    # Finding correlation coefficients for task 1
    for s, session in enumerate(experiment):
        all_neurons_all_spikes_raster_plot_task = all_sessions[s]

        if all_neurons_all_spikes_raster_plot_task.shape[1] > 0:
            #Select  Choices only
            all_neurons_all_spikes_raster_plot_task = all_neurons_all_spikes_raster_plot_task[
                1::2, :, :]
            predictor_A_Task_1, predictor_A_Task_2, predictor_A_Task_3,\
            predictor_B_Task_1, predictor_B_Task_2, predictor_B_Task_3, reward,\
            predictor_a_good_task_1,predictor_a_good_task_2, predictor_a_good_task_3 = re.predictors_pokes(session)

            # Getting out task indicies
            task = session.trial_data['task']
            forced_trials = session.trial_data['forced_trial']
            non_forced_array = np.where(forced_trials == 0)[0]
            task_non_forced = task[non_forced_array]
            task_1 = np.where(task_non_forced == 1)[0]
            task_2 = np.where(task_non_forced == 2)[0]

            n_trials, n_neurons, n_timepoints = all_neurons_all_spikes_raster_plot_task.shape

            # For regressions for each task independently
            predictor_A_Task_1 = predictor_A_Task_1[:len(task_1)]
            all_neurons_all_spikes_raster_plot_task_1 = all_neurons_all_spikes_raster_plot_task[:len(
                task_1), :, :]
            all_neurons_all_spikes_raster_plot_task_1 = np.mean(
                all_neurons_all_spikes_raster_plot_task_1, axis=2)

            predictors_task_1 = OrderedDict([('A_task_1', predictor_A_Task_1)])

            X_task_1 = np.vstack(predictors_task_1.values()
                                 ).T[:len(predictor_A_Task_1), :].astype(float)
            n_predictors = X_task_1.shape[1]
            y_t1 = all_neurons_all_spikes_raster_plot_task_1.reshape(
                [all_neurons_all_spikes_raster_plot_task_1.shape[0],
                 -1])  # Activity matrix [n_trials, n_neurons*n_timepoints]

            ols = LinearRegression(copy_X=True, fit_intercept=True)
            ols.fit(X_task_1, y_t1)
            C_task_1.append(ols.coef_.reshape(
                n_neurons, n_predictors))  # Predictor loadings

            # For regressions for each task independently
            predictor_A_Task_2 = predictor_A_Task_2[len(task_1):len(task_1) +
                                                    len(task_2)]
            all_neurons_all_spikes_raster_plot_task_2 = all_neurons_all_spikes_raster_plot_task[
                len(task_1):len(task_1) + len(task_2), :, :]
            all_neurons_all_spikes_raster_plot_task_2 = np.mean(
                all_neurons_all_spikes_raster_plot_task_2, axis=2)

            predictors_task_2 = OrderedDict([('A_task_2', predictor_A_Task_2)])

            X_task_2 = np.vstack(predictors_task_2.values()
                                 ).T[:len(predictor_A_Task_2), :].astype(float)
            n_predictors = X_task_2.shape[1]
            y_t2 = all_neurons_all_spikes_raster_plot_task_2.reshape(
                [all_neurons_all_spikes_raster_plot_task_2.shape[0],
                 -1])  # Activity matrix [n_trials, n_neurons*n_timepoints]
            ols = LinearRegression(copy_X=True, fit_intercept=True)
            ols.fit(X_task_2, y_t2)
            C_task_2.append(ols.coef_.reshape(
                n_neurons, n_predictors))  # Predictor loadings

            # For regressions for each task independently
            predictor_A_Task_3 = predictor_A_Task_3[len(task_1) + len(task_2):]
            all_neurons_all_spikes_raster_plot_task_3 = all_neurons_all_spikes_raster_plot_task[
                len(task_1) + len(task_2):, :, :]
            all_neurons_all_spikes_raster_plot_task_3 = np.mean(
                all_neurons_all_spikes_raster_plot_task_3, axis=2)

            predictors_task_3 = OrderedDict([('A_task_3', predictor_A_Task_3)])

            X_task_3 = np.vstack(predictors_task_3.values()
                                 ).T[:len(predictor_A_Task_3), :].astype(float)
            n_predictors = X_task_3.shape[1]
            y_t3 = all_neurons_all_spikes_raster_plot_task_3.reshape(
                [all_neurons_all_spikes_raster_plot_task_3.shape[0],
                 -1])  # Activity matrix [n_trials, n_neurons*n_timepoints]
            ols = LinearRegression(copy_X=True, fit_intercept=True)
            ols.fit(X_task_3, y_t3)
            C_task_3.append(ols.coef_.reshape(
                n_neurons, n_predictors))  # Predictor loadings

    C_task_1 = np.concatenate(C_task_1, 0)
    C_task_2 = np.concatenate(C_task_2, 0)
    C_task_3 = np.concatenate(C_task_3, 0)

    return C_task_1, C_task_2, C_task_3
def remapping_surprise(experiment, distribution):

    surprise_list = []
    surprise_list_2 = []
    surprise_list_neurons_1_2 = []
    surprise_list_neurons_2_3 = []
    for i, session in enumerate(experiment):

        aligned_spikes = session.aligned_rates
        n_trials, n_neurons, n_timepoints = aligned_spikes.shape

        # Trial indices  of choices
        predictor_A_Task_1, predictor_A_Task_2, predictor_A_Task_3,\
        predictor_B_Task_1, predictor_B_Task_2, predictor_B_Task_3, reward,\
        predictor_a_good_task_1,predictor_a_good_task_2, predictor_a_good_task_3 = re.predictors_pokes(session)
        t_out = session.t_out

        initiate_choice_t = session.target_times  #Initiation and Choice Times

        ind_choice = (np.abs(t_out - initiate_choice_t[-2])
                      ).argmin()  # Find firing rates around choice
        ind_after_choice = ind_choice + 7  # 1 sec after choice
        spikes_around_choice = aligned_spikes[:, :, ind_choice - 2:
                                              ind_after_choice]  # Find firing rates only around choice
        mean_spikes_around_choice = np.mean(spikes_around_choice, axis=2)

        a_1 = np.where(predictor_A_Task_1 == 1)
        a_2 = np.where(predictor_A_Task_2 == 1)
        a_3 = np.where(predictor_A_Task_3 == 1)

        baseline_mean_trial = np.mean(aligned_spikes, axis=2)
        baseline_mean_all_trials = np.mean(baseline_mean_trial, axis=0)
        std_all_trials = np.std(baseline_mean_trial, axis=1)

        choice_a1 = mean_spikes_around_choice[a_1]
        choice_a2 = mean_spikes_around_choice[a_2]
        choice_a3 = mean_spikes_around_choice[a_3]

        choice_a1_mean = np.mean(choice_a1, axis=0)
        choice_a2_mean = np.mean(choice_a2, axis=0)
        choice_a3_mean = np.mean(choice_a3, axis=0)

        for neuron in range(n_neurons):
            trials_firing = mean_spikes_around_choice[:, neuron]
            if choice_a1_mean[neuron] > baseline_mean_all_trials[neuron] + 3*std_all_trials[neuron] \
            or choice_a2_mean[neuron] > baseline_mean_all_trials[neuron] + 3*std_all_trials[neuron] \
            or choice_a3_mean[neuron] > baseline_mean_all_trials[neuron] + 3*std_all_trials[neuron] :

                a1_fr = trials_firing[
                    a_1]  # Firing rates on poke A choices in Task 1
                a1_poisson = a1_fr.astype(int)
                a2_fr = trials_firing[a_2]
                a2_poisson = a2_fr.astype(int)
                a3_fr = trials_firing[a_3]
                a3_poisson = a3_fr.astype(int)

                a1_mean = np.nanmean(a1_fr)
                a1_poisson = np.nanmean(a1_poisson)
                a1_std = np.nanstd(a1_fr)
                a2_mean = np.nanmean(a2_fr)
                a2_poisson = np.nanmean(a2_poisson)
                a2_std = np.nanstd(a2_fr)
                a3_poisson = np.nanmean(a3_poisson)
                a3_mean = np.nanmean(a3_fr)

                a1_fr_last = a1_fr[-15:]
                a1_fr_last_poisson = a1_fr_last.astype(int)
                a2_fr_first = a2_fr[:15]
                a2_fr_first_poisson = a2_fr_first.astype(int)
                a2_fr_last = a2_fr[-15:]
                a2_fr_last_poisson = a2_fr_last.astype(int)
                a3_fr_first = a3_fr[:15]
                a3_fr_first_poisson = a3_fr_first.astype(int)

                if a1_mean > 0.1 and a2_mean > 0.1 and a3_mean > 0.1:
                    if distribution == 'Normal':

                        surprise_a1 = -norm.logpdf(a1_fr_last, a1_mean, a1_std)

                        surprise_a2 = -norm.logpdf(a2_fr_first, a1_mean,
                                                   a1_std)

                        surprise_a2_last = -norm.logpdf(
                            a2_fr_last, a2_mean, a2_std)

                        surprise_a3_first = -norm.logpdf(
                            a3_fr_first, a2_mean, a2_std)

                    elif distribution == 'Poisson':

                        surprise_a1 = -poisson.logpmf(a1_fr_last_poisson,
                                                      mu=a1_poisson)

                        surprise_a2 = -poisson.logpmf(a2_fr_first_poisson,
                                                      mu=a1_poisson)

                        surprise_a2_last = -poisson.logpmf(a2_fr_last_poisson,
                                                           mu=a2_poisson)

                        surprise_a3_first = -poisson.logpmf(
                            a3_fr_first_poisson, mu=a2_poisson)

                    surprise_array_t1_2 = np.concatenate(
                        [surprise_a1, surprise_a2])

                    surprise_array_t2_3 = np.concatenate(
                        [surprise_a2_last, surprise_a3_first])

                if len(surprise_array_t1_2) > 0 and len(
                        surprise_array_t2_3) > 0:
                    surprise_list_neurons_1_2.append(surprise_array_t1_2)
                    surprise_list_neurons_2_3.append(surprise_array_t2_3)

        surprise_list.append(surprise_list_neurons_1_2)
        surprise_list_2.append(surprise_list_neurons_2_3)

    surprise_list_neurons_1_2 = np.array(surprise_list_neurons_1_2)
    surprise_list_neurons_2_3 = np.array(surprise_list_neurons_2_3)

    mean_1_2 = np.nanmean(surprise_list_neurons_1_2, axis=0)
    std_1_2 = np.nanstd(surprise_list_neurons_1_2, axis=0)
    serr_1_2 = std_1_2 / np.sqrt(len(surprise_list_neurons_1_2))
    mean_2_3 = np.nanmean(surprise_list_neurons_2_3, axis=0)
    std_2_3 = np.nanstd(surprise_list_neurons_2_3, axis=0)
    serr_2_3 = std_2_3 / np.sqrt(len(surprise_list_neurons_2_3))

    x_pos = np.arange(len(mean_2_3))
    task_change = 15

    allmeans = mean_1_2 + mean_2_3 / 2
    allserr = serr_1_2 + serr_2_3 / 2

    plt.figure()
    plt.plot(x_pos, mean_1_2)
    plt.fill_between(x_pos,
                     mean_1_2 - serr_1_2,
                     mean_1_2 + serr_1_2,
                     alpha=0.2)
    plt.axvline(task_change, color='k', linestyle=':')
    plt.title('Task 1 and 2')
    plt.ylabel('-log(p(X))')
    plt.xlabel('Trial #')

    plt.figure()
    plt.plot(x_pos, mean_2_3)
    plt.fill_between(x_pos,
                     mean_2_3 - serr_2_3,
                     mean_2_3 + serr_2_3,
                     alpha=0.2)
    plt.axvline(task_change, color='k', linestyle=':')
    plt.title('Task 2 and 3')
    plt.ylabel('-log(p(X))')
    plt.xlabel('Trial #')

    plt.figure()
    plt.plot(x_pos, allmeans)
    plt.fill_between(x_pos, allmeans - allserr, allmeans + allserr, alpha=0.2)
    plt.axvline(task_change, color='k', linestyle=':')
    plt.title('Combined')
    plt.ylabel('-log(p(X))')
    plt.xlabel('Trial #')
def task_specific_regressors(session):
    task = session.trial_data['task']
    forced_trials = session.trial_data['forced_trial']
    non_forced_array = np.where(forced_trials == 0)[0]
    task_non_forced = task[non_forced_array]
    task_1 = np.where(task_non_forced == 1)[0]
    task_2 = np.where(task_non_forced == 2)[0]

    predictor_A_Task_1, predictor_A_Task_2, predictor_A_Task_3,\
    predictor_B_Task_1, predictor_B_Task_2, predictor_B_Task_3, reward,\
    predictor_a_good_task_1,predictor_a_good_task_2, predictor_a_good_task_3 = re.predictors_pokes(session)

    predictor_B_Task_1[len(task_1) + len(task_2):] = -1
    predictor_B_Task_1[len(task_1):len(task_1) + len(task_2)] = -2

    predictor_B_Task_2[len(task_1) + len(task_2):] = -1
    predictor_B_Task_2[:len(task_1)] = -2

    predictor_B_Task_1 = predictor_B_Task_1 + predictor_B_Task_3
    predictor_B_Task_2 = predictor_B_Task_2 + predictor_B_Task_3

    predictor_B_Task_1_choice = []
    predictor_B_Task_2_choice = []

    for c, choice in enumerate(predictor_B_Task_1):
        if choice == 1:
            predictor_B_Task_1_choice.append(0)
            predictor_B_Task_1_choice.append(1)
        elif choice == 0:
            predictor_B_Task_1_choice.append(0)
            predictor_B_Task_1_choice.append(0)
        elif choice == -2:
            predictor_B_Task_1_choice.append(0)
            predictor_B_Task_1_choice.append(0)
        elif choice == -1:
            predictor_B_Task_1_choice.append(0)
            predictor_B_Task_1_choice.append(-1)

    for c, choice in enumerate(predictor_B_Task_2):
        if choice == 1:
            predictor_B_Task_2_choice.append(0)
            predictor_B_Task_2_choice.append(1)
        elif choice == 0:
            predictor_B_Task_2_choice.append(0)
            predictor_B_Task_2_choice.append(0)
        elif choice == -2:
            predictor_B_Task_2_choice.append(0)
            predictor_B_Task_2_choice.append(0)
        elif choice == -1:
            predictor_B_Task_2_choice.append(0)
            predictor_B_Task_2_choice.append(-1)

    predictor_B_Task_1_initiation = []
    predictor_B_Task_2_initiation = []

    for c, choice in enumerate(predictor_B_Task_1):
        if choice == 1:
            predictor_B_Task_1_initiation.append(1)
            predictor_B_Task_1_initiation.append(0)
        elif choice == -2:
            predictor_B_Task_1_initiation.append(0)
            predictor_B_Task_1_initiation.append(0)
        elif choice == 0:
            predictor_B_Task_1_initiation.append(1)
            predictor_B_Task_1_initiation.append(0)
        elif choice == -1:
            predictor_B_Task_1_initiation.append(-1)
            predictor_B_Task_1_initiation.append(0)

    for c, choice in enumerate(predictor_B_Task_2):
        if choice == 1:
            predictor_B_Task_2_initiation.append(1)
            predictor_B_Task_2_initiation.append(0)
        elif choice == -2:
            predictor_B_Task_2_initiation.append(0)
            predictor_B_Task_2_initiation.append(0)
        elif choice == 0:
            predictor_B_Task_2_initiation.append(1)
            predictor_B_Task_2_initiation.append(0)
        elif choice == -1:
            predictor_B_Task_2_initiation.append(-1)
            predictor_B_Task_2_initiation.append(0)

    t_3 = predictor_B_Task_1_initiation[(len(task_1) + len(task_2)) * 2:]
    t_3 = np.asarray(t_3)

    ind = np.where(t_3 == 1)
    ind = np.asarray(ind) + (len(task_1) + len(task_2)) * 2

    predictor_B_Task_1_initiation = np.asarray(predictor_B_Task_1_initiation)
    predictor_B_Task_2_initiation = np.asarray(predictor_B_Task_2_initiation)
    predictor_B_Task_1_initiation[ind[0]] = 0
    predictor_B_Task_2_initiation[ind[0]] = 0

    return predictor_B_Task_1_initiation, predictor_B_Task_2_initiation, predictor_B_Task_1_choice, predictor_B_Task_2_choice
def predictors_around_pokes_include_a(session):
    poke_identity, outcomes_non_forced, initation_choice, initiation_time_stamps, poke_list_A, poke_list_B, all_events, constant_poke_a, choices, trial_times = extract_poke_times_include_a(
        session)
    unique_pokes = np.unique(poke_identity)

    predictor_A_Task_1, predictor_A_Task_2, predictor_A_Task_3,\
    predictor_B_Task_1, predictor_B_Task_2, predictor_B_Task_3, reward,\
    predictor_a_good_task_1,predictor_a_good_task_2, predictor_a_good_task_3 = re.predictors_pokes(session)

    poke_A = predictor_A_Task_1 + predictor_A_Task_2 + predictor_A_Task_3

    i = np.where(unique_pokes == constant_poke_a)
    unique_pokes = np.delete(unique_pokes, i)

    poke_1_id = constant_poke_a
    poke_2_id = unique_pokes[0]
    poke_3_id = unique_pokes[1]
    poke_4_id = unique_pokes[2]
    poke_5_id = unique_pokes[3]

    poke_1 = np.zeros(len(poke_identity))
    poke_2 = np.zeros(len(poke_identity))
    poke_3 = np.zeros(len(poke_identity))
    poke_4 = np.zeros(len(poke_identity))
    poke_5 = np.zeros(len(poke_identity))

    # Make a predictor for outcome which codes Initiation as 0
    outcomes = []
    for o, outcome in enumerate(outcomes_non_forced):
        outcomes.append(0)
        if outcome == 1:
            outcomes.append(1)
        elif outcome == 0:
            outcomes.append(0)
    outcomes = np.asarray(outcomes)

    choices = []
    for c, choice in enumerate(poke_A):
        choices.append(0)
        if choice == 1:
            #choices.append(0)
            choices.append(0.5)
        elif choice == 0:
            #choices.append(0)
            choices.append(-0.5)

    choices = np.asarray(choices)

    init_choices_a_b = []
    for c, choice in enumerate(poke_A):
        if choice == 1:
            init_choices_a_b.append(0)
            init_choices_a_b.append(0)
        elif choice == 0:
            init_choices_a_b.append(1)
            init_choices_a_b.append(1)
    init_choices_a_b = np.asarray(init_choices_a_b)

    choices_initiation = []
    for c, choice in enumerate(poke_A):
        choices_initiation.append(1)
        if choice == 1:
            choices_initiation.append(0)
        elif choice == 0:
            choices_initiation.append(0)

    choices_initiation = np.asarray(choices_initiation)

    for p, poke in enumerate(poke_identity):
        if poke == poke_1_id:
            poke_1[p] = 1
        if poke == poke_2_id:
            poke_2[p] = 1
        elif poke == poke_3_id:
            poke_3[p] = 1
        elif poke == poke_4_id:
            poke_4[p] = 1
        elif poke == poke_5_id:
            poke_5[p] = 1

    return poke_1, poke_2, poke_3, poke_4, poke_5, outcomes, initation_choice, unique_pokes, constant_poke_a, choices, choices_initiation, init_choices_a_b
def a_b_previous_choice_for_svd(session):

    # Ones are As
    predictor_A_Task_1, predictor_A_Task_2, predictor_A_Task_3,\
    predictor_B_Task_1, predictor_B_Task_2, predictor_B_Task_3, reward,\
    predictor_a_good_task_1,predictor_a_good_task_2, predictor_a_good_task_3 = re.predictors_pokes(session)

    a_after_a_task_1 = []
    a_after_a_task_2 = []
    a_after_a_task_3 = []

    a_after_b_task_1 = []
    a_after_b_task_2 = []
    a_after_b_task_3 = []

    b_after_b_task_1 = []
    b_after_b_task_2 = []
    b_after_b_task_3 = []

    b_after_a_task_1 = []
    b_after_a_task_2 = []
    b_after_a_task_3 = []

    task = session.trial_data['task']
    forced_trials = session.trial_data['forced_trial']
    non_forced_array = np.where(forced_trials == 0)[0]

    task_non_forced = task[non_forced_array]
    task_1 = np.where(task_non_forced == 1)[0]
    task_1_len = len(task_1)

    task_2 = np.where(task_non_forced == 2)[0]
    task_2_len = len(task_2)

    predictor_A_Task_1[task_1_len:] = -1
    predictor_A_Task_2[:task_1_len] = -1
    predictor_A_Task_2[(task_1_len + task_2_len):] = -1
    predictor_A_Task_3[:(task_1_len + task_2_len)] = -1

    for i, predictor in enumerate(predictor_A_Task_1):
        if i > 0:
            if predictor_A_Task_1[i - 1] == 1 and predictor_A_Task_1[i] == 1:
                a_after_a_task_1.append(1)
                a_after_b_task_1.append(0)
                b_after_b_task_1.append(0)
                b_after_a_task_1.append(0)

            elif predictor_A_Task_1[i - 1] == 0 and predictor_A_Task_1[i] == 1:
                a_after_b_task_1.append(1)
                a_after_a_task_1.append(0)
                b_after_b_task_1.append(0)
                b_after_a_task_1.append(0)

            elif predictor_A_Task_1[i - 1] == 0 and predictor_A_Task_1[i] == 0:
                a_after_b_task_1.append(0)
                a_after_a_task_1.append(0)
                b_after_a_task_1.append(0)
                b_after_b_task_1.append(1)

            elif predictor_A_Task_1[i - 1] == 1 and predictor_A_Task_1[i] == 0:
                a_after_b_task_1.append(0)
                a_after_a_task_1.append(0)
                b_after_b_task_1.append(0)
                b_after_a_task_1.append(1)
            else:
                a_after_b_task_1.append(-1)
                a_after_a_task_1.append(-1)
                b_after_b_task_1.append(-1)
                b_after_a_task_1.append(-1)

    for i, predictor in enumerate(predictor_A_Task_2):
        if i > 0:
            if predictor_A_Task_2[i - 1] == 1 and predictor_A_Task_2[i] == 1:
                a_after_a_task_2.append(1)
                a_after_b_task_2.append(0)
                b_after_b_task_2.append(0)
                b_after_a_task_2.append(0)

            elif predictor_A_Task_2[i - 1] == 0 and predictor_A_Task_2[i] == 1:
                a_after_b_task_2.append(1)
                a_after_a_task_2.append(0)
                b_after_b_task_2.append(0)
                b_after_a_task_2.append(0)

            elif predictor_A_Task_2[i - 1] == 0 and predictor_A_Task_2[i] == 0:
                a_after_b_task_2.append(0)
                a_after_a_task_2.append(0)
                b_after_a_task_2.append(0)
                b_after_b_task_2.append(1)

            elif predictor_A_Task_2[i - 1] == 1 and predictor_A_Task_2[i] == 0:
                a_after_b_task_2.append(0)
                a_after_a_task_2.append(0)
                b_after_b_task_2.append(0)
                b_after_a_task_2.append(1)
            else:
                a_after_b_task_2.append(-1)
                a_after_a_task_2.append(-1)
                b_after_b_task_2.append(-1)
                b_after_a_task_2.append(-1)

    for i, predictor in enumerate(predictor_A_Task_3):

        if i > 0:

            if predictor_A_Task_3[i - 1] == 1 and predictor_A_Task_3[i] == 1:
                a_after_a_task_3.append(1)
                a_after_b_task_3.append(0)
                b_after_b_task_3.append(0)
                b_after_a_task_3.append(0)

            elif predictor_A_Task_3[i - 1] == 0 and predictor_A_Task_3[i] == 1:
                a_after_b_task_3.append(1)
                a_after_a_task_3.append(0)
                b_after_b_task_3.append(0)
                b_after_a_task_3.append(0)

            elif predictor_A_Task_3[i - 1] == 0 and predictor_A_Task_3[i] == 0:
                a_after_b_task_3.append(0)
                a_after_a_task_3.append(0)
                b_after_a_task_3.append(0)
                b_after_b_task_3.append(1)

            elif predictor_A_Task_3[i - 1] == 1 and predictor_A_Task_3[i] == 0:
                a_after_b_task_3.append(0)
                a_after_a_task_3.append(0)
                b_after_b_task_3.append(0)
                b_after_a_task_3.append(1)

            else:
                a_after_b_task_3.append(-1)
                a_after_a_task_3.append(-1)
                b_after_b_task_3.append(-1)
                b_after_a_task_3.append(-1)

    reward_previous = []
    for i, predictor in enumerate(reward):
        if i > 0:
            if reward[i - 1] == 1:
                reward_previous.append(1)
            else:
                reward_previous.append(0)

    a_after_a_task_1_reward = np.where((np.asarray(a_after_a_task_1) == 1)
                                       & (np.asarray(reward[1:]) == 1))
    a_after_a_task_2_reward = np.where((np.asarray(a_after_a_task_2) == 1)
                                       & (np.asarray(reward[1:]) == 1))
    a_after_a_task_3_reward = np.where((np.asarray(a_after_a_task_3) == 1)
                                       & (np.asarray(reward[1:]) == 1))

    a_after_b_task_1_reward = np.where((np.asarray(a_after_b_task_1) == 1)
                                       & (np.asarray(reward[1:]) == 1))
    a_after_b_task_2_reward = np.where((np.asarray(a_after_b_task_2) == 1)
                                       & (np.asarray(reward[1:]) == 1))
    a_after_b_task_3_reward = np.where((np.asarray(a_after_b_task_3) == 1)
                                       & (np.asarray(reward[1:]) == 1))

    b_after_b_task_1_reward = np.where((np.asarray(b_after_b_task_1) == 1)
                                       & (np.asarray(reward[1:]) == 1))
    b_after_b_task_2_reward = np.where((np.asarray(b_after_b_task_2) == 1)
                                       & (np.asarray(reward[1:]) == 1))
    b_after_b_task_3_reward = np.where((np.asarray(b_after_b_task_3) == 1)
                                       & (np.asarray(reward[1:]) == 1))

    b_after_a_task_1_reward = np.where((np.asarray(b_after_a_task_1) == 1)
                                       & (np.asarray(reward[1:]) == 1))
    b_after_a_task_2_reward = np.where((np.asarray(b_after_a_task_2) == 1)
                                       & (np.asarray(reward[1:]) == 1))
    b_after_a_task_3_reward = np.where((np.asarray(b_after_a_task_3) == 1)
                                       & (np.asarray(reward[1:]) == 1))

    a_after_a_task_1_nreward = np.where((np.asarray(a_after_a_task_1) == 1)
                                        & (np.asarray(reward[1:]) == 0))
    a_after_a_task_2_nreward = np.where((np.asarray(a_after_a_task_2) == 1)
                                        & (np.asarray(reward[1:]) == 0))
    a_after_a_task_3_nreward = np.where((np.asarray(a_after_a_task_3) == 1)
                                        & (np.asarray(reward[1:]) == 0))

    a_after_b_task_1_nreward = np.where((np.asarray(a_after_b_task_1) == 1)
                                        & (np.asarray(reward[1:]) == 0))
    a_after_b_task_2_nreward = np.where((np.asarray(a_after_b_task_2) == 1)
                                        & (np.asarray(reward[1:]) == 0))
    a_after_b_task_3_nreward = np.where((np.asarray(a_after_b_task_3) == 1)
                                        & (np.asarray(reward[1:]) == 0))

    b_after_b_task_1_nreward = np.where((np.asarray(b_after_b_task_1) == 1)
                                        & (np.asarray(reward[1:]) == 0))
    b_after_b_task_2_nreward = np.where((np.asarray(b_after_b_task_2) == 1)
                                        & (np.asarray(reward[1:]) == 0))
    b_after_b_task_3_nreward = np.where((np.asarray(b_after_b_task_3) == 1)
                                        & (np.asarray(reward[1:]) == 0))

    b_after_a_task_1_nreward = np.where((np.asarray(b_after_a_task_1) == 1)
                                        & (np.asarray(reward[1:]) == 0))
    b_after_a_task_2_nreward = np.where((np.asarray(b_after_a_task_2) == 1)
                                        & (np.asarray(reward[1:]) == 0))
    b_after_a_task_3_nreward = np.where((np.asarray(b_after_a_task_3) == 1)
                                        & (np.asarray(reward[1:]) == 0))

    return a_after_a_task_1_reward,a_after_a_task_2_reward,a_after_a_task_3_reward,a_after_b_task_1_reward,\
    a_after_b_task_2_reward,a_after_b_task_3_reward,b_after_b_task_1_reward,b_after_b_task_2_reward,\
    b_after_b_task_3_reward,b_after_a_task_1_reward,b_after_a_task_2_reward,b_after_a_task_3_reward,\
    a_after_a_task_1_nreward,a_after_a_task_2_nreward,a_after_a_task_3_nreward,\
    a_after_b_task_1_nreward,a_after_b_task_2_nreward,a_after_b_task_3_nreward,\
    b_after_b_task_1_nreward,b_after_b_task_2_nreward,b_after_b_task_3_nreward,\
    b_after_a_task_1_nreward,b_after_a_task_2_nreward,b_after_a_task_3_nreward
Exemplo n.º 12
0
    choice_non_forced = choices[non_forced_array]
    n_trials = len(choice_non_forced)

    n_trials, n_neurons, n_timepoints = aligned_spikes.shape
    t_out = session.t_out
    initiate_choice_t = session.target_times  #Initiation and Choice Times
    ind_choice = (np.abs(t_out - initiate_choice_t[-2])
                  ).argmin()  # Find firing rates around choice
    ind_after_choice = ind_choice + 7  # 1 sec after choice
    spikes_around_choice = aligned_spikes[:, :, ind_choice - 2:
                                          ind_after_choice]  # Find firing rates only around choice
    mean_spikes_around_choice = np.mean(spikes_around_choice, axis=2)

    predictor_A_Task_1, predictor_A_Task_2, predictor_A_Task_3,\
    predictor_B_Task_1, predictor_B_Task_2, predictor_B_Task_3, reward,\
    predictor_a_good_task_1,predictor_a_good_task_2, predictor_a_good_task_3 = re.predictors_pokes(session)
    # Check if a choice happened before the end of the session
    if len(predictor_A_Task_1) != len(choice_non_forced):
        predictor_A_Task_1 = predictor_A_Task_1[:len(choice_non_forced)]
        predictor_A_Task_2 = predictor_A_Task_2[:len(choice_non_forced)]
        predictor_A_Task_3 = predictor_A_Task_3[:len(choice_non_forced)]
        predictor_B_Task_1 = predictor_B_Task_1[:len(choice_non_forced)]
        predictor_B_Task_2 = predictor_B_Task_2[:len(choice_non_forced)]
        predictor_B_Task_3 = predictor_B_Task_3[:len(choice_non_forced)]
        reward = reward[:len(choice_non_forced)]

    A = predictor_A_Task_1 + predictor_A_Task_2 + predictor_A_Task_3
    if len(A) != len(mean_spikes_around_choice):
        mean_spikes_around_choice = mean_spikes_around_choice[:len(A)]

    B = predictor_B_Task_1 + predictor_B_Task_2 + predictor_B_Task_3
Exemplo n.º 13
0
def extract_session_a_b_based_on_block(session, tasks_unchanged = True):
    # Extracta A and B trials based on what block and task it happened
    # Takes session as an argument and outputs A and B trials for when A and B ports were good in every task
    spikes = session.ephys
    spikes = spikes[:,~np.isnan(spikes[1,:])] 
    aligned_rates = session.aligned_rates
    
    poke_A, poke_A_task_2, poke_A_task_3, poke_B, poke_B_task_2, poke_B_task_3,poke_I, poke_I_task_2,poke_I_task_3 = ep.extract_choice_pokes(session)
    trial_сhoice_state_task_1, trial_сhoice_state_task_2, trial_сhoice_state_task_3, ITI_task_1, ITI_task_2,ITI_task_3 = ep.initiation_and_trial_end_timestamps(session)
    task_1 = len(trial_сhoice_state_task_1)
    task_2 = len(trial_сhoice_state_task_2)
    
    
    # Getting choice indices 
    predictor_A_Task_1, predictor_A_Task_2, predictor_A_Task_3,\
    predictor_B_Task_1, predictor_B_Task_2, predictor_B_Task_3, reward,\
    predictor_a_good_task_1,predictor_a_good_task_2, predictor_a_good_task_3 = re.predictors_pokes(session)
    
    if aligned_rates.shape[0] != predictor_A_Task_1.shape[0]:
        predictor_A_Task_1 = predictor_A_Task_1[:aligned_rates.shape[0]] 
        predictor_A_Task_2 = predictor_A_Task_2[:aligned_rates.shape[0]] 
        predictor_A_Task_3 = predictor_A_Task_3[:aligned_rates.shape[0]] 
        predictor_B_Task_1 = predictor_B_Task_1[:aligned_rates.shape[0]] 
        predictor_B_Task_2 = predictor_B_Task_2[:aligned_rates.shape[0]] 
        predictor_B_Task_3 = predictor_B_Task_3[:aligned_rates.shape[0]] 
        reward = reward[:aligned_rates.shape[0]] 
        
    ## If you want to only look at tasks with a shared I port 
    if tasks_unchanged == False:
        if poke_I == poke_I_task_2: 
            aligned_rates_task_1 = aligned_rates[:task_1]
            predictor_A_Task_1 = predictor_A_Task_1[:task_1]
            predictor_B_Task_1 = predictor_B_Task_1[:task_1]
            reward_task_1 = reward[:task_1]
            aligned_rates_task_2 = aligned_rates[:task_1+task_2]
            predictor_A_Task_2 = predictor_A_Task_2[:task_1+task_2]
            predictor_B_Task_2 = predictor_B_Task_2[:task_1+task_2]
            reward_task_2 = reward[:task_1+task_2]
            predictor_a_good_task_1 = predictor_a_good_task_1
            predictor_a_good_task_2 = predictor_a_good_task_2
            
        elif poke_I == poke_I_task_3:
            aligned_rates_task_1 = aligned_rates[:task_1]
            predictor_A_Task_1 = predictor_A_Task_1[:task_1]
            predictor_B_Task_1 = predictor_B_Task_1[:task_1]
            #reward_task_1 = reward[:task_1]
            aligned_rates_task_2 = aligned_rates[task_1+task_2:]
            predictor_A_Task_2 = predictor_A_Task_3[task_1+task_2:]
            predictor_B_Task_2 = predictor_B_Task_3[task_1+task_2:]
            #reward_task_2 = reward[task_1+task_2:]
            
            predictor_a_good_task_1 = predictor_a_good_task_1
            predictor_a_good_task_2 = predictor_a_good_task_3
            
        elif poke_I_task_2 == poke_I_task_3:
            aligned_rates_task_1 = aligned_rates[:task_1+task_2]
            predictor_A_Task_1 = predictor_A_Task_2[:task_1+task_2]
            predictor_B_Task_1 = predictor_B_Task_2[:task_1+task_2]
            #reward_task_1 = reward[:task_1+task_2]
            aligned_rates_task_2 = aligned_rates[task_1+task_2:]
            predictor_A_Task_2 = predictor_A_Task_3[task_1+task_2:]
            predictor_B_Task_2 = predictor_B_Task_3[task_1+task_2:]
            #reward_task_2 = reward[task_1+task_2:]
            
            predictor_a_good_task_1 = predictor_a_good_task_2
            predictor_a_good_task_2 = predictor_a_good_task_3
            
    #Get firing rates for each task
    aligned_rates_task_1 = aligned_rates[:task_1]
    aligned_rates_task_2 = aligned_rates[task_1:task_1+task_2]
    aligned_rates_task_3 = aligned_rates[task_1+task_2:]
    
    #Indicies of A choices in each task (1s) and Bs are just 0s 
    predictor_A_Task_1_cut = predictor_A_Task_1[:task_1]
    #reward_task_1_cut = reward[:task_1]
          
    predictor_A_Task_2_cut = predictor_A_Task_2[task_1:task_1+task_2]
    #reward_task_2_cut = reward[task_1:task_1+task_2]
          
    predictor_A_Task_3_cut = predictor_A_Task_3[task_1+task_2:]
    #reward_task_3_cut = reward[task_1+task_2:]
    
    
    # Make arrays with 1s and 0s to mark states in the task
    states_task_1 = np.zeros(len(predictor_A_Task_1_cut))
    states_task_1[predictor_a_good_task_1] = 1
    states_task_2 = np.zeros(len(predictor_A_Task_2_cut))
    states_task_2[predictor_a_good_task_2] = 1
    states_task_3 = np.zeros(len(predictor_A_Task_3_cut))
    states_task_3[predictor_a_good_task_3] = 1
    
    state_A_choice_A_t1 = aligned_rates_task_1[np.where((states_task_1 ==1) & (predictor_A_Task_1_cut == 1 ))]
    state_A_choice_B_t1 = aligned_rates_task_1[np.where((states_task_1 ==1) & (predictor_A_Task_1_cut == 0))]
    
    state_B_choice_A_t1 = aligned_rates_task_1[np.where((states_task_1 == 0) & (predictor_A_Task_1_cut == 1 ))]
    state_B_choice_B_t1 = aligned_rates_task_1[np.where((states_task_1 == 0) & (predictor_A_Task_1_cut == 0))]
    
    state_A_choice_A_t2 = aligned_rates_task_2[np.where((states_task_2 ==1) & (predictor_A_Task_2_cut == 1 ))]
    state_A_choice_B_t2 = aligned_rates_task_2[np.where((states_task_2 ==1) & (predictor_A_Task_2_cut == 0))]
    
    state_B_choice_A_t2 = aligned_rates_task_2[np.where((states_task_2 == 0) & (predictor_A_Task_2_cut == 1 ))]
    state_B_choice_B_t2 = aligned_rates_task_2[np.where((states_task_2 == 0) & (predictor_A_Task_2_cut == 0))]

    state_A_choice_A_t3 = aligned_rates_task_3[np.where((states_task_3 ==1) & (predictor_A_Task_3_cut == 1 ))]
    state_A_choice_B_t3 = aligned_rates_task_3[np.where((states_task_3 ==1) & (predictor_A_Task_3_cut == 0))]
    
    state_B_choice_A_t3 = aligned_rates_task_3[np.where((states_task_3 == 0) & (predictor_A_Task_3_cut == 1 ))]
    state_B_choice_B_t3 = aligned_rates_task_3[np.where((states_task_3 == 0) & (predictor_A_Task_3_cut == 0))]

    return state_A_choice_A_t1,state_A_choice_B_t1,state_B_choice_A_t1,state_B_choice_B_t1,\
        state_A_choice_A_t2, state_A_choice_B_t2,state_B_choice_A_t2,state_B_choice_B_t2,\
        state_A_choice_A_t3, state_A_choice_B_t3, state_B_choice_A_t3, state_B_choice_B_t3, spikes