Example #1
0
def plot_tuning_curves(HP,subj,s):
    
    if subj == 'm484':  
        Data = HP['Data'][0,:16]
        DM = HP['DM'][0,:16]
    elif subj == 'm479':  
        Data = HP['Data'][0,16:24]
        DM = HP['DM'][0,16:24]
    elif subj == 'm483':  
        Data = HP['Data'][0,24:]
        DM = HP['DM'][0,24:]
        
    aligned_rates_choices =  Data[s]
    session_DM = DM[s]
    choices = session_DM[:,1]
    task =  session_DM[:,5]
    a_pokes =  session_DM[:,6]
    b_pokes =  session_DM[:,7]
    taskid = rc.task_ind(task,a_pokes,b_pokes)
       

    
    a1_ind = aligned_rates_choices[np.where((choices == 1) & (taskid ==1))[0]]
    b1_ind = aligned_rates_choices[np.where((choices == 0) & (taskid ==1))[0]]
    
    a2_ind = aligned_rates_choices[np.where((choices == 1) & (taskid ==2))[0]]
    b2_ind = aligned_rates_choices[np.where((choices == 0) & (taskid ==2))[0]]

    a3_ind = aligned_rates_choices[np.where((choices == 1) & (taskid ==3))[0]]
    b3_ind = aligned_rates_choices[np.where((choices == 0) & (taskid ==3))[0]]


    return a1_ind, b1_ind,a2_ind, b2_ind,a3_ind, b3_ind,taskid,task
Example #2
0
def tim_rewrite(area = 1):
    HP = io.loadmat('/Users/veronikasamborska/Desktop/HP.mat')
    PFC = io.loadmat('/Users/veronikasamborska/Desktop/PFC.mat')

    ntrials=20;
    baselength=10;
    #A{s,ch,stype}=[]
    n = 3
    A = [[[ [] for _ in range(n)] for _ in range(n)] for _ in range(n)]
               
    if area==1:
        Data = HP
    else:
        Data = PFC
    neuron_num=0
    
    for  i, ii in enumerate(Data['DM'][0]):
         
        
        DD = Data['Data'][0][i]
        DM = Data['DM'][0][i]
  
        choices = DM[:,1]
        b_pokes = DM[:,7]
        a_pokes = DM[:,6]
        task = DM[:,5]
        taskid = rc.task_ind(task,a_pokes,b_pokes)
        
           
        sw_point=np.where(abs(np.diff(task)>0))[0]+1
        for stype in [1,2,3]: #1 is 1-2; 2 is 1-3; 3 is 2-3
            for s in range(2):
           
                #figure out type of switch. 
                prepost=[taskid[sw_point[s]-2], taskid[sw_point[s]+2]]
                         
                if(sum(prepost)==stype+2):
               
                    #FIND LAST ntrials A before switch and first ntrials as after switch 
                    for ch in [1,2]:
                        Aind=np.where(choices==(ch-1))[0]
               
                        Aind_pre_sw = Aind[Aind<=sw_point[s]]
                        
                        Aind_pre_sw = Aind_pre_sw[-ntrials-baselength-1:]
               
                        Aind_post_sw = Aind[Aind>sw_point[s]]
                        
                        Aind_post_sw = Aind_post_sw[:ntrials]
                       
                        Atrials= np.hstack([Aind_pre_sw,Aind_post_sw])
                
                        A[s][ch-1][stype-1].append((DD[Atrials]))
    return A
def remap_surprise_time(data, task_1_2=False, task_2_3=False, task_1_3=False):

    y = data['DM'][0]
    x = data['Data'][0]
    task_time_confound_data = []
    task_time_confound_dm = []

    for s, sess in enumerate(x):
        DM = y[s]
        b_pokes = DM[:, 7]
        a_pokes = DM[:, 6]
        task = DM[:, 5]
        taskid = rc.task_ind(task, a_pokes, b_pokes)

        if task_1_2 == True:

            taskid_1 = 1
            taskid_2 = 2

        elif task_2_3 == True:

            taskid_1 = 2
            taskid_2 = 3

        elif task_1_3 == True:

            taskid_1 = 1
            taskid_2 = 3

        task_1 = np.where(taskid == taskid_1)[0][-1]
        task_2 = np.where(taskid == taskid_2)[0][0]
        if task_1 + 1 == task_2:  #or task_1+1== task_2:
            task_time_confound_data.append(sess)
            task_time_confound_dm.append(y[s])
        task_1_rev = np.where(taskid == taskid_1)[0][0]
        task_2_rev = np.where(taskid == taskid_2)[0][-1]
        if task_2_rev + 1 == task_1_rev:

            task_time_confound_data.append(sess)
            task_time_confound_dm.append(y[s])

    return task_time_confound_data, task_time_confound_dm
Example #4
0
def remap(Data, DM, n_perm):

    y = DM
    x = Data

    fg_n = 1
    fg_nr = 30
    all_ns = 0
    all_ns_non_remapped = 0
    remapped_between = 0
    remapped_within = 0
    remapped_between_not_within = 0
    ns = 0

    for s, sess in enumerate(x):
        DM = y[s]
        #state =  DM[:,0]

        choices = DM[:, 1]
        b_pokes = DM[:, 6]
        a_pokes = DM[:, 7]
        task = DM[:, 5]
        taskid = rc.task_ind(task, a_pokes, b_pokes)

        task_1_a = np.where((taskid == 1) & (choices == 0))[0]
        task_1_a_1 = task_1_a[:int(len(task_1_a) / 2)]
        task_1_a_2 = task_1_a[int(len(task_1_a) / 2):]

        task_2_a = np.where((taskid == 2) & (choices == 0))[0]
        task_3_a = np.where((taskid == 3) & (choices == 0))[0]
        task_2_a_1 = task_2_a[:int(len(task_2_a) / 2)]
        task_2_a_2 = task_2_a[int(len(task_2_a) / 2):]

        task_3_a = task_3_a[:int(len(task_3_a) / 2)]
        task_3_a_1 = task_3_a[int(len(task_3_a) / 2):]

        It = np.arange(25, 30)  #Init
        Ct = np.arange(36, 41)  #Choice

        firing_rates_mean_time = x[s]
        firing_rates_all_time = x[s][:, :, :]
        n_trials, n_neurons, n_time = firing_rates_mean_time.shape

        # Numpy arrays to fill the firing rates of each neuron where the A choice was made

        for neuron in range(n_neurons):
            n_firing = firing_rates_all_time[:,
                                             neuron]  # Firing rate of each neuron
            ns += 1
            ## Task 1 --> 2
            a1_fr_between = n_firing[task_1_a_2]
            a1_fr_between = a1_fr_between[:, It]
            a2_fr_between = n_firing[task_2_a_1]
            a2_fr_between = a2_fr_between[:, It]

            a1_fr_within = n_firing[task_1_a_1]
            a1_fr_within = a1_fr_within[:, It]

            a2_fr_within = n_firing[task_1_a_2]

            a2_fr_within = a2_fr_within[:, It]

            p_within, p_max_within, x_max_within, activity_diff_within = perm_test(
                a1_fr_within, a2_fr_within, n_perm=n_perm)
            p_between, p_max_between, x_max_between, activity_diff_between = perm_test(
                a1_fr_between, a2_fr_between, n_perm=n_perm)

            if (np.max(abs(activity_diff_between)) > p_max_between):
                remapped_between += 1

                all_ns += 1
                fig = plt.figure(fg_n)

                if all_ns > 20:
                    fg_n += 1
                    fig = plt.figure(fg_n)
                    all_ns = 1

                fig.add_subplot(4, 5, all_ns)
                plt.plot(np.mean(firing_rates_all_time[task_2_a_2, neuron, :],
                                 axis=0),
                         color='coral',
                         label='task 1 A')
                plt.plot(np.mean(firing_rates_all_time[task_3_a_1, neuron, :],
                                 axis=0),
                         color='red',
                         label='task 2 A')
                #plt.plot(np.mean(firing_rates_all_time[task_1_a_1,neuron, :], axis = 0), color = 'lightblue', label = 'task 1 A')

                ym = np.max([(np.mean(firing_rates_all_time[task_2_a_2,neuron, :], axis = 0)),\
                             (np.mean(firing_rates_all_time[task_3_a_1,neuron, :],axis =0))])

                #print(ym)
                # plt.vlines(25, ymin =  0, ymax = ym, linestyle = '--',color = 'grey')
                # plt.vlines(36, ymin =  0, ymax = ym,linestyle = '--', color = 'black')
                # plt.vlines(42, ymin =  0, ymax = ym,linestyle = '--', color = 'pink')
                plt.plot(activity_diff_between, 'black')

                plt.plot(x_max_between,
                         'red',
                         linestyle='--',
                         alpha=0.5,
                         linewidth=0.5)
                plt.plot(-x_max_between,
                         'red',
                         linestyle='--',
                         alpha=0.5,
                         linewidth=0.5)

            if (np.max(abs(activity_diff_within)) > p_max_within):
                remapped_within += 1

                all_ns_non_remapped += 1
                fig = plt.figure(fg_nr)
                if all_ns_non_remapped > 20:
                    fg_nr += 1
                    fig = plt.figure(fg_nr)
                    all_ns_non_remapped = 1

                fig.add_subplot(4, 5, all_ns_non_remapped)
                plt.plot(np.mean(firing_rates_all_time[task_2_a_2, neuron, :],
                                 axis=0),
                         color='blue',
                         label='task 1 A')
                #plt.plot(np.mean(firing_rates_all_time[task_3_a_1,neuron, :], axis =0), color = 'red', label = 'task 2 A')
                plt.plot(np.mean(firing_rates_all_time[task_2_a_1, neuron, :],
                                 axis=0),
                         color='lightblue',
                         label='task 1 A')

                ym = np.max([(np.mean(firing_rates_all_time[task_2_a_2,neuron, :], axis = 0)),\
                             (np.mean(firing_rates_all_time[task_2_a_1,neuron, :],axis = 0))])

                #print(ym)
                # plt.vlines(25, ymin =  0, ymax = ym, linestyle = '--',color = 'grey')
                #  plt.vlines(36, ymin =  0, ymax = ym,linestyle = '--', color = 'black')
                #  plt.vlines(42, ymin =  0, ymax = ym,linestyle = '--', color = 'pink')

                plt.plot(activity_diff_within, 'grey')

                plt.plot(x_max_within,
                         'red',
                         linestyle='--',
                         alpha=0.5,
                         linewidth=0.5)
                plt.plot(-x_max_within,
                         'red',
                         linestyle='--',
                         alpha=0.5,
                         linewidth=0.5)

            if (np.max(abs(activity_diff_between)) > p_max_between) and (
                    np.max(abs(activity_diff_within)) < p_max_within):

                remapped_between_not_within += 1

    return remapped_between_not_within, remapped_between, remapped_within, ns
