def real_neuron_PCA(neuron_file='data/adultneuron.txt', time_len=176, direc_num=8,n_components=2):

    neuron_reaction = list()
    with open(neuron_file,'r') as nf:
        for line in nf:
            line = line.rstrip('\n').split(',')
            for i in range(len(line)):
                line[i] = float(line[i])
            neuron_reaction.append(line)
    neuron_reaction = np.array(neuron_reaction).T

    pca=PCA(n_components=n_components)
    n_pca=pca.fit_transform(StandardScaler().fit_transform(neuron_reaction))
    #n_pca=pca.fit_transform(neuron_reaction)

    '''

    fig= plt.figure(figsize=(10,10))

    if n_components == 2:
        ax = plt.axes()
    elif n_components == 3:
        ax = plt.axes(projection='3d')

    for loc in range(direc_num):
        pc = n_pca[loc*time_len:(loc+1)*time_len,:]
        if n_components == 2:
            ax.plot(pc[:,0],pc[:,1],label='loc:'+str(loc+1))
        elif n_components == 3:
            ax.plot3D(pc[:,0],pc[:,1],pc[:,2],label='loc:'+str(loc+1))
        
    plt.legend()

    sample_name = neuron_file.split('/')[-1].split('.')[0]
    save_path = 'figure/real_neuron_pca/'
    mkdir_p(save_path)
    plt.savefig(save_path+sample_name+'comp'+str(n_components)+'_all_loc_bycolor_conc_std.pdf',bbox_inches='tight')
    plt.close()
    '''

    #8conc_subplot_time

    fig= plt.figure(figsize=(10,5*n_components))
    time_line = np.arange(time_len)*20/1000

    for n_comp in range(n_components):
        index = (n_components,1,n_comp+1)
        ax = fig.add_subplot(index[0],index[1],index[2])
        ax.set_title('component:'+str(n_comp+1))

        for loc in range(direc_num):
            pc = n_pca[loc*time_len:(loc+1)*time_len,n_comp]
            ax.plot(time_line,pc,label='loc:'+str(loc))
        ax.legend()

    sample_name = neuron_file.split('/')[-1].split('.')[0]
    save_path = 'figure/real_neuron_pca/'
    mkdir_p(save_path)
    plt.savefig(save_path+sample_name+'comp'+str(n_components)+'_all_loc_bycomp_conc_std.pdf',bbox_inches='tight')
    plt.close()
Пример #2
0
def print_basic_info(hp,
                     log,
                     model_dir,
                     smooth_growth=True,
                     smooth_window=5,
                     auto_range_select=False,
                     avr_window=9,
                     perf_margin=0.05,
                     max_trial_num_limit=30):

    print('rule trained: ', hp['rule_trains'])
    print('minimum trial number: 0')
    print('maximum trial number: ', log['trials'][-1])
    print('minimum trial step  : ', log['trials'][1])
    print('total number        : ', len(log['trials']))

    fig_pref = plt.figure(figsize=(12, 9))
    #############################
    if auto_range_select:
        trial_selected = dict()
    #############################
    for rule in hp['rule_trains']:
        if smooth_growth:
            growth = tools.smooth(log['perf_' + rule], smooth_window)
        else:
            growth = log['perf_' + rule]

        plt.plot(log['trials'], growth, label=rule)

        ################################################################
        if auto_range_select:
            trial_selected[rule] = tools.range_auto_select(hp,log,log['perf_'+rule],\
                avr_window=avr_window,perf_margin=perf_margin,max_trial_num_limit=max_trial_num_limit)
            for m_c in [('early', 'green'), ('mid', 'blue'),
                        ('mature', 'red')]:
                plt.fill_between(log['trials'], growth, where=[i in trial_selected[rule][m_c[0]] for i in log['trials']],\
                    facecolor=m_c[1],alpha=0.3)
        ################################################################

    tools.mkdir_p('figure/figure_' + model_dir.rstrip('/').split('/')[-1] +
                  '/')

    plt.xlabel("trial trained")
    plt.ylabel("perf")
    plt.legend(bbox_to_anchor=(1.05, 0), loc=3, borderaxespad=0)
    plt.title('Growth of Performance')
    save_name = 'figure/figure_' + model_dir.rstrip('/').split(
        '/')[-1] + '/growth_of_performance'
    plt.tight_layout()
    plt.savefig(save_name + '.png', transparent=False, bbox_inches='tight')
    plt.savefig(save_name + '.pdf', transparent=False, bbox_inches='tight')
    plt.savefig(save_name + '.eps', transparent=False, bbox_inches='tight')
    plt.show()
    #########################
    if auto_range_select:
        return trial_selected