Example #5
0
def plotting_a_b_i(Data,DM, all_session_b1, all_session_a1, all_session_i1, all_session_b2, all_session_a2,\
all_session_i2, all_session_b3, all_session_a3, all_session_i3):

    s_n = 0
    for s, sess in enumerate(Data_HP):
        s_n += 1
        DM = DM_HP[s]
        x = Data_HP[s]

        choices = DM[:, 1]
        b_pokes = DM[:, 7]
        a_pokes = DM[:, 6]
        task = DM[:, 5]
        taskid = rc.task_ind(task, a_pokes, b_pokes)

        task_1 = np.where((taskid == 1))[0]
        task_2 = np.where((taskid == 2))[0]
        task_3 = np.where((taskid == 3))[0]

        task_1_a = np.where((taskid == 1) & (choices == 0))[0]
        task_2_a = np.where((taskid == 2) & (choices == 0))[0]
        task_3_a = np.where((taskid == 3) & (choices == 0))[0]

        task_1_b = np.where((taskid == 1) & (choices == 1))[0]
        task_2_b = np.where((taskid == 2) & (choices == 1))[0]
        task_3_b = np.where((taskid == 3) & (choices == 1))[0]

        all_session_b1_s = all_session_b1[s]
        all_session_a1_s = all_session_a1[s]
        all_session_i1_s = all_session_i1[s]
        all_session_b2_s = all_session_b2[s]
        all_session_a2_s = all_session_a2[s]
        all_session_i2_s = all_session_i2[s]
        all_session_b3_s = all_session_b3[s]
        all_session_a3_s = all_session_a3[s]
        all_session_i3_s = all_session_i3[s]

        firing_rates_mean_time = x
        n_trials, n_neurons, n_time = firing_rates_mean_time.shape
        b1_fr = np.mean(firing_rates_mean_time[task_1_b], axis=0)
        b2_fr = np.mean(firing_rates_mean_time[task_2_b], axis=0)
        b3_fr = np.mean(firing_rates_mean_time[task_3_b], axis=0)

        ## As
        a1_fr = np.mean(firing_rates_mean_time[task_1_a], axis=0)
        a2_fr = np.mean(firing_rates_mean_time[task_2_a], axis=0)
        a3_fr = np.mean(firing_rates_mean_time[task_3_a], axis=0)

        ## Is
        i1_fr = np.mean(firing_rates_mean_time[task_1], axis=0)
        i2_fr = np.mean(firing_rates_mean_time[task_2], axis=0)
        i3_fr = np.mean(firing_rates_mean_time[task_3], axis=0)

        fig = plt.figure(s_n)

        for neuron in range(n_neurons):

            if neuron in all_session_b1_s:
                n_firing_b1 = b1_fr[neuron]
                fig.add_subplot(5, 5, neuron + 1)
                plt.plot(n_firing_b1, color='blue')

            if neuron in all_session_a1_s:
                n_firing_a1 = a1_fr[neuron]
                fig.add_subplot(5, 5, neuron + 1)

                plt.plot(n_firing_a1, color='red')

            if neuron in all_session_i1_s:
                n_firing_i1 = i1_fr[neuron]
                fig.add_subplot(5, 5, neuron + 1)
                plt.plot(n_firing_i1, color='yellow')

            if neuron in all_session_a2_s:
                n_firing_a2 = a2_fr[neuron]
                fig.add_subplot(5, 5, neuron + 1)
                plt.plot(n_firing_a2, color='red', linestyle='--')

            if neuron in all_session_b2_s:
                n_firing_b2 = b2_fr[neuron]
                fig.add_subplot(5, 5, neuron + 1)
                plt.plot(n_firing_b2, color='blue', linestyle='--')

            if neuron in all_session_i2_s:
                n_firing_i2 = i2_fr[neuron]
                fig.add_subplot(5, 5, neuron + 1)
                plt.plot(n_firing_i2, color='yellow', linestyle='--')

            if neuron in all_session_a3_s:
                n_firing_a3 = a3_fr[neuron]
                fig.add_subplot(5, 5, neuron + 1)
                plt.plot(n_firing_a3, color='pink')

            if neuron in all_session_b3_s:
                n_firing_b3 = b3_fr[neuron]
                fig.add_subplot(5, 5, neuron + 1)
                plt.plot(n_firing_b3, color='lightblue')

            if neuron in all_session_i3_s:
                n_firing_i3 = i3_fr[neuron]
                fig.add_subplot(5, 5, neuron + 1)
                plt.plot(n_firing_i3, color='orange')
def correlations(data, task_1_2=False, task_2_3=False, task_1_3=False):

    y = data['DM'][0]
    x = data['Data'][0]

    stack_array = []
    for s, sess in enumerate(x):
        DM = y[s]

        choices = DM[:, 1]
        b_pokes = DM[:, 7]
        a_pokes = DM[:, 6]
        task = DM[:, 5]
        state = DM[:, 0]
        block = DM[:, 5]
        taskid = rc.task_ind(task, a_pokes, b_pokes)

        if task_1_2 == True:

            taskid_1 = 1
            taskid_2 = 2

        elif task_2_3 == True:

            taskid_1 = 2
            taskid_2 = 3

        elif task_1_3 == True:

            taskid_1 = 1
            taskid_2 = 3

        task_1_a_bad = np.where((taskid == taskid_1) & (choices == 1) &
                                (state == 0))[0]  # Find indicies for task 1 A
        task_1_a_good = np.where((taskid == taskid_1) & (choices == 1) &
                                 (state == 1))[0]  # Find indicies for task 1 A

        task_1_b_bad = np.where((taskid == taskid_1) & (choices == 0) &
                                (state == 1))[0]  # Find indicies for task 1 A
        task_1_b_good = np.where((taskid == taskid_1) & (choices == 0) &
                                 (state == 0))[0]  # Find indicies for task 1 A

        task_2_a_bad = np.where((taskid == taskid_2) & (choices == 1) &
                                (state == 0))[0]  # Find indicies for task 1 A
        task_2_a_good = np.where((taskid == taskid_2) & (choices == 1) &
                                 (state == 1))[0]  # Find indicies for task 1 A

        task_2_b_bad = np.where((taskid == taskid_2) & (choices == 0) &
                                (state == 1))[0]  # Find indicies for task 1 A
        task_2_b_good = np.where((taskid == taskid_2) & (choices == 0) &
                                 (state == 0))[0]  # Find indicies for task 1 A

        trials_since_block = []
        t = 0
        for st, s in enumerate(state):
            if state[st - 1] != state[st]:
                t = 0
            else:
                t += 1
            trials_since_block.append(t)

        firing_rates_mean_time = sess
        task_1_a_bad_f = np.mean(firing_rates_mean_time[task_1_a_bad[:10]],
                                 axis=2)
        task_1_a_good_f = np.mean(firing_rates_mean_time[task_1_a_good[:10]],
                                  axis=2)

        task_1_b_bad_f = np.mean(firing_rates_mean_time[task_1_b_bad[:10]],
                                 axis=2)
        task_1_b_good_f = np.mean(firing_rates_mean_time[task_1_b_good[:10]],
                                  axis=2)

        task_2_a_bad_f = np.mean(firing_rates_mean_time[task_2_a_bad[:10]],
                                 axis=2)
        task_2_a_good_f = np.mean(firing_rates_mean_time[task_2_a_good[:10]],
                                  axis=2)

        task_2_b_bad_f = np.mean(firing_rates_mean_time[task_2_b_bad[:10]],
                                 axis=2)
        task_2_b_good_f = np.mean(firing_rates_mean_time[task_2_b_good[:10]],
                                  axis=2)

        ## Last 10

        task_1_a_bad_l = np.mean(firing_rates_mean_time[task_1_a_bad[-10:]],
                                 axis=2)
        task_1_a_good_l = np.mean(firing_rates_mean_time[task_1_a_good[-10:]],
                                  axis=2)

        task_1_b_bad_l = np.mean(firing_rates_mean_time[task_1_b_bad[-10:]],
                                 axis=2)
        task_1_b_good_l = np.mean(firing_rates_mean_time[task_1_b_good[-10:]],
                                  axis=2)

        task_2_a_bad_l = np.mean(firing_rates_mean_time[task_2_a_bad[-10:]],
                                 axis=2)
        task_2_a_good_l = np.mean(firing_rates_mean_time[task_2_a_good[-10:]],
                                  axis=2)

        task_2_b_bad_l = np.mean(firing_rates_mean_time[task_2_b_bad[-10:]],
                                 axis=2)
        task_2_b_good_l = np.mean(firing_rates_mean_time[task_2_b_good[-10:]],
                                  axis=2)

        stack_first_last = np.vstack((task_1_a_bad_f,task_1_a_bad_l, task_1_a_good_f,task_1_a_good_l,task_1_b_bad_f,task_1_b_bad_l,\
                                      task_1_b_good_f,task_1_b_good_l,task_2_a_bad_f,task_2_a_bad_l,\
                                          task_2_a_good_f,task_2_a_good_l,task_2_b_bad_f,task_2_b_bad_l,task_2_b_good_f,task_2_b_good_l))
        print(stack_first_last.shape)

        if stack_first_last.shape[0] == 160:
            stack_array.append(stack_first_last)

    all_conc = np.concatenate(stack_array, 1)
    corr = np.corrcoef(all_conc)

    plt.figure()
    plt.imshow(corr)
    plt.xticks(np.arange(0,160,10),['A bad T1 Start', 'A bad T1 End', 'A good T1 Start', 'A good T1 End', 'B bad T1 Start','B bad T1 End',\
                                   'B good T1 Start','B good T1 End', 'A bad T2 Start', ' bad T2 End','A good T2 Start', 'A good T2 End', 'B bad T2 Start', 'B bad T2 End',\
                                       'B good T2 Start','B good T2 End'], rotation=90)

    return corr, stack_array
def shuffle_block_start(data, task_1_2 = False, task_2_3 = False, task_1_3 = False, n_perms = 5):
    
    fr,dm = remap_surprise_time(data, task_1_2 = task_1_2, task_2_3 = task_2_3, task_1_3 = task_1_3)
    
   
    ind_pre = 31
    ind_post = 20
    n_count = 0
  
    surprise_list_neurons_a_a_p = []
    surprise_list_neurons_b_b_p = []
    surprise_list_neurons_b_init_p = []

    for  s, sess in tqdm(enumerate(fr)):
        
        DM = dm[s]
        task = DM[:,5]
        surprise_list_neurons_a_a_perm = []
        surprise_list_neurons_b_b_perm = []
        surprise_list_neurons_b_init_perm = []
                   

        for perm in range(n_perms):

            choices = DM[:,1]
            b_pokes = DM[:,7]
            a_pokes = DM[:,6]
            taskid = rc.task_ind(task,a_pokes,b_pokes)
            if task_1_2 == True:
            
                taskid_1 = 1
                taskid_2 = 2
                
            elif task_2_3 == True:
                
                taskid_1 = 2
                taskid_2 = 3
            
            elif task_1_3 == True:
                
                taskid_1 = 1
                taskid_2 = 3
 
            task_1_a = np.where((taskid == taskid_1) & (choices == 1))[0] # Find indicies for task 1 A
            task_1_a_pre_baseline = task_1_a[-ind_pre:-ind_pre+10] # Find indicies for task 1 A last 10 
            task_1_a_pre  = task_1_a[-ind_pre+10:] # Find indicies for task 1 A last 10 
            
            # Reverse
            
            task_1_a_pre_baseline_rev = task_1_a[-10:] # Find indicies for task 1 A last 10 
            task_1_a_pre_rev  = task_1_a[-ind_pre:-ind_pre+20] # Find indicies for task 1 A last 10 
           
            task_1_b = np.where((taskid == taskid_1) & (choices == 0))[0] # Find indicies for task 1 B
            task_1_b_pre_baseline = task_1_b[-ind_pre:-ind_pre+10] # Find indicies for task 1 A last 10 
            task_1_b_pre  = task_1_b[-ind_pre+10:] # Find indicies for task 1 A last 10 
         
            task_1_b_pre_baseline_rev = task_1_b[-10:] # Find indicies for task 1 A last 10 
            task_1_b_pre_rev  = task_1_b[-ind_pre:-ind_pre+20]# Find indicies for task 1 A last 10 
         
            task_2_b = np.where((taskid == taskid_2) & (choices == 0))[0] # Find indicies for task 2 B
            task_2_b_post = task_2_b[:ind_post] # Find indicies for task 1 A last 10 
    
            task_2_b_post_rev_baseline = task_2_b[-10:] # Find indicies for task 1 A last 10 
    
            task_2_a = np.where((taskid == taskid_2) & (choices == 1))[0] # Find indicies for task 2 A
            task_2_a_post = task_2_a[:ind_post] # Find indicies for task 1 A last 10 
    
            task_2_a_post_rev_baseline = task_2_a[-10:] # Find indicies for task 1 A last 10 
    
         
            firing_rates_mean_time = fr[s]
           
            n_trials, n_neurons, n_time = firing_rates_mean_time.shape
            
            surprise_list_neurons_a_a = []
            surprise_list_neurons_b_b = []
            for neuron in range(n_neurons):
                n_count +=1
                
                n_firing = firing_rates_mean_time[:,neuron, :]  # Firing rate of each neuron
                n_firing =  gaussian_filter1d(n_firing.astype(float),2,1)
    
                # Task 1 Mean rates on the first 20 A trials
                task_1_mean_a = np.tile(np.mean(n_firing[task_1_a_pre_baseline], axis = 0),[np.mean(n_firing[task_1_a_pre_baseline],0).shape[0],1])   
                task_1_std_a = np.tile(np.std(n_firing[task_1_a_pre_baseline], axis = 0),[np.std(n_firing[task_1_a_pre_baseline],0).shape[0],1] ) 
               
                task_1_mean_a_rev = np.tile(np.mean(n_firing[task_1_a_pre_baseline_rev], axis = 0),[np.mean(n_firing[task_1_a_pre_baseline_rev],0).shape[0],1] )
                task_1_std_a_rev = np.tile(np.std(n_firing[task_1_a_pre_baseline_rev], axis = 0),[np.std(n_firing[task_1_a_pre_baseline_rev],0).shape[0],1] ) 
               
                # Task 1 Mean rates on the first 20 B trials
                task_1_mean_b = np.tile(np.mean(n_firing[task_1_b_pre_baseline], axis = 0),[np.mean(n_firing[task_1_a_pre_baseline],0).shape[0],1] ) 
                task_1_std_b = np.tile(np.std(n_firing[task_1_b_pre_baseline], axis = 0),[np.std(n_firing[task_1_a_pre_baseline],0).shape[0],1] ) 
                 
                task_1_mean_b_rev = np.tile(np.mean(n_firing[task_1_b_pre_baseline_rev], axis = 0), [np.mean(n_firing[task_1_b_pre_baseline_rev],0).shape[0],1] ) 
                task_1_std_b_rev = np.tile(np.std(n_firing[task_1_b_pre_baseline_rev], axis = 0), [np.std(n_firing[task_1_b_pre_baseline_rev],0).shape[0],1] ) 
               
                # Task 1 Mean rates on the last 20 A trials
                task_1_mean_a_l = np.tile(np.mean(n_firing[task_1_a_pre], axis = 0),[np.mean(n_firing[task_1_a_pre_baseline],0).shape[0],1] ) 
                task_1_std_a_l = np.tile(np.std(n_firing[task_1_a_pre], axis = 0),[np.std(n_firing[task_1_a_pre_baseline],0).shape[0],1] )
               
                task_1_mean_a_l_rev = np.tile(np.mean(n_firing[task_1_a_pre_rev], axis = 0),[np.mean(n_firing[task_1_a_pre_rev],0).shape[0],1] )
    
                # Task 1 Mean rates on the last 20 B trials
                task_1_mean_b_l = np.tile(np.mean(n_firing[task_1_b_pre], axis = 0),[np.mean(n_firing[task_1_a_pre_baseline],0).shape[0],1] ) 
                #task_1_std_b_l = np.std(n_firing[task_1_b_pre], axis = 0)
                
                task_1_mean_b_l_rev = np.tile(np.mean(n_firing[task_1_b_pre_rev], axis = 0),[np.mean(n_firing[task_1_b_pre_rev],0).shape[0],1] ) 
    
                # Task 1 Mean rates on the first 20 A trials
                task_2_mean_a = np.tile(np.mean(n_firing[task_2_a_post], axis = 0),[np.mean(n_firing[task_1_a_pre_baseline],0).shape[0],1] ) 
                #task_2_std_a = np.std(n_firing[task_2_a_post], axis = 0)   
                
                task_2_mean_a_rev = np.tile(np.mean(n_firing[task_2_a_post_rev_baseline], axis = 0),[np.mean(n_firing[task_2_a_post_rev_baseline],0).shape[0],1] ) 
    
                task_2_std_a_rev = np.tile(np.std(n_firing[task_2_a_post_rev_baseline], axis = 0),[np.std(n_firing[task_2_a_post_rev_baseline],0).shape[0],1] ) 
                task_2_std_a_rev = np.tile(np.std(n_firing[task_2_a_post_rev_baseline], axis = 0),[np.std(n_firing[task_2_a_post_rev_baseline],0).shape[0],1] ) 
    
    
                # Task 1 Mean rates on the first 20 B trials
                task_2_mean_b = np.tile(np.mean(n_firing[task_2_b_post], axis = 0),[np.mean(n_firing[task_1_a_pre_baseline],0).shape[0],1] ) 
                #task_2_std_b = np.std(n_firing[task_2_b_post], axis = 0)
                task_2_mean_b_rev = np.tile(np.mean(n_firing[task_2_b_post_rev_baseline], axis = 0),[np.mean(n_firing[task_2_b_post_rev_baseline],0).shape[0],1] ) 
                task_2_std_b_rev = np.tile(np.std(n_firing[task_2_b_post_rev_baseline], axis = 0),[np.std(n_firing[task_2_b_post_rev_baseline],0).shape[0],1] ) 
               
               
               
                  
                min_std = 2
                
                if (len(np.where(task_1_mean_a_l == 0)[0]) == 0) and (len(np.where(task_1_mean_a == 0)[0]) == 0)\
                        and (len(np.where(task_1_mean_a_rev == 0)[0]) == 0) and (len(np.where(task_1_mean_b_l == 0)[0]) == 0)\
                        and (len(np.where(task_1_mean_b == 0)[0]) == 0) and (len(np.where(task_1_mean_b_rev == 0)[0]) == 0)\
                        and (len(np.where(task_2_mean_a == 0)[0]) == 0) and (len(np.where(task_2_mean_a_rev == 0)[0]) == 0)\
                        and (len(np.where(task_2_mean_b == 0)[0]) == 0) and (len(np.where(task_2_mean_b_rev == 0)[0]) == 0):
    
                    a_within_1 = -norm.logpdf(task_1_mean_a_l, np.transpose(task_1_mean_a, (1,0)), np.transpose(task_1_std_a+min_std))
                    a_within_1_rev = -norm.logpdf(task_1_mean_a_l_rev, np.transpose(task_1_mean_a_rev, (1,0)), np.transpose(task_1_std_a_rev+min_std))
        
                    b_within_1 = -norm.logpdf(task_1_mean_b_l, np.transpose(task_1_mean_b, (1,0)), np.transpose(task_1_std_b+min_std))
                    b_within_1_rev = -norm.logpdf(task_1_mean_b_l_rev, np.transpose(task_1_mean_b_rev, (1,0)), np.transpose(task_1_std_b_rev+min_std))
        
                    a_between = -norm.logpdf(task_2_mean_a, np.transpose(task_1_mean_a, (1,0)), np.transpose(task_1_std_a+min_std))
                    a_between_rev = -norm.logpdf(task_1_mean_a_l, np.transpose(task_2_mean_a_rev, (1,0)), np.transpose(task_2_std_a_rev+min_std))
        
                    b_between = -norm.logpdf(task_2_mean_b, np.transpose(task_1_mean_b, (1,0)), np.transpose(task_1_std_b+min_std))
                    b_between_rev = -norm.logpdf(task_1_mean_b_l, np.transpose(task_2_mean_b_rev, (1,0)), np.transpose(task_2_std_b_rev+min_std))
                else:
                   
                    a_within_1 = np.zeros(task_1_mean_a_l.shape); a_within_1[:] = np.NaN
                    a_within_1_rev = np.zeros(task_1_mean_a_l.shape); a_within_1_rev[:] = np.NaN
                    b_within_1 = np.zeros(task_1_mean_a_l.shape); b_within_1[:] = np.NaN
                    b_within_1_rev = np.zeros(task_1_mean_a_l.shape); b_within_1_rev[:] = np.NaN
                     
                    a_between = np.zeros(task_1_mean_a_l.shape); a_between[:] = np.NaN
                    a_between_rev = np.zeros(task_1_mean_a_l.shape); a_between_rev[:] = np.NaN
                    b_between = np.zeros(task_1_mean_a_l.shape); b_between[:] = np.NaN
                    b_between_rev = np.zeros(task_1_mean_a_l.shape); b_between_rev[:] = np.NaN
    
                 
                within_a = np.mean([a_within_1,a_within_1_rev],0)
                within_b = np.mean([b_within_1,b_within_1_rev],0)
    
                between_a = np.mean([a_between,a_between_rev],0)
                between_b = np.mean([b_between,b_between_rev],0)
                
                if task_2_3 == True:
    
                    surprise_array_a = np.concatenate([a_within_1, a_between], axis = 0)                   
                    surprise_array_b = np.concatenate([b_within_1,b_between], axis = 0)         
                else:
                    surprise_array_a = np.concatenate([within_a, between_a], axis = 0)                   
                    surprise_array_b = np.concatenate([within_b,between_b], axis = 0)         
                   
                surprise_list_neurons_a_a.append(surprise_array_a)
                surprise_list_neurons_b_b.append(surprise_array_b)

                
            surprise_list_neurons_a_a_mean = (-np.sqrt(np.nanmean(surprise_list_neurons_a_a,0)))
            surprise_list_neurons_b_b_mean = (-np.sqrt(np.nanmean(surprise_list_neurons_b_b,0)))
            
            surprise_list_neurons_a_a_perm.append(abs(np.diag(surprise_list_neurons_a_a_mean.T[:,:63]) - np.diag(surprise_list_neurons_b_b_mean.T[:,63:])))
            surprise_list_neurons_b_b_perm.append(abs(np.diag(surprise_list_neurons_b_b_mean.T[:,:63]) - np.diag(surprise_list_neurons_b_b_mean.T[:,63:])))
            surprise_list_neurons_b_init_perm.append(abs(surprise_list_neurons_b_b_mean.T[36,:63] - surprise_list_neurons_b_b_mean.T[36,63:]))
             
              
        surprise_list_neurons_a_a_p.append(np.percentile(np.asarray(surprise_list_neurons_a_a_perm),95, axis = 0))
        surprise_list_neurons_b_b_p.append(np.percentile(np.asarray(surprise_list_neurons_b_b_perm),95, axis = 0))
        surprise_list_neurons_b_init_p.append(np.percentile(np.asarray(surprise_list_neurons_b_init_perm),95, axis = 0))
    surprise_list_neurons_a_a_p = np.nanmean(surprise_list_neurons_a_a_p,0)
    surprise_list_neurons_b_b_p = np.nanmean(surprise_list_neurons_b_b_p,0)
    surprise_list_neurons_b_init_p  = np.nanmean(surprise_list_neurons_b_init_p,0)
  
    return surprise_list_neurons_a_a_p, surprise_list_neurons_b_b_p,surprise_list_neurons_b_init_p