Пример #3
0
def plot_PSTH(
    hp,
    log,
    model_dir,
    rule,
    epoch,
    trial_list,
    n_types=('exh_neurons', 'mix_neurons'),
    plot_oppo_dir=False,
    norm=True,
    PSTH_log=None,
):

    print("Start ploting PSTH")
    print("\trule: " + rule + " selective epoch: " + epoch)

    n_number = dict()

    if PSTH_log is None:
        PSTH_log = gen_PSTH_log(hp,
                                trial_list,
                                model_dir,
                                rule,
                                epoch,
                                n_types=n_types,
                                norm=norm)

        if plot_oppo_dir:
            PSTH_log_oppo = gen_PSTH_log(hp,
                                         trial_list,
                                         model_dir,
                                         rule,
                                         epoch,
                                         n_types=n_types,
                                         norm=norm,
                                         oppo_sel_dir=plot_oppo_dir)
            for key, value in PSTH_log_oppo.items():
                PSTH_log_oppo[key] = value.mean(axis=0)

    for key, value in PSTH_log.items():
        n_number[key] = np.size(value, 0)
        PSTH_log[key] = value.mean(axis=0)

    data_to_plot = dict()
    data_types = ["PSTH", "n_num", "growth"]
    if plot_oppo_dir:
        data_types.append("PSTH_oppo")

    is_dict = False
    is_list = False
    if isinstance(trial_list, dict):
        temp_list = list()
        is_dict = True
        for value in trial_list[rule].values():
            temp_list += value
        temp_list = sorted(set(temp_list))
    elif isinstance(trial_list, list):
        temp_list = trial_list
        is_list = True

    for trial_num in temp_list:
        growth = log['perf_' + rule][trial_num // log['trials'][1]]

        #if growth <= hp['early_target_perf']:
        #if hp['early_target_perf']-0.05 <= growth <= hp['early_target_perf']+0.05:
        if (is_list and growth > hp['mid_target_perf']) or (
                is_dict and trial_num in trial_list[rule]['mature']):
            m_key = "mature"
        #elif growth <= hp['mid_target_perf']:
        #elif hp['mid_target_perf']-0.05 <= growth <= hp['mid_target_perf']+0.05:
        elif (is_list and growth > hp['early_target_perf']) or (
                is_dict and trial_num in trial_list[rule]['mid']):
            m_key = "mid"
        #else:
        #elif hp['mature_target_perf']-0.05 <= growth <= hp['mature_target_perf']+0.05:
        elif is_list or (is_dict and trial_num in trial_list[rule]['early']):
            m_key = "early"
        #else:
        #continue

        if m_key not in data_to_plot:
            data_to_plot[m_key] = dict()
            for data_type in data_types:
                data_to_plot[m_key][data_type] = list()

        data_to_plot[m_key]["PSTH"].append(PSTH_log[trial_num])
        data_to_plot[m_key]["growth"].append(growth)
        data_to_plot[m_key]["n_num"].append(n_number[trial_num])
        if plot_oppo_dir:
            data_to_plot[m_key]["PSTH_oppo"].append(PSTH_log_oppo[trial_num])

    for m_key in data_to_plot.keys():
        for data_type in data_types:
            data_to_plot[m_key][data_type] = np.array(
                data_to_plot[m_key][data_type]).mean(axis=0)

    # plot #
    save_path = 'figure/figure_' + model_dir.rstrip('/').split(
        '/')[-1] + '/' + rule + '/' + epoch + '/' + '_'.join(n_types) + '/'

    if is_dict or len(temp_list) == 1:
        step = 'None'
    else:
        step = str(temp_list[1] - temp_list[0])

    if is_dict:
        trial_range = 'auto_choose'
    else:
        trial_range = str((temp_list[0], temp_list[-1]))

    title = 'Rule:' + rule + ' Epoch:' + epoch + ' Neuron_type:' + '_'.join(
        n_types) + ' trial range:' + trial_range + ' step:' + step

    colors = {
        "early": "green",
        "mid": "blue",
        "mature": "red",
    }

    fig, ax = plt.subplots(figsize=(14, 10))
    fig.suptitle(title)

    for m_key in data_to_plot.keys():
        ax.plot(np.arange(len(data_to_plot[m_key]["PSTH"]))*hp['dt']/1000, data_to_plot[m_key]["PSTH"],\
            label= m_key+'_%.2f'%(data_to_plot[m_key]["growth"])+'_n%d'%(data_to_plot[m_key]["n_num"]), color=colors[m_key])
        if plot_oppo_dir:
            ax.plot(np.arange(len(data_to_plot[m_key]["PSTH_oppo"]))*hp['dt']/1000, data_to_plot[m_key]["PSTH_oppo"],\
                label= m_key+'_opposite_sel_dir', color=colors[m_key], linestyle = '--')

    ax.set_xlabel("time/s")
    ax.set_ylabel("activity")
    ax.legend(bbox_to_anchor=(1.05, 0), loc=3, borderaxespad=0)

    mkdir_p(save_path)
    plt.savefig(save_path + rule + '_' + epoch + '_' + trial_range + '_step_' +
                step + '_PSTH.pdf',
                bbox_inches='tight')
    plt.savefig(save_path + rule + '_' + epoch + '_' + trial_range + '_step_' +
                step + '_PSTH.eps',
                bbox_inches='tight')
    plt.savefig(save_path + rule + '_' + epoch + '_' + trial_range + '_step_' +
                step + '_PSTH.png',
                bbox_inches='tight')

    plt.close()
Пример #4
0
def odrd_distractor_analysis(hp,log,model_dir,trial_list,):

    early_record = list()
    mid_record = list()
    mature_record = list()

    early_saccade_dir = dict()
    mid_saccade_dir = dict()
    mature_saccade_dir = dict()
    for d in range(hp['n_eachring']//2+1):
        early_saccade_dir[d] = list()
        mid_saccade_dir[d] = list()
        mature_saccade_dir[d] = list()

    is_dict = False
    is_list = False
    if isinstance(trial_list, dict):
        temp_list = list()
        is_dict = True
        for value in trial_list['odrd'].values():
            temp_list += value
        temp_list = sorted(set(temp_list))
    elif isinstance(trial_list, list):
        temp_list = trial_list
        is_list = True

    for trial_num in temp_list:

        saccade_dir_temp= dict()
        for d in range(hp['n_eachring']//2+1):
            saccade_dir_temp[d] = list()

        temp_pref_list = np.zeros(hp['n_eachring']//2+1)

        model = Model(model_dir+'/'+str(trial_num)+'/', hp=hp)
        with tf.Session() as sess:
            model.restore()

            for stim1 in range(0,hp['n_eachring']):
                for distrac in range(0,hp['n_eachring']):
                    task_mode = 'test-'+str(stim1)+'-'+str(distrac)
                    trial = generate_trials('odrd', hp, task_mode)
                    feed_dict = tools.gen_feed_dict(model, trial, hp)
                    y_hat = sess.run(model.y_hat,feed_dict=feed_dict)
                    temp_pref, dist = get_perf(y_hat, trial.y_loc)
                    temp_pref = np.mean(temp_pref)
                    temp_pref_list[get_abs_dist(stim1,distrac,hp['n_eachring'])] += temp_pref/hp['n_eachring']

                    saccade_dir_temp[get_abs_dist(stim1,distrac,hp['n_eachring'])] += dist

        for i in range(1,len(temp_pref_list)-1):
            temp_pref_list[i] /= 2

        if len(temp_pref_list)%2 == 0:
            temp_pref_list[-1] /= 2

        matur = log['perf_odrd'][trial_num//log['trials'][1]]
        if (is_list and matur > hp['mid_target_perf']) or (is_dict and trial_num in trial_list['odrd']['mature']):
            mature_record.append(temp_pref_list)
            for key,value in saccade_dir_temp.items():
                mature_saccade_dir[key] += value
        elif (is_list and matur > hp['early_target_perf']) or (is_dict and trial_num in trial_list['odrd']['mid']):
            mid_record.append(temp_pref_list)
            for key,value in saccade_dir_temp.items():
                mid_saccade_dir[key] += value
        elif is_list or (is_dict and trial_num in trial_list['odrd']['early']):
            early_record.append(temp_pref_list)
            for key,value in saccade_dir_temp.items():
                early_saccade_dir[key] += value

    early_trial_count = len(early_record)
    mid_trial_count = len(mid_record)
    mature_trial_count = len(mature_record)

    early_record = np.array(early_record).mean(axis=0)
    mid_record = np.array(mid_record).mean(axis=0)
    mature_record = np.array(mature_record).mean(axis=0)

    abs_dist_list = np.arange(hp['n_eachring']//2+1)

    fig,ax = plt.subplots(figsize=(10,6))
    try:
        ax.plot(abs_dist_list,early_record,color='green',label='early trial_num:'+str(early_trial_count))
    except:
        pass
    try:
        ax.plot(abs_dist_list,mid_record,color='blue',label='mid trial_num:'+str(mid_trial_count))
    except:
        pass
    try:
        ax.plot(abs_dist_list,mature_record,color='red',label='mature trial_num:'+str(mature_trial_count))
    except:
        pass
    ax.set_xticks(abs_dist_list)
    ax.set_ylabel("perf")
    ax.set_xlabel("distance between distractor and stim1 ($\\times$%.1f$\degree$)"%(360/hp['n_eachring']))
    ax.legend()

    save_folder = 'figure/figure_'+model_dir.rstrip('/').split('/')[-1]+'/odrd/'
    tools.mkdir_p(save_folder)
    save_pic = save_folder+'odrd_distractor_analysis_by_growth'
    plt.savefig(save_pic+'.png',transparent=False)
    plt.savefig(save_pic+'.eps',transparent=False)
    plt.savefig(save_pic+'.pdf',transparent=False)

    plt.close(fig)

    for d in range(hp['n_eachring']//2+1):
        fig,axes = plt.subplots(1,3,figsize=(18,6))
        axes[0].hist(early_saccade_dir[d],bins=30,range=(0,180), histtype="stepfilled",alpha=0.6, color="green")
        axes[0].set_title("early")
        axes[1].hist(mid_saccade_dir[d],bins=30,range=(0,180), histtype="stepfilled",alpha=0.6, color="blue")
        axes[1].set_title("mid")
        axes[2].hist(mature_saccade_dir[d],bins=30,range=(0,180), histtype="stepfilled",alpha=0.6, color="red")
        axes[2].set_title("mature")
        for i in range(3):
            axes[i].set_xlabel("distance to stim1($\degree$)")
        fig.suptitle("distractor distance: "+str(d))

        save_pic = save_folder+'saccade_distribut_analysis_by_growth_dis_'+str(d)
        plt.savefig(save_pic+'.png',transparent=False)
        plt.savefig(save_pic+'.eps',transparent=False)
        plt.savefig(save_pic+'.pdf',transparent=False)

        plt.close(fig)
Пример #5
0
def tunning_analysis(
                    hp,
                    log,
                    model_dir, 
                    rule,
                    epoch,
                    trial_list,
                    n_types=('exh_neurons','mix_neurons'),
                    gaussion_fit = True,
                    height_ttest = True, 
                    ):
    
    if gaussion_fit:
        # curve fit #
        from scipy.optimize import curve_fit
        import math

        def gaussian(x, a,u, sig):
            return a*np.exp(-(x - u) ** 2 / (2 * sig ** 2)) / (sig * math.sqrt(2 * math.pi))

    if height_ttest:
        # independent t-test #
        from scipy.stats import ttest_ind

    tuning_store = dict()
    info_store = dict()
    height_store = dict()

    is_dict = False
    is_list = False
    if isinstance(trial_list, dict):
        temp_list = list()
        is_dict = True
        for value in trial_list[rule].values():
            temp_list += value
        temp_list = sorted(set(temp_list))
    elif isinstance(trial_list, list):
        temp_list = trial_list
        is_list = True
    
    for trial_num in temp_list:
        growth = log['perf_'+rule][trial_num//log['trials'][1]]
        if (is_list and growth > hp['mid_target_perf']) or (is_dict and trial_num in trial_list[rule]['mature']):
            mature_key = "mature"
        elif (is_list and growth > hp['early_target_perf']) or (is_dict and trial_num in trial_list[rule]['mid']):
            mature_key = "mid"
        elif is_list or (is_dict and trial_num in trial_list[rule]['early']):
            mature_key = "early"

        if mature_key not in tuning_store:
            tuning_store[mature_key] = list()
            info_store[mature_key] = list()
            height_store[mature_key] = list()

        n_list = list()

        read_name = model_dir+'/'+str(trial_num)+'/neuron_info_'+rule+'_'+epoch+'.pkl'
        with open(read_name,'rb') as nf:
            ninf = pickle.load(nf)

        for ntype in n_types:
            n_list += ninf[ntype]
        n_list = list(set(n_list))

        trial_avrg_tuning = list()
        for neuron_inf in n_list:
            max_dir = neuron_inf[2]
            tuning = ninf['firerate_loc_order'][neuron_inf[0]]
            trial_avrg_tuning.append(max_central(max_dir,tuning))
            height_store[mature_key].append(tuning.max()-tuning.min())

        trial_avrg_tuning = np.array(trial_avrg_tuning).mean(axis=0)

        tuning_store[mature_key].append(trial_avrg_tuning)
        info_store[mature_key].append((len(n_list),growth))

    for key in tuning_store.keys():
        tuning_store[key] = np.array(tuning_store[key])
        height_store[key] = np.array(height_store[key])


    fig,ax = plt.subplots(figsize=(16,10))

    if is_dict or len(temp_list) == 1:
        step = 'None'
    else:
        step = str(temp_list[1]-temp_list[0])

    if is_dict:
        trial_range = 'auto_choose'
    else:
        trial_range = str((temp_list[0],temp_list[-1]))
    title = 'Rule:'+rule+' Epoch:'+epoch+' trial range:'+trial_range+' step:'+step

    for mature_key in tuning_store.keys():

        if mature_key == 'mature':
            color = 'red'
        elif mature_key == 'mid':
            color = 'blue'
        elif mature_key == 'early':
            color = 'green'

        temp_tuning = tuning_store[mature_key].mean(axis=0)
        temp_x = np.arange(len(temp_tuning))

        avg_n_number = int(np.array([x[0] for x in info_store[mature_key]]).mean())
        avg_growth = np.array([x[1] for x in info_store[mature_key]]).mean()

        ax.scatter(temp_x, temp_tuning, marker = '+',color = color, s = 70 ,\
                                label = mature_key+' avg_n_num(int):'+str(avg_n_number)+' avg_growth:%.2f'%(avg_growth))

        if gaussion_fit:
            gaussian_x = np.arange(-0.1,len(temp_tuning)-0.9,0.1)
            paras , _ = curve_fit(gaussian,temp_x,temp_tuning+(-1)*np.min(temp_tuning),\
                p0=[np.max(temp_tuning)+1,len(temp_tuning)//2,1])
            gaussian_y = gaussian(gaussian_x,paras[0],paras[1],paras[2])-np.min(temp_tuning)*(-1)
            width = paras[2]

            ax.plot(gaussian_x, gaussian_y, color=color,\
                label = mature_key+' curve_width:%.2f'%(width*2))

    if height_ttest:
        maturation = list(tuning_store.keys())
        title += '\nHeight independent T-test p-value: '
        for i in range(len(maturation)-1):
            for j in range(i+1,len(maturation)):
                t_h, p_h = ttest_ind(height_store[maturation[i]],height_store[maturation[j]])
                title += maturation[i]+'-'+maturation[j]+':%.1e '%(p_h)

    fig.suptitle(title)
    ax.legend(bbox_to_anchor=(1.05, 0), loc=3, borderaxespad=0)
    ax.set_ylabel('activity')
    #ax.set_xlabel('loc')

    save_path = 'figure/figure_'+model_dir.rstrip('/').split('/')[-1]+'/'+rule+'/'+epoch+'/'+'_'.join(n_types)+'/'
    mkdir_p(save_path)
    plt.savefig(save_path+rule+'_'+epoch+'_'+trial_range+'_step_'+step+'_tuning_analysis.pdf',bbox_inches='tight')
    plt.savefig(save_path+rule+'_'+epoch+'_'+trial_range+'_step_'+step+'_tuning_analysis.eps',bbox_inches='tight')
    plt.savefig(save_path+rule+'_'+epoch+'_'+trial_range+'_step_'+step+'_tuning_analysis.png',bbox_inches='tight')

    plt.close()
Пример #6
0
def train(
    model_dir,
    hp=None,
    max_steps=1e7,
    display_step=500,
    ruleset='all_new',
    rule_trains=None,
    rule_prob_map=None,
    seed=0,
    load_dir=None,
    trainables=None,
):
    """Train the network.

    Args:
        model_dir: str, training directory
        hp: dictionary of hyperparameters
        max_steps: int, maximum number of training steps
        display_step: int, display steps
        ruleset: the set of rules to train
        rule_trains: list of rules to train, if None then all rules possible
        rule_prob_map: None or dictionary of relative rule probability
        seed: int, random seed to be used

    Returns:
        model is stored at model_dir/trial_number/model.ckpt
        training configuration is stored at model_dir/hp.json
    """

    tools.mkdir_p(model_dir)

    # Network parameters
    default_hp = get_default_hp(ruleset)
    if hp is not None:
        default_hp.update(hp)
    hp = default_hp
    hp['seed'] = seed
    hp['rng'] = np.random.RandomState(seed)

    # Rules to train and test. Rules in a set are trained together
    if rule_trains is None:
        # By default, training all rules available to this ruleset
        hp['rule_trains'] = task.rules_dict[ruleset]
    else:
        hp['rule_trains'] = rule_trains
    hp['rules'] = hp['rule_trains']

    # Assign probabilities for rule_trains.
    if rule_prob_map is None:
        rule_prob_map = dict()

    # Turn into rule_trains format
    hp['rule_probs'] = None
    if hasattr(hp['rule_trains'], '__iter__'):
        # Set default as 1.
        rule_prob = np.array(
            [rule_prob_map.get(r, 1.) for r in hp['rule_trains']])
        hp['rule_probs'] = list(rule_prob / np.sum(rule_prob))
    tools.save_hp(hp, model_dir)

    # Build the model
    model = Model(model_dir, hp=hp)

    # Display hp
    for key, val in hp.items():
        print('{:20s} = '.format(key) + str(val))

    # Store results
    log = defaultdict(list)
    log['model_dir'] = model_dir

    # Record time
    t_start = time.time()

    with tf.Session() as sess:
        if load_dir is not None:
            model.restore(load_dir)  # complete restore
        else:
            # Assume everything is restored
            sess.run(tf.global_variables_initializer())

        # Set trainable parameters
        if trainables is None or trainables == 'all':
            var_list = model.var_list  # train everything
        elif trainables == 'input':
            # train all nputs
            var_list = [
                v for v in model.var_list
                if ('input' in v.name) and ('rnn' not in v.name)
            ]
        elif trainables == 'rule':
            # train rule inputs only
            var_list = [v for v in model.var_list if 'rule_input' in v.name]
        else:
            raise ValueError('Unknown trainables')
        model.set_optimizer(var_list=var_list)

        # penalty on deviation from initial weight
        if hp['l2_weight_init'] > 0:
            anchor_ws = sess.run(model.weight_list)
            for w, w_val in zip(model.weight_list, anchor_ws):
                model.cost_reg += (hp['l2_weight_init'] *
                                   tf.nn.l2_loss(w - w_val))

            model.set_optimizer(var_list=var_list)

        # partial weight training
        if ('p_weight_train' in hp and (hp['p_weight_train'] is not None)
                and hp['p_weight_train'] < 1.0):
            for w in model.weight_list:
                w_val = sess.run(w)
                w_size = sess.run(tf.size(w))
                w_mask_tmp = np.linspace(0, 1, w_size)
                hp['rng'].shuffle(w_mask_tmp)
                ind_fix = w_mask_tmp > hp['p_weight_train']
                w_mask = np.zeros(w_size, dtype=np.float32)
                w_mask[ind_fix] = 1e-1  # will be squared in l2_loss
                w_mask = tf.constant(w_mask)
                w_mask = tf.reshape(w_mask, w.shape)
                model.cost_reg += tf.nn.l2_loss((w - w_val) * w_mask)
            model.set_optimizer(var_list=var_list)

        step = 0
        while 1:  #step * hp['batch_size_train'] <= max_steps:
            try:
                # Validation
                if step % display_step == 0:
                    trial_number = step * hp[
                        'batch_size_train']  # add by yichen
                    log['trials'].append(trial_number)
                    tools.mkdir_p(model_dir + '/' +
                                  str(trial_number))  # add by yichen
                    log['times'].append(time.time() - t_start)
                    log = do_eval(sess, model, log, hp['rule_trains'])
                    #check if minimum performance is above target

                    if log['perf_min'][-1] > model.hp['target_perf']:
                        print('Perf reached the target: {:0.2f}'.format(
                            hp['target_perf']))
                        break

                # Training
                rule_train_now = hp['rng'].choice(hp['rule_trains'],
                                                  p=hp['rule_probs'])
                # Generate a random batch of trials.
                # Each batch has the same trial length
                trial = generate_trials(rule_train_now,
                                        hp,
                                        'random',
                                        batch_size=hp['batch_size_train'])

                # Generating feed_dict.
                feed_dict = tools.gen_feed_dict(model, trial, hp)
                sess.run(model.train_step, feed_dict=feed_dict)

                step += 1

            except KeyboardInterrupt:
                print("Optimization interrupted by user")
                break

        print("Optimization finished!")
Пример #7
0
def sample_neuron_by_trial(hp,log,model_dir,rule,epoch,trial_list,n_type,):

    with open(model_dir+'/task_info.pkl','rb') as tinf:
        task_info = pickle.load(tinf)

    save_root_folder = 'figure/figure_'+model_dir.rstrip('/').split('/')[-1]+'/'+rule+'/'+epoch+'/sample_neuron/'
    tools.mkdir_p(save_root_folder)

    is_dict = False
    is_list = False
    if isinstance(trial_list, dict):
        temp_list = list()
        is_dict = True
        for value in trial_list[rule].values():
            temp_list += value
        temp_list = sorted(set(temp_list))
    elif isinstance(trial_list, list):
        temp_list = trial_list
        is_list = True

    for trial_num in temp_list:
        H = Get_H(hp,model_dir,trial_num,rule,save_H=False,task_mode='test',)

        with open(model_dir+'/'+str(trial_num)+'/neuron_info_'+rule+'_'+epoch+'.pkl','rb') as inf:
            neuron_info = pickle.load(inf)

        n_list = neuron_info[n_type]
        if not n_list: #if empty
            continue
        #sample_n:(neuron,p,sel_dir)

        perf = log['perf_'+rule][trial_num//log['trials'][1]]
        if (is_list and perf > hp['mid_target_perf']) or (is_dict and trial_num in trial_list[rule]['mature']):
            color = 'red'
            save_folder = save_root_folder+str(trial_num)+'mature/'+n_type+'/'
        elif (is_list and perf > hp['early_target_perf']) or (is_dict and trial_num in trial_list[rule]['mid']):
            color = 'blue'
            save_folder = save_root_folder+str(trial_num)+'mid/'+n_type+'/'
        elif is_list or (is_dict and trial_num in trial_list[rule]['early']):
            color = 'green'
            save_folder = save_root_folder+str(trial_num)+'early/'+n_type+'/'

        tools.mkdir_p(save_folder)

        posi_list = [1,2,5,8,7,6,3,0]
        period_mean = [0 for _ in range(9)]
        for sample_n in n_list:
            fig,ax = plt.subplots(3,3,figsize=(10,10))

            max_ = 0
            min_ = 0
            psth = dict()
            time = np.arange(len(H[:,0,0]))*hp['dt']/1000

            for loc in task_info[rule]['in_loc_set']:

                psth[loc] = H[:,task_info[rule]['in_loc'] == loc,sample_n[0]].mean(axis=1)
                period_mean[loc] = \
                    H[task_info[rule]['epoch_info'][epoch][0]:task_info[rule]['epoch_info'][epoch][1],task_info[rule]['in_loc'] == loc,sample_n[0]].mean()

                max_temp = np.max(psth[loc])
                min_temp = np.min(psth[loc])

                if max_temp>max_:
                    max_ = max_temp
                if min_temp<min_:
                    min_ = min_temp

            period_mean[-1] = period_mean[0]
            period_mean = np.array(period_mean)
            period_mean /= period_mean.max()

            for loc in task_info[rule]['in_loc_set']:
                ax[posi_list[loc]//3][posi_list[loc]%3].set_ylim(min_-0.1*abs(max_),max_+0.1*abs(max_))
                ax[posi_list[loc]//3][posi_list[loc]%3].plot(time,psth[loc],color=color)
                ax[posi_list[loc]//3][posi_list[loc]%3].set_xticks(np.arange(0,np.max(time),1))
                if loc in [1,2,3]:
                    ax[posi_list[loc]//3][posi_list[loc]%3].yaxis.set_ticks_position('right')
                if loc in [7,0,1]:
                    ax[posi_list[loc]//3][posi_list[loc]%3].xaxis.set_ticks_position('top')

            axis = plt.subplot(3,3,5,projection='polar')
            axis.set_theta_zero_location('N')
            axis.set_theta_direction(-1)
            
            axis.set_yticks([])
            #axis.set_xticks([])

            theta1 = np.arange(0, 2 * np.pi + 0.00000001, np.pi / 4)
            axis.plot(theta1,period_mean,color=color)

            theta2 = np.arange(0,2*np.pi,2*np.pi/360)
            sel_dir_point = np.zeros(360)
            sel_dir_point[sample_n[2]*45] = 1
            axis.plot(theta2,sel_dir_point,color='black')

            title = 'Rule:'+rule+' Epoch:'+epoch+' Neuron:'+str(sample_n[0])+' SelectDir:'+str(sample_n[2])+\
                ' Perf:'+str(perf)[:4]
            plt.suptitle(title)

            plt.savefig(save_folder+str(sample_n[0])+'.png',transparent=False)
            plt.savefig(save_folder+str(sample_n[0])+'.eps',transparent=False)
            plt.savefig(save_folder+str(sample_n[0])+'.pdf',transparent=False)

            plt.close()

            #polar only
            fig_p,ax_p = plt.subplots(figsize=(10,11),subplot_kw=dict(projection='polar'))
            ax_p.set_theta_zero_location('N')
            ax_p.set_theta_direction(-1)
            
            ax_p.set_yticks([])
            #ax_p.set_xticks([])

            ax_p.plot(theta1,period_mean,color=color)
            ax_p.plot(theta2,sel_dir_point,color='black')

            plt.suptitle(title)
            plt.savefig(save_folder+str(sample_n[0])+'_polar.png',transparent=False)
            plt.savefig(save_folder+str(sample_n[0])+'_polar.eps',transparent=False)
            plt.savefig(save_folder+str(sample_n[0])+'_polar.pdf',transparent=False)

            plt.close()
def plot_PSTH_neuron_selected(
    hp,
    log,
    model_dir,
    rule,
    epoch,
    trial_list,
    selected_n_list=None,
    n_types_selected=('exh_neurons', 'mix_neurons'),
    n_types_intersection_limit=(
        'selective_neurons',
    ),  #('exh_neurons','mix_neurons'),#('active_neurons',),
    plot_oppo_dir=False,
    norm=True,
    PSTH_log=None,
):

    print("Start ploting neuron selected PSTH")
    print("\trule: " + rule + " selective epoch: " + epoch)

    is_dict = False
    is_list = False
    if isinstance(trial_list, dict):
        temp_list = list()
        is_dict = True
        for value in trial_list[rule].values():
            temp_list += value
        temp_list = sorted(set(temp_list))
    elif isinstance(trial_list, list):
        temp_list = trial_list
        is_list = True

    if selected_n_list is None:
        n_info_file = model_dir + '/' + str(
            temp_list[-1]) + '/neuron_info_' + rule + '_' + epoch + '.pkl'
        with open(n_info_file, 'rb') as ninf:
            n_info = pickle.load(ninf)

        n_list = list()
        for n_type in n_types_selected:
            n_list = list(set(n_list + n_info[n_type]))

        selected_n_list = [n[0] for n in n_list]

    for trial_num in temp_list:
        n_info_file = model_dir + '/' + str(
            trial_num) + '/neuron_info_' + rule + '_' + epoch + '.pkl'
        with open(n_info_file, 'rb') as ninf:
            n_info = pickle.load(ninf)

        n_index_temp = list()
        for n_type in n_types_intersection_limit:
            n_index_temp = list(set(n_index_temp + n_info[n_type]))
        n_index_temp = [n[0] for n in n_index_temp]

        selected_n_list = list(set(selected_n_list) & set(n_index_temp))

    print("%d neuron(s) selected" % (len(selected_n_list)))

    n_number = dict()

    if PSTH_log is None:
        PSTH_log = gen_PSTH_log_neuron_selected(trial_list,model_dir,rule,epoch,\
            selected_n_list=selected_n_list,n_types=n_types_intersection_limit,norm=norm)

        if plot_oppo_dir:
            PSTH_log_oppo = gen_PSTH_log_neuron_selected(trial_list,model_dir,rule,epoch,\
                selected_n_list=selected_n_list,n_types=n_types_intersection_limit,norm=norm,oppo_sel_dir=plot_oppo_dir)
            for key, value in PSTH_log_oppo.items():
                PSTH_log_oppo[key] = value.mean(axis=0)

    for key, value in PSTH_log.items():
        n_number[key] = np.size(value, 0)
        PSTH_log[key] = value.mean(axis=0)

    data_to_plot = dict()
    maturation = ["early", "mid", "mature"]
    data_types = ["PSTH", "n_num", "growth"]
    if plot_oppo_dir:
        data_types.append("PSTH_oppo")

    for m_key in maturation:
        data_to_plot[m_key] = dict()
        for data_type in data_types:
            data_to_plot[m_key][data_type] = list()

    for trial_num in temp_list:
        growth = log['perf_' + rule][trial_num // log['trials'][1]]

        if (is_list and growth > hp['mid_target_perf']) or (
                is_dict and trial_num in trial_list[rule]['mature']):
            m_key = "mature"
        elif (is_list and growth > hp['early_target_perf']) or (
                is_dict and trial_num in trial_list[rule]['mid']):
            m_key = "mid"
        elif is_list or (is_dict and trial_num in trial_list[rule]['early']):
            m_key = "early"

        data_to_plot[m_key]["PSTH"].append(PSTH_log[trial_num])
        data_to_plot[m_key]["growth"].append(growth)
        data_to_plot[m_key]["n_num"].append(n_number[trial_num])
        if plot_oppo_dir:
            data_to_plot[m_key]["PSTH_oppo"].append(PSTH_log_oppo[trial_num])

    for m_key in maturation:
        for data_type in data_types:
            data_to_plot[m_key][data_type] = np.array(
                data_to_plot[m_key][data_type]).mean(axis=0)

    # plot #
    save_path = 'figure/figure_' + model_dir.rstrip('/').split(
        '/')[-1] + '/' + rule + '/' + epoch + '/'
    if is_dict or len(temp_list) == 1:
        step = 'None'
    else:
        step = str(temp_list[1] - temp_list[0])

    if is_dict:
        trial_range = 'auto_choose'
    else:
        trial_range = str((temp_list[0], temp_list[-1]))
    title = 'Rule:' + rule + ' Epoch:' + epoch + ' trial range:' + trial_range + ' step:' + step

    colors = {
        "early": "green",
        "mid": "blue",
        "mature": "red",
    }

    fig, ax = plt.subplots(figsize=(14, 10))
    fig.suptitle(title)

    for m_key in maturation:
        ax.plot(np.arange(len(data_to_plot[m_key]["PSTH"]))*hp['dt']/1000, data_to_plot[m_key]["PSTH"],\
            label= m_key+'_%.2f'%(data_to_plot[m_key]["growth"])+'_n%d'%(data_to_plot[m_key]["n_num"]), color=colors[m_key])
        if plot_oppo_dir:
            ax.plot(np.arange(len(data_to_plot[m_key]["PSTH_oppo"]))*hp['dt']/1000, data_to_plot[m_key]["PSTH_oppo"],\
                label= m_key+'_opposite_sel_dir', color=colors[m_key], linestyle = '--')

    ax.set_xlabel("time/s")
    ax.set_ylabel("activity")
    ax.legend(bbox_to_anchor=(1.05, 0), loc=3, borderaxespad=0)

    mkdir_p(save_path)
    plt.savefig(save_path + rule + '_' + epoch + '_' + trial_range + '_step_' +
                step + 'neuron_selected_PSTH.pdf',
                bbox_inches='tight')
    plt.savefig(save_path + rule + '_' + epoch + '_' + trial_range + '_step_' +
                step + 'neuron_selected_PSTH.eps',
                bbox_inches='tight')
    plt.savefig(save_path + rule + '_' + epoch + '_' + trial_range + '_step_' +
                step + 'neuron_selected_PSTH.png',
                bbox_inches='tight')

    plt.close()

    with open(
            save_path + rule + '_' + epoch + '_' + trial_range + '_step_' +
            step + "selected_neurons.txt", "w") as tf:
        tf.write(str(selected_n_list))