Example #8
0
def trials_surprise(data, task_1_2=False, task_2_3=False, task_1_3=False):

    # y = data['DM'][0]
    # x = data['Data'][0]
    x, y = rtf.remap_surprise_time(data,
                                   task_1_2=task_1_2,
                                   task_2_3=task_2_3,
                                   task_1_3=task_1_3)

    surprise_list_neurons_a_a = []
    surprise_list_neurons_b_b = []
    surprise_list_neurons_a_a_diff = []
    surprise_list_neurons_b_b_diff = []

    ind_pre = 20

    ind_post = 21

    #A_ind_pre_sw=Aind_pre_sw(end-ntrials-baselength:end);

    for s, sess in enumerate(x):
        DM = y[s]

        choices = DM[:, 1]
        b_pokes = DM[:, 7]
        a_pokes = DM[:, 6]
        task = DM[:, 5]
        taskid = rc.task_ind(task, a_pokes, b_pokes)

        if task_1_2 == True:

            taskid_1 = 1
            taskid_2 = 2

        elif task_2_3 == True:

            taskid_1 = 2
            taskid_2 = 3

        elif task_1_3 == True:

            taskid_1 = 1
            taskid_2 = 3

        task_1_a = np.where((taskid == taskid_1)
                            & (choices == 1))[0]  # Find indicies for task 1 A
        task_1_a_pre_baseline = task_1_a[
            -ind_pre:-ind_pre + 10]  # Find indicies for task 1 A last 10
        task_1_a_pre = task_1_a[-ind_pre +
                                10:]  # Find indicies for task 1 A last 10

        # Reverse
        task_1_a_pre_baseline_rev = task_1_a[
            -10:]  # Find indicies for task 1 A last 10
        task_1_a_pre_rev = task_1_a[-ind_pre:-ind_pre +
                                    20]  # Find indicies for task 1 A last 10

        task_1_b = np.where((taskid == taskid_1)
                            & (choices == 0))[0]  # Find indicies for task 1 B
        task_1_b_pre_baseline = task_1_b[
            -ind_pre:-ind_pre + 10]  # Find indicies for task 1 A last 10
        task_1_b_pre = task_1_b[-ind_pre +
                                10:]  # Find indicies for task 1 A last 10

        task_1_b_pre_baseline_rev = task_1_b[
            -10:]  # Find indicies for task 1 A last 10
        task_1_b_pre_rev = task_1_b[-ind_pre:-ind_pre +
                                    20]  # Find indicies for task 1 A last 10

        task_2_b = np.where((taskid == taskid_2)
                            & (choices == 0))[0]  # Find indicies for task 2 B
        task_2_b_post = task_2_b[:
                                 ind_post]  # Find indicies for task 1 A last 10

        task_2_b_post_rev_baseline = task_2_b[
            -10:]  # Find indicies for task 1 A last 10

        task_2_a = np.where((taskid == taskid_2)
                            & (choices == 1))[0]  # Find indicies for task 2 A
        task_2_a_post = task_2_a[:
                                 ind_post]  # Find indicies for task 1 A last 10

        task_2_a_post_rev_baseline = task_2_a[
            -10:]  # Find indicies for task 1 A last 10

        firing_rates_mean_time = x[s]

        n_trials, n_neurons, n_time = firing_rates_mean_time.shape

        for neuron in range(n_neurons):

            n_firing = firing_rates_mean_time[:,
                                              neuron, :].T  # Firing rate of each neuron
            n_firing = gaussian_filter1d(n_firing.astype(float), 2, 1)

            n_firing = n_firing.T

            # Baseline
            task_1_mean_a = np.mean(n_firing[task_1_a_pre_baseline], axis=0)
            task_1_std_a = np.std(n_firing[task_1_a_pre_baseline], axis=0)

            # Task 1 Mean rates on the first 20 B trials
            task_1_mean_b = np.mean(n_firing[task_1_b_pre_baseline], axis=0)
            task_1_std_b = np.std(n_firing[task_1_b_pre_baseline], axis=0)

            min_std = 2

            if (len(np.where(n_firing[task_1_a_pre] == 0)[0]) )== 0 and (len(np.where(n_firing[task_1_b_pre] == 0)[0]) == 0)\
                     and (len(np.where(n_firing[task_2_a_post] == 0)[0]) == 0) and (len(np.where(n_firing[task_2_b_post] == 0)[0]) == 0):

                #if (len(np.where(task_1_mean_a == 0)[0]) == 0 ) and (len(np.where(task_1_mean_b == 0)[0]) == 0):

                a_within = -norm.logpdf(n_firing[task_1_a_pre], task_1_mean_a,
                                        (task_1_std_a + min_std))

                b_within = -norm.logpdf(n_firing[task_1_b_pre], task_1_mean_b,
                                        (task_1_std_b + min_std))

                a_between = -norm.logpdf(n_firing[task_2_a_post],
                                         task_1_mean_a,
                                         (task_1_std_a + min_std))

                b_between = -norm.logpdf(n_firing[task_2_b_post],
                                         task_1_mean_b,
                                         (task_1_std_b + min_std))

            else:

                a_within = np.zeros(n_firing[task_1_a_pre].shape)
                a_within[:] = np.NaN
                b_within = np.zeros(n_firing[task_1_b_pre].shape)
                b_within[:] = np.NaN
                a_between = np.zeros(n_firing[task_2_a_post].shape)
                a_between[:] = np.NaN
                b_between = np.zeros(n_firing[task_2_b_post].shape)
                b_between[:] = np.NaN

            surprise_array_a = np.concatenate([a_within, a_between], axis=0)
            surprise_array_b = np.concatenate([b_within, b_between], axis=0)

            surprise_list_neurons_a_a.append(surprise_array_a)
            surprise_list_neurons_b_b.append(surprise_array_b)

    surprise_list_neurons_a_a_p = np.nanmean(
        np.asarray(surprise_list_neurons_a_a), axis=0)
    surprise_list_neurons_b_b_p = np.nanmean(
        np.asarray(surprise_list_neurons_b_b), axis=0)

    surprise_list_neurons_a_a_std = np.nanstd(
        np.asarray(surprise_list_neurons_a_a), axis=0) / np.sqrt(
            np.asarray(surprise_list_neurons_a_a).shape[0])
    surprise_list_neurons_b_b_std = np.nanstd(
        np.asarray(surprise_list_neurons_b_b), axis=0) / np.sqrt(
            np.asarray(surprise_list_neurons_b_b).shape[0])

    return surprise_list_neurons_b_b_p, surprise_list_neurons_a_a_p, surprise_list_neurons_b_b_std, surprise_list_neurons_a_a_std
Example #9
0
def shuffle_block_start_trials(data,
                               task_1_2=False,
                               task_2_3=False,
                               task_1_3=False,
                               n_perms=5):

    x, y = rtf.remap_surprise_time(data,
                                   task_1_2=task_1_2,
                                   task_2_3=task_2_3,
                                   task_1_3=task_1_3)
    ind_pre = 31
    ind_post = 40
    n_count = 0

    surprise_list_neurons_a_a_p = []
    surprise_list_neurons_b_b_p = []

    for s, sess in tqdm(enumerate(x)):
        DM = y[s]
        task = DM[:, 5]
        surprise_list_neurons_a_a_perm = []
        surprise_list_neurons_b_b_perm = []

        for perm in range(n_perms):

            surprise_list_neurons_a_a = []
            surprise_list_neurons_b_b = []

            choices = DM[:, 1]
            b_pokes = DM[:, 7]
            a_pokes = DM[:, 6]
            task = DM[:, 5]
            taskid = rc.task_ind(task, a_pokes, b_pokes)

            if task_1_2 == True:

                taskid_1 = 1
                taskid_2 = 2

            elif task_2_3 == True:

                taskid_1 = 2
                taskid_2 = 3

            elif task_1_3 == True:

                taskid_1 = 1
                taskid_2 = 3

            #np.random.shuffle(taskid)
            taskid = np.roll(task, np.random.randint(len(task)), axis=0)

            task_1_a = np.where((taskid == taskid_1) & (choices == 1))[
                0]  # Find indicies for task 1 A
            task_1_a_pre_baseline = task_1_a[
                -ind_pre:-ind_pre + 10]  # Find indicies for task 1 A last 10
            task_1_a_pre = task_1_a[-ind_pre +
                                    10:]  # Find indicies for task 1 A last 10
            # Reverse
            task_1_a_pre_baseline_rev = task_1_a[
                -10:]  # Find indicies for task 1 A last 10
            task_1_a_pre_rev = task_1_a[
                -ind_pre:-ind_pre + 20]  # Find indicies for task 1 A last 10

            task_1_b = np.where((taskid == taskid_1) & (choices == 0))[
                0]  # Find indicies for task 1 B
            task_1_b_pre_baseline = task_1_b[
                -ind_pre:-ind_pre + 10]  # Find indicies for task 1 A last 10
            task_1_b_pre = task_1_b[-ind_pre +
                                    10:]  # Find indicies for task 1 A last 10

            task_1_b_pre_baseline_rev = task_1_b[
                -10:]  # Find indicies for task 1 A last 10
            task_1_b_pre_rev = task_1_b[
                -ind_pre:-ind_pre + 20]  # Find indicies for task 1 A last 10

            task_2_b = np.where((taskid == taskid_2) & (choices == 0))[
                0]  # Find indicies for task 2 B
            task_2_b_post = task_2_b[:
                                     ind_post]  # Find indicies for task 1 A last 10

            task_2_b_post_rev_baseline = task_2_b[
                -10:]  # Find indicies for task 1 A last 10

            task_2_a = np.where((taskid == taskid_2) & (choices == 1))[
                0]  # Find indicies for task 2 A
            task_2_a_post = task_2_a[:
                                     ind_post]  # Find indicies for task 1 A last 10

            task_2_a_post_rev_baseline = task_2_a[
                -10:]  # Find indicies for task 1 A last 10

            firing_rates_mean_time = x[s]

            n_trials, n_neurons, n_time = firing_rates_mean_time.shape

            for neuron in range(n_neurons):

                n_firing = firing_rates_mean_time[:,
                                                  neuron, :].T  # Firing rate of each neuron
                n_firing = gaussian_filter1d(n_firing.astype(float), 2, 1)
                n_firing = n_firing.T

                # n_firing_pre_init = np.mean(n_firing[:,:20],1)
                # n_firing_init = np.mean(n_firing[:,25:30],1)
                # n_firing_ch = np.mean(n_firing[:,36:41],1)
                # n_firing_rew = np.mean(n_firing[:,42:47],1)
                # n_firing = np.vstack([n_firing_pre_init,n_firing_init,n_firing_ch,n_firing_rew])
                # n_firing = n_firing.T
                # Baseline
                task_1_mean_a = np.mean(n_firing[task_1_a_pre_baseline],
                                        axis=0)
                task_1_std_a = np.std(n_firing[task_1_a_pre_baseline], axis=0)

                # Task 1 Mean rates on the first 20 B trials
                task_1_mean_b = np.mean(n_firing[task_1_b_pre_baseline],
                                        axis=0)
                task_1_std_b = np.std(n_firing[task_1_b_pre_baseline], axis=0)

                min_std = 2

                if (len(np.where(n_firing[task_1_a_pre] == 0)[0]) )== 0 and (len(np.where(n_firing[task_1_b_pre] == 0)[0]) == 0)\
                     and (len(np.where(n_firing[task_2_a_post] == 0)[0]) == 0) and (len(np.where(n_firing[task_2_b_post] == 0)[0]) == 0)\
                     and (len(np.where(task_1_mean_a == 0)[0]) == 0 ) and (len(np.where(task_1_mean_b == 0)[0]) == 0):

                    a_within = -norm.logpdf(n_firing[task_1_a_pre],
                                            task_1_mean_a,
                                            (task_1_std_a + min_std))

                    b_within = -norm.logpdf(n_firing[task_1_b_pre],
                                            task_1_mean_b,
                                            (task_1_std_b + min_std))

                    a_between = -norm.logpdf(n_firing[task_2_a_post],
                                             task_1_mean_a,
                                             (task_1_std_a + min_std))

                    b_between = -norm.logpdf(n_firing[task_2_b_post],
                                             task_1_mean_b,
                                             (task_1_std_b + min_std))

                else:
                    a_within = np.zeros(n_firing[task_1_a_pre].shape)
                    a_within[:] = np.NaN
                    b_within = np.zeros(n_firing[task_1_b_pre].shape)
                    b_within[:] = np.NaN
                    a_between = np.zeros(n_firing[task_2_a_post].shape)
                    a_between[:] = np.NaN
                    b_between = np.zeros(n_firing[task_2_b_post].shape)
                    b_between[:] = np.NaN

                surprise_array_a = np.concatenate([a_within, a_between],
                                                  axis=0)
                surprise_array_b = np.concatenate([b_within, b_between],
                                                  axis=0)

                surprise_list_neurons_a_a.append(surprise_array_a)
                surprise_list_neurons_b_b.append(surprise_array_b)

            surprise_list_neurons_a_a_perm.append(
                abs(
                    np.nanmean(surprise_list_neurons_a_a, 0)[21] -
                    np.asarray(np.nanmean(surprise_list_neurons_a_a, 0)[20])))
            surprise_list_neurons_b_b_perm.append(
                abs(
                    np.nanmean(surprise_list_neurons_b_b, 0)[21] -
                    np.asarray(np.nanmean(surprise_list_neurons_b_b, 0)[20])))

        surprise_list_neurons_a_a_p.append(
            np.percentile(np.asarray(surprise_list_neurons_a_a_perm),
                          95,
                          axis=0))
        surprise_list_neurons_b_b_p.append(
            np.percentile(np.asarray(surprise_list_neurons_b_b_perm),
                          95,
                          axis=0))

    surprise_list_neurons_a_a_p = np.nanmean(surprise_list_neurons_a_a_p, 0)
    surprise_list_neurons_b_b_p = np.nanmean(surprise_list_neurons_b_b_p, 0)

    return surprise_list_neurons_a_a_p, surprise_list_neurons_b_b_p
Example #10
0
def select_trials(Data, DM, max_number_per_block, ind_time=np.arange(0, 63)):

    all_sessions = []
    for data, dm in zip(Data, DM):

        trials, neurons, time = data.shape
        choices = dm[:, 1]
        block = dm[:, 4]
        task = dm[:, 5]
        state = dm[:, 0]

        data = np.mean(data[:, :, ind_time], axis=2)

        b_pokes = dm[:, 7]
        a_pokes = dm[:, 6]
        taskid = rc.task_ind(task, a_pokes, b_pokes)

        task_1 = np.where(taskid == 1)
        task_2 = np.where(taskid == 2)
        task_3 = np.where(taskid == 3)

        correct_a = 1 * choices.astype(bool) & state.astype(bool)
        choices_b = (choices - 1) * -1
        state_b = (state - 1) * -1
        correct_b = 1 * choices_b.astype(bool) & state_b.astype(bool)
        correct = correct_a + correct_b
        exp_choices = ut.exp_mov_ave(correct, tau=8, initValue=0.5, alpha=None)
        ind_choosing_correct = np.where(exp_choices > 0.65)[0]

        state_change = np.where(np.diff(block) != 0)[0] + 1
        state_change = np.append(state_change, 0)
        state_change = np.sort(state_change)

        choice_a_state_a = np.where((choices == 1) & (state == 1))[0]
        choice_b_state_b = np.where((choices == 0) & (state == 0))[0]
        if len(state_change) > 12:
            block_12_ind = state_change[12]
            state_change = state_change[:12]

        data = data[:block_12_ind]

        if len(state_change) > 11:

            state_1_correct = np.intersect1d(ind_choosing_correct,
                                             (np.where(block == 0)))
            state_2_correct = np.intersect1d(ind_choosing_correct,
                                             (np.where(block == 1)))
            state_3_correct = np.intersect1d(ind_choosing_correct,
                                             (np.where(block == 2)))
            state_4_correct = np.intersect1d(ind_choosing_correct,
                                             (np.where(block == 3)))
            state_5_correct = np.intersect1d(ind_choosing_correct,
                                             (np.where(block == 4)))
            state_6_correct = np.intersect1d(ind_choosing_correct,
                                             (np.where(block == 5)))
            state_7_correct = np.intersect1d(ind_choosing_correct,
                                             (np.where(block == 6)))

            state_8_correct = np.intersect1d(ind_choosing_correct,
                                             (np.where(block == 7)))
            state_9_correct = np.intersect1d(ind_choosing_correct,
                                             (np.where(block == 8)))
            state_10_correct = np.intersect1d(ind_choosing_correct,
                                              (np.where(block == 9)))
            state_11_correct = np.intersect1d(ind_choosing_correct,
                                              (np.where(block == 10)))
            state_12_correct = np.intersect1d(ind_choosing_correct,
                                              (np.where(block == 11)))

            change = [np.asarray([state_1_correct[0],state_2_correct[0],state_3_correct[0],state_4_correct[0],\
                                         state_5_correct[0],state_6_correct[0], state_7_correct[0],state_8_correct[0],\
                                         state_9_correct[0],state_10_correct[0], state_11_correct[0], state_12_correct[0]])][0]

            block_ch = np.zeros(12)
            ch = np.zeros(12)
            if task_1[0][-1] < task_2[0][-1] < task_3[0][-1]:
                block_ch[:] = state_change
                ch[:] = change

            elif task_1[0][-1] < task_3[0][-1] and task_3[0][-1] < task_2[0][
                    -1]:
                block_ch[:4] = state_change[:4]
                block_ch[4:8] = state_change[8:]
                block_ch[8:12] = state_change[4:8]
                ch[:4] = change[:4]
                ch[4:8] = change[8:]
                ch[8:12] = change[4:8]

            elif task_3[0][-1] < task_2[0][-1] and task_2[0][-1] < task_1[0][
                    -1]:
                block_ch[:4] = state_change[8:]
                block_ch[4:8] = state_change[4:8]
                block_ch[8:12] = state_change[:4]
                ch[:4] = change[8:]
                ch[4:8] = change[4:8]
                ch[8:12] = change[:4]

            elif task_3[0][-1] < task_1[0][-1] and task_3[0][-1] < task_2[0][
                    -1] and task_1[0][-1] < task_2[0][-1]:
                block_ch[:4] = state_change[8:]
                block_ch[4:8] = state_change[:4]
                block_ch[8:12] = state_change[4:8]
                ch[:4] = change[8:]
                ch[4:8] = change[:4]
                ch[8:12] = change[4:8]

            elif task_2[0][-1] < task_3[0][-1] and task_3[0][-1] < task_1[0][
                    -1]:
                block_ch[:4] = state_change[4:8]
                block_ch[4:8] = state_change[8:]
                block_ch[8:12] = state_change[:4]
                ch[:4] = change[4:8]
                ch[4:8] = change[8:]
                ch[8:12] = change[:4]

            elif task_2[0][-1] < task_1[0][-1] and task_1[0][-1] < task_3[0][
                    -1]:
                block_ch[:4] = state_change[4:8]
                block_ch[4:8] = state_change[:4]
                block_ch[8:12] = state_change[8:]
                ch[:4] = change[4:8]
                ch[4:8] = change[:4]
                ch[8:12] = change[8:]

            state_change_t1 = ch[:4]
            state_change_t2 = ch[4:8]
            state_change_t3 = ch[8:]


            state_change_t1_1_ind,state_change_t1_2_ind, state_change_t1_3_ind,state_change_t1_4_ind,\
            state_change_t2_1_ind,state_change_t2_2_ind, state_change_t2_3_ind,state_change_t2_4_ind,\
            state_change_t3_1_ind,state_change_t3_2_ind, state_change_t3_3_ind,state_change_t3_4_ind = state_behaviour_ind(state_change_t1,state_change_t2,state_change_t3, change, data)


            t1_a_state_1, t1_a_state_2 ,t1_b_state_1, t1_b_state_2 ,t2_a_state_1,t2_a_state_2, t2_b_state_1,t2_b_state_2,t3_a_state_1,t3_a_state_2,\
            t3_b_state_1,t3_b_state_2 = choose_a_a_b_b(choice_a_state_a, choice_b_state_b,state_change_t1_1_ind,state_change_t1_2_ind, state_change_t1_3_ind,state_change_t1_4_ind,\
            state_change_t2_1_ind,state_change_t2_2_ind, state_change_t2_3_ind,state_change_t2_4_ind,\
            state_change_t3_1_ind,state_change_t3_2_ind, state_change_t3_3_ind,state_change_t3_4_ind, block, block_ch)

            data_t1_1 = data[t1_a_state_1, :]
            data_t1_2 = data[t1_a_state_2, :]
            data_t1_3 = data[t1_b_state_1, :]
            data_t1_4 = data[t1_b_state_2, :]

            data_t2_1 = data[t2_a_state_1, :]
            data_t2_2 = data[t2_a_state_2, :]
            data_t2_3 = data[t2_b_state_1, :]
            data_t2_4 = data[t2_b_state_2, :]

            data_t3_1 = data[t3_a_state_1, :]
            data_t3_2 = data[t3_a_state_2, :]
            data_t3_3 = data[t3_b_state_1, :]
            data_t3_4 = data[t3_b_state_2, :]


            dict_names = {'data_t1_1':data_t1_1,'data_t1_2':data_t1_2,'data_t1_3':data_t1_3,'data_t1_4':data_t1_4,\
                         'data_t2_1':data_t2_1,'data_t2_2':data_t2_2,'data_t2_3':data_t2_3,'data_t2_4':data_t2_4,\
                         'data_t3_1':data_t3_1,'data_t3_2':data_t3_2,'data_t3_3':data_t3_3,'data_t3_4':data_t3_4}

            all_dict = {}
            for i in dict_names.keys():
                data_dict = {
                    i: np.full((data_t1_1.shape[1], max_number_per_block),
                               np.nan)
                }
                for n in range(dict_names[i].shape[1]):
                    x = np.arange(dict_names[i][:, n].shape[0])
                    y = dict_names[i][:, n]
                    f = interpolate.interp1d(x, y)

                    xnew = np.arange(0, dict_names[i][:, n].shape[0] - 1,
                                     (dict_names[i][:, n].shape[0] - 1) /
                                     max_number_per_block)
                    ynew = f(
                        xnew
                    )  # use interpolation function returned by `interp1d`
                    ynew = gaussian_filter1d(ynew, 10)
                    data_dict[i][n, :] = ynew[:max_number_per_block]

                all_dict.update(data_dict)
            all_sessions.append(all_dict)

    session_list = []
    for s in all_sessions:
        neuron_list = []
        for i in dict_names.keys():
            neuron_list.append(s[i])
        session_list.append(np.asarray(neuron_list))
    session_list = np.concatenate(session_list, 1)

    return session_list
Example #11
0
def a_b_i_coding(Data,Design):
    
    all_session_b1 =  []
    all_session_a1 =  []
    all_session_i1 =  []
        
    all_session_b2 =  []
    all_session_a2 =  []
    all_session_i2 =  []
        
    all_session_b3 =  []
    all_session_a3 =  []
    all_session_i3 =  []


  
    for  s, sess in enumerate(Data):
     

        DM = Design[s]
        x = Data[s]
        #state =  DM[:,0]
       
        choices = DM[:,1]
        b_pokes = DM[:,7]
        a_pokes = DM[:,6]
        task = DM[:,5]
        taskid = rc.task_ind(task,a_pokes,b_pokes)
        
        task_1 = np.where((taskid == 1))[0]       
        task_2 = np.where((taskid == 2))[0]
        task_3 = np.where((taskid == 3))[0]
       
        task_1_a = np.where((taskid == 1) & (choices == 0))[0]       
        task_2_a = np.where((taskid == 2) & (choices == 0))[0]
        task_3_a = np.where((taskid == 3) & (choices == 0))[0]
        
        task_1_b = np.where((taskid == 1) & (choices == 1))[0]       
        task_2_b = np.where((taskid == 2) & (choices == 1))[0]
        task_3_b = np.where((taskid == 3) & (choices == 1))[0]
        
        
        It = np.arange(23 ,27) #Init
        Ct = np.arange(33, 37) #Choice

        firing_rates_mean_time = x
        n_trials, n_neurons, n_time = firing_rates_mean_time.shape
        
        # Numpy arrays to fill the firing rates of each neuron where the A choice was made
                    
        ## Bs          
        b1_fr = np.mean(np.mean(firing_rates_mean_time[task_1_b][:,:,Ct], axis = 0), axis = 1)
        b2_fr = np.mean(np.mean(firing_rates_mean_time[task_2_b][:,:,Ct], axis = 0), axis = 1)
        b3_fr = np.mean(np.mean(firing_rates_mean_time[task_3_b][:,:,Ct], axis = 0), axis = 1)
   
        ## As
        a1_fr = np.mean(np.mean(firing_rates_mean_time[task_1_a][:,:,Ct],axis = 0), axis = 1)
        a2_fr = np.mean(np.mean(firing_rates_mean_time[task_2_a][:,:,Ct],axis = 0), axis = 1)
        a3_fr = np.mean(np.mean(firing_rates_mean_time[task_3_a][:,:,Ct],axis = 0), axis = 1)
        
        ## Is
        i1_fr = np.mean(np.mean(firing_rates_mean_time[task_1][:,:,It], axis = 0), axis = 1)
        i2_fr = np.mean(np.mean(firing_rates_mean_time[task_2][:,:,It], axis = 0), axis = 1)
        i3_fr = np.mean(np.mean(firing_rates_mean_time[task_3][:,:,It], axis = 0), axis = 1)
        
        fr_av_t1 = np.mean(np.mean(firing_rates_mean_time[task_1], axis = 0), axis = 1)
        fr_av_t2 = np.mean(np.mean(firing_rates_mean_time[task_2], axis = 0), axis = 1)
        fr_av_t3 = np.mean(np.mean(firing_rates_mean_time[task_3], axis = 0), axis = 1)
        
        b1_prop = b1_fr/fr_av_t1
        b2_prop = b2_fr/fr_av_t2
        b3_prop = b3_fr/fr_av_t3

        a1_prop = a1_fr/fr_av_t1
        a2_prop = a2_fr/fr_av_t2
        a3_prop = a3_fr/fr_av_t3


        i1_prop = i1_fr/fr_av_t1
        i2_prop = i2_fr/fr_av_t2
        i3_prop = i3_fr/fr_av_t3
        
        all_session_b1.append(b1_prop)
        all_session_a1.append(a1_prop)
        all_session_i1.append(i1_prop)
        
        all_session_b2.append(b2_prop)
        all_session_a2.append(a2_prop)
        all_session_i2.append(i2_prop)
        
        all_session_b3.append(b3_prop)
        all_session_a3.append(a3_prop)
        all_session_i3.append(i3_prop)
        
    return all_session_b1, all_session_a1, all_session_i1, all_session_b2, all_session_a2,\
    all_session_i2, all_session_b3, all_session_a3, all_session_i3
        
        
Example #12
0
def out_of_sequence(m484, m479, m483, m478, m486, m480, m481, data_HP,
                    data_PFC):
    #HP = m484 + m479 + m483
    #PFC = m478 + m486 + m480 + m481
    all_subjects = [
        data_HP['DM'][0][:16], data_HP['DM'][0][16:24], data_HP['DM'][0][24:],
        data_PFC['DM'][0][:9], data_PFC['DM'][0][9:25],
        data_PFC['DM'][0][25:39], data_PFC['DM'][0][39:]
    ]

    subj = [m484, m479, m483, m478, m486, m480, m481]
    all_subj_mean = []
    all_subj_std = []

    for s, subject in zip(subj, all_subjects):
        reversal_number = 0

        s = np.asarray(s)
        date = []
        for ses in s:
            date.append(ses.datetime)

        ind_sort = np.argsort(date)
        subject = np.asarray(subject)
        sessions_beh_event = s[ind_sort]
        sess_count = 0
        all_sessions_wrong_ch = [
        ]  # Get the list with only trials that were treated as trials in task programme
        reversal_number = 0
        task_number = 0

        all_reversals = []
        all_tasks = []

        for session, session_event in zip(sessions_beh_event):
            sess_count += 1
            reversal_number = 0
            sessions_block = session[:, 4]
            forced_trials = session[:, 3]
            forced_array = np.where(forced_trials == 1)[0]
            sessions_block = np.delete(sessions_block, forced_array)
            Block_transitions = sessions_block[
                1:] - sessions_block[:-1]  # Block transition
            task = session[:, 5]
            task = np.delete(task, forced_array)

            poke_I = np.delete(session[:, 8], forced_array)
            poke_A = np.delete(session[:, 6], forced_array)
            poke_B = np.delete(session[:, 7], forced_array)

            taskid = rc.task_ind(task, poke_A, poke_B)
            task_1 = np.where(taskid == 1)[0]
            task_2 = np.where(taskid == 2)[0]
            task_3 = np.where(taskid == 3)[0]

            poke_A_1 = 'poke_' + str(int(poke_A[task_1[0]]))
            poke_B_1 = 'poke_' + str(int(poke_B[task_1[0]]))
            poke_A_2 = 'poke_' + str(int(poke_A[task_2[0]]))
            poke_B_2 = 'poke_' + str(int(poke_B[task_2[0]]))
            poke_A_3 = 'poke_' + str(int(poke_A[task_3[0]]))
            poke_B_3 = 'poke_' + str(int(poke_B[task_3[0]]))

            poke_I_1 = 'poke_' + str(int(poke_I[task_1[0]]))
            poke_I_2 = 'poke_' + str(int(poke_I[task_2[0]]))
            poke_I_3 = 'poke_' + str(int(poke_I[task_3[0]]))

            reversal_trials = np.where(Block_transitions == 1)[0]

            if len(reversal_trials) >= 12:
                task_number += 1


                events = [event.name for event in session_event.events if event.name in ['choice_state', 'init_trial','a_forced_state', 'b_forced_state',poke_A_1, poke_B_1,\
                                                                                   poke_A_2, poke_B_2,poke_A_3, poke_B_3,poke_I_1 ,\
                                                                                   poke_I_2,poke_I_3]]

                session_wrong_choice = []
                # Go through events list and find the events of interest
                wrong_count = 0
                choice_state_count = 0
                wrong_count_state = []
                prev_choice = 'forced_state'
                prev_choice_arr = []
                choice_state = False
                for event in events:
                    prev_choice_arr.append(prev_choice)
                    if event == 'choice_state':
                        session_wrong_choice.append(wrong_count)
                        wrong_count = 0
                        choice_state = True
                        choice_state_count += 1
                        wrong_count_state.append('choice')

                    elif event == 'a_forced_state' or event == 'b_forced_state':
                        prev_choice = 'forced_state'

                    # In task 1 B is different to every other B, init  in 1 is the same as init in 3 --> so exclude init 3
                    if choice_state_count in task_1:  # Task 1
                        if choice_state_count == task_1[0]:
                            prev_choice = 'forced_state'

                        if event == poke_A_1:
                            if choice_state == True:
                                prev_choice = 'Poke_A_1'
                                choice_state = False
                            elif choice_state == False and prev_choice == 'Poke_B_1':
                                if event == poke_B_1:
                                    wrong_count += 1
                                    wrong_count_state.append(poke_A_1)

                        elif event == poke_B_1:
                            if choice_state == True:
                                prev_choice = 'Poke_B_1'
                                choice_state = False
                            elif choice_state == False and prev_choice == 'Poke_A_1':
                                if event == poke_B_1:
                                    wrong_count += 1
                                    wrong_count_state.append(poke_B_1)
#
# elif event == poke_I_2:
#     if choice_state == False and  prev_choice == 'Poke_B_1' or prev_choice == 'Poke_A_1' :
#         wrong_count += 1
#         wrong_count_state.append(poke_I_2)

# elif event == poke_B_2:
#     if choice_state == False and  prev_choice == 'Poke_B_1' or prev_choice == 'Poke_A_1' :
#         wrong_count += 1
#         wrong_count_state.append(poke_B_2)

# elif event == poke_B_3:
#     if choice_state == False and  prev_choice == 'Poke_B_1' or prev_choice == 'Poke_A_1' :
#         wrong_count += 1
#         wrong_count_state.append(poke_B_3)

# In task 2 B is different to every other B, init in 2 becomes B in 3 --> so exclude B 3 but include other Inits

                    elif choice_state_count in task_2:  # Task 2
                        if choice_state_count == task_2[0]:
                            prev_choice = 'forced_state'

                        if event == poke_A_2:
                            if choice_state == True:
                                prev_choice = 'Poke_A_2'
                                choice_state = False
                            elif choice_state == False and prev_choice == 'Poke_B_2':
                                wrong_count += 1
                                wrong_count_state.append(poke_B_2)

                        elif event == poke_B_2:
                            if choice_state == True:
                                prev_choice = 'Poke_B_2'
                                choice_state = False
                            elif choice_state == False and prev_choice == 'Poke_A_2':
                                wrong_count += 1
                                wrong_count_state.append(poke_A_2)

                        # elif event == poke_I_1:
                        #    if choice_state == False and  prev_choice == 'Poke_B_2' or prev_choice == 'Poke_A_2' :
                        #        wrong_count += 1
                        #        wrong_count_state.append(poke_I_1)

                        # elif event == poke_B_1:
                        #    if choice_state == False and  prev_choice == 'Poke_B_2' or prev_choice == 'Poke_A_2' :
                        #        wrong_count += 1
                        #        wrong_count_state.append(poke_B_1)

                    # In task 3 B is the same as Init in task 2, init in 2 becomes B in 3 --> so exclude I2 but include other Bs

                    elif choice_state_count in task_3:  # Task 2
                        if choice_state_count == task_3[0]:
                            prev_choice = 'forced_state'

                        if event == poke_A_3:
                            if choice_state == True:
                                prev_choice = 'Poke_A_3'
                                choice_state = False
                            elif choice_state == False and prev_choice == 'Poke_B_3':
                                wrong_count += 1

                        elif event == poke_B_3:

                            if choice_state == True:
                                prev_choice = 'Poke_B_3'
                                choice_state = False
                            elif choice_state == False and prev_choice == 'Poke_A_3':
                                wrong_count += 1
#
# elif event == poke_I_2:
#    if choice_state == False and  prev_choice == 'Poke_B_3' or prev_choice == 'Poke_A_3' :
#        wrong_count += 1

# elif event == poke_B_1:
#    if choice_state == False and  prev_choice == 'Poke_B_3' or prev_choice == 'Poke_A_3' :
#        wrong_count += 1

# elif event == poke_B_2:
#    if choice_state == False and  prev_choice == 'Poke_B_3' or prev_choice == 'Poke_A_3' :
#        wrong_count += 1

                if sess_count == 1:
                    all_sessions_wrong_ch = session_wrong_choice[:len(task)]

                elif sess_count > 1:
                    all_sessions_wrong_ch += session_wrong_choice[:len(task)]

                task_change = np.where(np.diff(task) != 0)[0]
                for i in range(len(task)):
                    if i in reversal_trials:
                        reversal_number += 1

                    if i in task_change:
                        task_number += 1
                        reversal_number = 0

                    all_reversals.append(reversal_number)
                    all_tasks.append(task_number)

        rev_over_4 = np.where(np.asarray(all_reversals) > 3)[0]
        all_reversals_np = np.delete(np.asarray(all_reversals), rev_over_4)

        all_tasks_np = np.delete(np.asarray(all_tasks), rev_over_4)
        pokes_np = np.delete(np.asarray(all_sessions_wrong_ch), rev_over_4)

        task_pl = np.zeros((18, 4))
        task_pl[:] = np.NaN
        std_plt = np.zeros((18, 4))
        std_plt[:] = np.NaN

        where_21_t = np.where(all_tasks_np > 18)[0]
        all_reversals_np = np.delete(np.asarray(all_reversals_np), where_21_t)
        all_tasks_np = np.delete(np.asarray(all_tasks_np), where_21_t)
        pokes_np = np.delete(np.asarray(pokes_np), where_21_t)
        for t in np.unique(all_tasks_np):
            for r in np.unique(all_reversals_np):
                plot_task = pokes_np[(all_tasks_np == t) &
                                     (all_reversals_np
                                      == r)]  # For plots from all trials
                #print(len(plot_task))
                mean_plot = np.mean(plot_task)
                std_plot = np.std(plot_task)

                task_pl[t - 1, r] = mean_plot
                std_plt[t - 1, r] = std_plot

        all_subj_mean.append(task_pl)
        all_subj_std.append(std_plt)

    all_subj_mean_np = np.asarray(all_subj_mean)
    tasks = all_subj_mean_np.shape[1]
    for task in range(tasks):
        pd.DataFrame(data=all_subj_mean_np[:, task, :]).to_csv(
            'task{}_rab_recording.csv'.format(task))

    tasks = all_subj_mean_np.shape[1]
    data = np.concatenate(all_subj_mean_np, 0)
    data = np.concatenate(data, 0)
    rev = np.tile(np.arange(4), 126)
    task_n = np.tile(np.repeat(np.arange(18), 4), 7)
    n_subj = np.repeat(np.arange(7), 72)

    # for task in range(tasks):
    #     pd.DataFrame(data=all_subjects[:,task,:]).to_csv('task{}_reversals_recording.csv'.format(task))

    anova = {'Data': data, 'Sub_id': n_subj, 'cond1': task_n, 'cond2': rev}

    anova_pd = pd.DataFrame.from_dict(data=anova)

    aovrm_es = pg.rm_anova(anova_pd,
                           dv='Data',
                           within=['cond1', 'cond2'],
                           subject='Sub_id')
    posthoc = pg.pairwise_ttests(data=anova_pd, dv='Data',within=['cond2'],subject = 'Sub_id',\
                             parametric=True, padjust='fdr_bh', effsize='hedges')

    all_rev = np.mean(all_subj_mean, axis=0)
    std_err = (np.std(all_subj_mean, axis=0)) / 7
    reversals = 4
    x = np.arange(reversals)
    plt.figure(figsize=(10, 5))

    for i in range(tasks):
        plt.plot(i * reversals + x, all_rev[i])
        plt.fill_between(i * reversals + x,
                         all_rev[i] - std_err[i],
                         all_rev[i] + std_err[i],
                         alpha=0.2)

    rev_1 = np.nanmean(all_rev[:, 0])
    rev_2 = np.nanmean(all_rev[:, 1])
    rev_3 = np.nanmean(all_rev[:, 2])
    rev_4 = np.nanmean(all_rev[:, 3])

    st_1 = np.nanmean(std_err[:, 0])
    st_2 = np.nanmean(std_err[:, 1])
    st_3 = np.nanmean(std_err[:, 2])
    st_4 = np.nanmean(std_err[:, 3])

    xs = [1, 2, 3, 4]
    rev = [rev_1, rev_2, rev_3, rev_4]
    st = [st_1, st_2, st_3, st_4]
    z = np.polyfit(xs, rev, 1)
    p = np.poly1d(z)
    plt.figure()
    plt.plot(xs, p(xs), "--", color='grey', label='Trend Line')

    plt.errorbar(x=xs,
                 y=rev,
                 yerr=st,
                 alpha=0.8,
                 linestyle='None',
                 marker='*',
                 color='Black')

    #plt.ylim(0,1.6)
    plt.ylabel('Number of Trials Till Threshold')
    plt.xlabel('Reversal Number')
    mean_rev = np.mean(all_subj_mean, axis=2)
    std_rev = np.std(all_subj_mean, axis=2)

    med_sub = np.mean(mean_rev, axis=0)
    std_sub = np.std(std_rev, axis=0)
    sample_size = np.sqrt(7)
    std_err_median = std_sub / sample_size
    x_pos = np.arange(len(med_sub))
    plt.figure(figsize=(10, 10))
    sns.set(style="white", palette="muted", color_codes=True)
    plt.errorbar(x=x_pos,
                 y=med_sub,
                 yerr=std_err_median,
                 alpha=0.8,
                 linestyle='None',
                 marker='*',
                 color='Black')

    z = np.polyfit(x_pos, med_sub, 1)
    p = np.poly1d(z)
    plt.plot(x_pos, p(x_pos), "--", color='grey', label='Trend Line')
    #plt.ylim(0,0.9)

    return aovrm_es, posthoc
Example #13
0
def perm_A(n_perms = 1000, area = 1):
    HP = io.loadmat('/Users/veronikasamborska/Desktop/HP.mat')
    PFC = io.loadmat('/Users/veronikasamborska/Desktop/PFC.mat')
    
    A_perms = []
    for perm in range(n_perms):

      
         
        ntrials=20;
        baselength=10;
        n = 3
        A = [[[ [] for _ in range(n)] for _ in range(n)] for _ in range(n)]
                   
        if area==1:
            Data = HP
        else:
            Data = PFC
        neuron_num=0
        
        for  i, ii in enumerate(Data['DM'][0]):
             
            
            DD = Data['Data'][0][i]
            DM = Data['DM'][0][i]
      
            choices = DM[:,1]
            b_pokes = DM[:,7]
            a_pokes = DM[:,6]
            task = DM[:,5]
            taskid = rc.task_ind(task,a_pokes,b_pokes)
            sw_point=np.where(abs(np.diff(task)>0))[0]+1
            b_s =np.where(choices==(0))[0]
            a_s =np.where(choices==(1))[0]
        
                
                          
            for stype in [1,2,3]: #1 is 1-2; 2 is 1-3; 3 is 2-3
            
                for s in range(2):
               
                    #figure out type of switch. 

                    prepost=[taskid[sw_point[s]-2], taskid[sw_point[s]+2]]
                             
                    if(sum(prepost)==stype+2):
                
                                   
                        #FIND LAST ntrials A before switch and first ntrials as after switch 
                        for ch in [1,2]:
                            
                            while  len(sw_point)<2:
                                task = np.roll(task,np.random.randint(len(task)), axis=0)
                                sw_point = np.where(abs(np.diff(task)>0))[0]+1
                            while len(np.where(b_s > sw_point[0])[0]) < 31 or len(np.where(b_s > sw_point[0])[0])< 31  or\
                                 len(np.where(a_s > sw_point[0])[0])< 31 or len(np.where(a_s > sw_point[0])[0])< 31 or\
                                     len(np.where(b_s > sw_point[1])[0]) < 31 or len(np.where(b_s > sw_point[1])[0])< 31  or\
                                 len(np.where(a_s > sw_point[1])[0])< 31 or len(np.where(a_s > sw_point[1])[0])< 31 or\
                                     len(np.where(b_s <= sw_point[0])[0]) < 31 or len(np.where(b_s <= sw_point[0])[0])< 31  or\
                                 len(np.where(a_s <= sw_point[0])[0])< 31 or len(np.where(a_s <= sw_point[0])[0])< 31 or\
                                     len(np.where(b_s <= sw_point[1])[0]) < 31 or len(np.where(b_s <= sw_point[1])[0])< 31  or\
                                 len(np.where(a_s <= sw_point[1])[0])< 31 or len(np.where(a_s <= sw_point[1])[0])< 31  :
                                    task = np.roll(task,np.random.randint(len(task)), axis=0)
                                    sw_point = np.where(abs(np.diff(task)>0))[0]+1
                                    while  len(sw_point)<2:
                                        task = np.roll(task,np.random.randint(len(task)), axis=0)
                                        sw_point = np.where(abs(np.diff(task)>0))[0]+1
                          
                        
                                                            
                            
                            Aind=np.where(choices==(ch-1))[0]
                   
                            Aind_pre_sw = Aind[Aind<=sw_point[s]]
                            
                            Aind_pre_sw = Aind_pre_sw[-ntrials-baselength-1:]
                   
                            Aind_post_sw = Aind[Aind>sw_point[s]]
                            
                            Aind_post_sw = Aind_post_sw[:ntrials]
                           
                            Atrials= np.hstack([Aind_pre_sw,Aind_post_sw])
                    
                            A[s][ch-1][stype-1].append((DD[Atrials]))
                            if DD[Atrials].shape[0] !=51:
                                print(sw_point, len(Aind_post_sw),len(Aind_pre_sw))
                                
                            
        A_perms.append(A)
    return A_perms
def regression_general(data):

    C = []
    cpd = []

    C_1 = []
    C_2 = []
    C_3 = []

    cpd_1_2 = []
    cpd_2_3 = []

    dm = data['DM']
    #dm = dm[:-1]
    firing = data['Data']
    #firing = firing[:-1]

    for s, sess in enumerate(dm):
        DM = dm[s]
        firing_rates = firing[s]
        n_trials, n_neurons, n_timepoints = firing_rates.shape

        if n_neurons > 10:
            session_trials_since_block = []

            state = DM[:, 0]
            choices = DM[:, 1]
            reward = DM[:, 2]
            b_pokes = DM[:, 7]
            a_pokes = DM[:, 6]
            task = DM[:, 5]
            block = DM[:, 4]
            block_df = np.diff(block)
            taskid = rc.task_ind(task, a_pokes, b_pokes)

            correct_choice = np.where(choices == state)[0]
            correct = np.zeros(len(choices))
            correct[correct_choice] = 1

            a_since_block = []
            trials_since_block = []
            t = 0

            #Bug in the state?
            for st, s in enumerate(block):
                if state[st - 1] != state[st]:
                    t = 0
                else:
                    t += 1
                trials_since_block.append(t)

            session_trials_since_block.append(trials_since_block)

            t = 0
            for st, (s, c) in enumerate(zip(block, choices)):
                if state[st - 1] != state[st]:
                    t = 0
                    a_since_block.append(t)

                elif c == 1:
                    t += 1
                    a_since_block.append(t)
                else:
                    a_since_block.append(0)

            negative_reward_count = []
            rew = 0
            block_df = np.append(block_df, 0)
            for r, b in zip(reward, block_df):

                if r == 0:
                    rew += 1
                    negative_reward_count.append(rew)
                elif r == 1:
                    rew -= 1
                    negative_reward_count.append(rew)
                if b != 0:
                    rew = 0

            positive_reward_count = []
            rew = 0
            block_df = np.append(block_df, 0)
            for r, b in zip(reward, block_df):

                if r == 1:
                    rew += 1
                    positive_reward_count.append(rew)
                elif r == 0:
                    rew += 0
                    positive_reward_count.append(rew)
                if b != 0:
                    rew = 0

            positive_reward_count = np.asarray(positive_reward_count)
            negative_reward_count = np.asarray(negative_reward_count)
            choices_int = np.ones(len(reward))

            choices_int[np.where(choices == 0)] = -1
            reward_choice_int = choices_int * reward
            interaction_trial_latent = trials_since_block * state
            interaction_a_latent = a_since_block * state
            int_a_reward = a_since_block * reward

            interaction_trial_choice = trials_since_block * choices_int
            reward_trial_in_block = trials_since_block * positive_reward_count
            negative_reward_count_st = negative_reward_count * correct
            positive_reward_count_st = positive_reward_count * correct
            negative_reward_count_ch = negative_reward_count * choices
            positive_reward_count_ch = positive_reward_count * choices
            ones = np.ones(len(choices))

            predictors_all = OrderedDict([
                ('Reward', reward),
                ('Choice', choices),
                #('Correct', correct),
                #('A in Block', a_since_block),
                #('A in Block x Reward', int_a_reward),
                ('State', state),
                ('Trial in Block', trials_since_block),
                #('Interaction State x Trial in Block', interaction_trial_latent),
                #('Interaction State x A count', interaction_a_latent),
                ('Choice x Trials in Block', interaction_trial_choice),
                ('Reward x Choice', reward_choice_int),
                # ('No Reward Count in a Block', negative_reward_count),
                # ('No Reward x Correct', negative_reward_count_st),
                # ('Reward Count in a Block', positive_reward_count),
                # ('Reward Count x Correct', positive_reward_count_st),
                # ('No reward Count x Choice',negative_reward_count_ch),
                # ('Reward Count x Choice',positive_reward_count_ch),
                # ('Reward x Trial in Block',reward_trial_in_block),
                ('ones', ones)
            ])

            X = np.vstack(
                predictors_all.values()).T[:len(choices), :].astype(float)
            n_predictors = X.shape[1]
            y = firing_rates.reshape(
                [len(firing_rates),
                 -1])  # Activity matrix [n_trials, n_neurons*n_timepoints]
            tstats = reg_f.regression_code(y, X)

            C.append(tstats.reshape(n_predictors, n_neurons,
                                    n_timepoints))  # Predictor loadings
            cpd.append(
                re._CPD(X, y).reshape(n_neurons, n_timepoints, n_predictors))

            task_1 = np.where(taskid == 1)[0]
            task_2 = np.where(taskid == 2)[0]
            task_3 = np.where(taskid == 3)[0]

            # Task 1
            reward_t1 = reward[task_1]
            choices_t1 = choices[task_1]
            correct_t1 = correct[task_1]

            a_since_block_t1 = np.asarray(a_since_block)[task_1]
            int_a_reward_t1 = int_a_reward[task_1]
            state_t1 = state[task_1]
            trials_since_block_t1 = np.asarray(trials_since_block)[task_1]
            interaction_trial_latent_t1 = interaction_trial_latent[task_1]
            interaction_a_latent_t1 = interaction_a_latent[task_1]
            interaction_trial_choice_t1 = interaction_trial_choice[task_1]
            reward_choice_int_t1 = reward_choice_int[task_1]
            negative_reward_count_t1 = negative_reward_count[task_1]
            negative_reward_count_st_t1 = negative_reward_count_st[task_1]
            positive_reward_count_t1 = positive_reward_count[task_1]
            positive_reward_count_st_t1 = positive_reward_count_st[task_1]
            negative_reward_count_ch_t1 = negative_reward_count_ch[task_1]
            positive_reward_count_ch_t1 = positive_reward_count_ch[task_1]
            reward_trial_in_block_t1 = reward_trial_in_block[task_1]

            firing_rates_t1 = firing_rates[task_1]
            ones = np.ones(len(choices_t1))

            predictors = OrderedDict([
                ('Reward', reward_t1), ('Choice', choices_t1),
                ('Correct', correct_t1), ('A in Block', a_since_block_t1),
                ('A in Block x Reward', int_a_reward_t1), ('State', state_t1),
                ('Trial in Block', trials_since_block_t1),
                ('Interaction State x Trial in Block',
                 interaction_trial_latent_t1),
                ('Interaction State x A count', interaction_a_latent_t1),
                ('Choice x Trials in Block', interaction_trial_choice_t1),
                ('Reward x Choice', reward_choice_int_t1),
                ('No Reward Count in a Block', negative_reward_count_t1),
                ('No Reward x Correct', negative_reward_count_st_t1),
                ('Reward Count in a Block', positive_reward_count_t1),
                ('Reward Count x Correct', positive_reward_count_st_t1),
                ('No reward Count x Choice', negative_reward_count_ch_t1),
                ('Reward Count x Choice', positive_reward_count_ch_t1),
                ('Reward x Trial in Block', reward_trial_in_block_t1),
                ('ones', ones)
            ])

            X_1 = np.vstack(
                predictors.values()).T[:len(choices_t1), :].astype(float)
            n_predictors = X_1.shape[1]
            y_1 = firing_rates_t1.reshape(
                [len(firing_rates_t1),
                 -1])  # Activity matrix [n_trials, n_neurons*n_timepoints]
            tstats = reg_f.regression_code(y_1, X_1)

            C_1.append(tstats.reshape(n_predictors, n_neurons,
                                      n_timepoints))  # Predictor loadings

            # Task 2
            reward_t2 = reward[task_2]
            choices_t2 = choices[task_2]
            correct_t2 = correct[task_2]

            a_since_block_t2 = np.asarray(a_since_block)[task_2]
            int_a_reward_t2 = int_a_reward[task_2]
            state_t2 = state[task_2]
            trials_since_block_t2 = np.asarray(trials_since_block)[task_2]
            interaction_trial_latent_t2 = interaction_trial_latent[task_2]
            interaction_a_latent_t2 = interaction_a_latent[task_2]
            interaction_trial_choice_t2 = interaction_trial_choice[task_2]
            reward_choice_int_t2 = reward_choice_int[task_2]
            negative_reward_count_t2 = negative_reward_count[task_2]
            negative_reward_count_st_t2 = negative_reward_count_st[task_2]
            positive_reward_count_t2 = positive_reward_count[task_2]
            positive_reward_count_st_t2 = positive_reward_count_st[task_2]
            negative_reward_count_ch_t2 = negative_reward_count_ch[task_2]
            positive_reward_count_ch_t2 = positive_reward_count_ch[task_2]
            reward_trial_in_block_t2 = reward_trial_in_block[task_2]

            firing_rates_t2 = firing_rates[task_2]
            ones = np.ones(len(choices_t2))

            predictors = OrderedDict([
                ('Reward', reward_t2), ('Choice', choices_t2),
                ('Correct', correct_t2), ('A in Block', a_since_block_t2),
                ('A in Block x Reward', int_a_reward_t2), ('State', state_t2),
                ('Trial in Block', trials_since_block_t2),
                ('Interaction State x Trial in Block',
                 interaction_trial_latent_t2),
                ('Interaction State x A count', interaction_a_latent_t2),
                ('Choice x Trials in Block', interaction_trial_choice_t2),
                ('Reward x Choice', reward_choice_int_t2),
                ('No Reward Count in a Block', negative_reward_count_t2),
                ('No Reward x Correct', negative_reward_count_st_t2),
                ('Reward Count in a Block', positive_reward_count_t2),
                ('Reward Count x Correct', positive_reward_count_st_t2),
                ('No reward Count x Choice', negative_reward_count_ch_t2),
                ('Reward Count x Choice', positive_reward_count_ch_t2),
                ('Reward x Trial in Block', reward_trial_in_block_t2),
                ('ones', ones)
            ])

            X_2 = np.vstack(
                predictors.values()).T[:len(choices_t2), :].astype(float)
            n_predictors = X_2.shape[1]
            y_2 = firing_rates_t2.reshape(
                [len(firing_rates_t2),
                 -1])  # Activity matrix [n_trials, n_neurons*n_timepoints]
            tstats = reg_f.regression_code(y_2, X_2)

            C_2.append(tstats.reshape(n_predictors, n_neurons,
                                      n_timepoints))  # Predictor loadings

            # Task 3
            reward_t3 = reward[task_3]
            choices_t3 = choices[task_3]
            correct_t3 = correct[task_3]

            a_since_block_t3 = np.asarray(a_since_block)[task_3]
            int_a_reward_t3 = int_a_reward[task_3]
            state_t3 = state[task_3]
            trials_since_block_t3 = np.asarray(trials_since_block)[task_3]
            interaction_trial_latent_t3 = interaction_trial_latent[task_3]
            interaction_a_latent_t3 = interaction_a_latent[task_3]
            interaction_trial_choice_t3 = interaction_trial_choice[task_3]
            reward_choice_int_t3 = reward_choice_int[task_3]
            negative_reward_count_t3 = negative_reward_count[task_3]
            negative_reward_count_st_t3 = negative_reward_count_st[task_3]
            positive_reward_count_t3 = positive_reward_count[task_3]
            positive_reward_count_st_t3 = positive_reward_count_st[task_3]
            negative_reward_count_ch_t3 = negative_reward_count_ch[task_3]
            positive_reward_count_ch_t3 = positive_reward_count_ch[task_3]
            reward_trial_in_block_t3 = reward_trial_in_block[task_3]

            firing_rates_t3 = firing_rates[task_3]
            ones = np.ones(len(choices_t3))

            predictors = OrderedDict([
                ('Reward', reward_t3), ('Choice', choices_t3),
                ('Correct', correct_t3), ('A in Block', a_since_block_t3),
                ('A in Block x Reward', int_a_reward_t3), ('State', state_t3),
                ('Trial in Block', trials_since_block_t3),
                ('Interaction State x Trial in Block',
                 interaction_trial_latent_t3),
                ('Interaction State x A count', interaction_a_latent_t3),
                ('Choice x Trials in Block', interaction_trial_choice_t3),
                ('Reward x Choice', reward_choice_int_t3),
                ('No Reward Count in a Block', negative_reward_count_t3),
                ('No Reward x Correct', negative_reward_count_st_t3),
                ('Reward Count in a Block', positive_reward_count_t3),
                ('Reward Count x Correct', positive_reward_count_st_t3),
                ('No reward Count x Choice', negative_reward_count_ch_t3),
                ('Reward Count x Choice', positive_reward_count_ch_t3),
                ('Reward x Trial in Block', reward_trial_in_block_t3),
                ('ones', ones)
            ])

            X_3 = np.vstack(
                predictors.values()).T[:len(choices_t3), :].astype(float)
            n_predictors = X_3.shape[1]
            y_3 = firing_rates_t3.reshape(
                [len(firing_rates_t3),
                 -1])  # Activity matrix [n_trials, n_neurons*n_timepoints]
            tstats = reg_f.regression_code(y_3, X_3)

            C_3.append(tstats.reshape(n_predictors, n_neurons,
                                      n_timepoints))  # Predictor loadings

            cpd_1_2.append(
                _CPD_cross_task(X_1, X_2, y_1,
                                y_2).reshape(n_neurons, n_timepoints,
                                             n_predictors))

            cpd_2_3.append(
                _CPD_cross_task(X_2, X_3, y_2,
                                y_3).reshape(n_neurons, n_timepoints,
                                             n_predictors))

            print(n_neurons)

    cpd = np.nanmean(np.concatenate(cpd, 0), axis=0)
    C = np.concatenate(C, 1)

    C_1 = np.concatenate(C_1, 1)

    C_2 = np.concatenate(C_2, 1)

    C_3 = np.concatenate(C_3, 1)

    cpd_1_2 = np.nanmean(np.concatenate(cpd_1_2, 0), axis=0)
    cpd_2_3 = np.nanmean(np.concatenate(cpd_2_3, 0), axis=0)

    return C, cpd, C_1, C_2, C_3, cpd_1_2, cpd_2_3, predictors_all, session_trials_since_block