def easy_activity_plot(model_dir, rule): """A simple plot of neural activity from one task. Args: model_dir: directory where model file is saved rule: string, the rule to plot """ model = Model(model_dir) hp = model.hp with tf.Session() as sess: model.restore() trial = generate_trials(rule, hp, mode='test') feed_dict = tools.gen_feed_dict(model, trial, hp) h, y_hat = sess.run([model.h, model.y_hat], feed_dict=feed_dict) # All matrices have shape (n_time, n_condition, n_neuron) # Take only the one example trial i_trial = 0 for activity, title in zip([trial.x, h, y_hat], ['input', 'recurrent', 'output']): plt.figure() plt.imshow(activity[:, i_trial, :].T, aspect='auto', cmap='hot', interpolation='none', origin='lower') plt.title(title) plt.colorbar() plt.show()
def quick_statespace(model_dir): """Quick state space analysis using simply PCA.""" rules = ['contextdm1', 'contextdm2'] h_lastts = dict() model = Model(model_dir) hp = model.hp with tf.Session() as sess: model.restore() for rule in rules: # Generate a batch of trial from the test mode trial = generate_trials(rule, hp, mode='test') feed_dict = tools.gen_feed_dict(model, trial, hp) h = sess.run(model.h, feed_dict=feed_dict) lastt = trial.epochs['stim1'][-1] h_lastts[rule] = h[lastt,:,:] from sklearn.decomposition import PCA model = PCA(n_components=5) model.fit(np.concatenate(h_lastts.values(), axis=0)) fig = plt.figure(figsize=(2,2)) ax = fig.add_axes([.3, .3, .6, .6]) for rule, color in zip(rules, ['red', 'blue']): data_trans = model.transform(h_lastts[rule]) ax.scatter(data_trans[:, 0], data_trans[:, 1], s=1, label=rule_name[rule], color=color) plt.tick_params(axis='both', which='major', labelsize=7) ax.set_xlabel('PC 1', fontsize=7) ax.set_ylabel('PC 2', fontsize=7) lg = ax.legend(fontsize=7, ncol=1, bbox_to_anchor=(1,0.3), loc=1, frameon=False) if save: plt.savefig('figure/choiceatt_quickstatespace.pdf',transparent=True)
def run_network_replacerule(model_dir, rule, replace_rule, rule_strength): """Run the network but with replaced rule input weights. Args: model_dir: model directory rule: the rule to test on replace_rule: a list of rule input units to use rule_strength: the relative strength of each replace rule unit """ model = Model(model_dir) hp = model.hp with tf.Session() as sess: model.restore() # Get performance batch_size_test = 1000 n_rep = 20 batch_size_test_rep = int(batch_size_test / n_rep) perf_rep = list() for i_rep in range(n_rep): trial = generate_trials(rule, hp, 'random', batch_size=batch_size_test_rep, replace_rule=replace_rule, rule_strength=rule_strength) feed_dict = tools.gen_feed_dict(model, trial, hp) y_hat_test = sess.run(model.y_hat, feed_dict=feed_dict) perf_rep.append(np.mean(get_perf(y_hat_test, trial.y_loc))) return np.mean(perf_rep), rule_strength
def easy_connectivity_plot(model_dir): """A simple plot of network connectivity.""" model = Model(model_dir) with tf.Session() as sess: model.restore() # get all connection weights and biases as tensorflow variables var_list = model.var_list # evaluate the parameters after training params = [sess.run(var) for var in var_list] # get name of each variable names = [var.name for var in var_list] # Plot weights for param, name in zip(params, names): if len(param.shape) != 2: continue vmax = np.max(abs(param)) * 0.7 plt.figure() # notice the transpose plt.imshow(param.T, aspect='auto', cmap='bwr', vmin=-vmax, vmax=vmax, interpolation='none', origin='lower') plt.title(name) plt.colorbar() plt.xlabel('From') plt.ylabel('To') plt.show()
def __init__(self, model_dir, rules=None): """Initialization. Args: model_dir: str, model directory rules: None or a list of rules """ # Stimulus-averaged traces h_stimavg_byrule = OrderedDict() h_stimavg_byepoch = OrderedDict() # Last time points of epochs h_lastt_byepoch = OrderedDict() model = Model(model_dir) hp = model.hp if rules is None: # Default value rules = hp['rules'] n_rules = len(rules) with tf.Session() as sess: model.restore() for rule in rules: trial = generate_trials(rule=rule, hp=hp, mode='test') feed_dict = tools.gen_feed_dict(model, trial, hp) h = sess.run(model.h, feed_dict=feed_dict) # Average across stimulus conditions h_stimavg = h.mean(axis=1) # dt_new = 50 # every_t = int(dt_new/hp['dt']) t_start = int( 500 / hp['dt']) # Important: Ignore the initial transition # Average across stimulus conditions h_stimavg_byrule[rule] = h_stimavg[t_start:, :] for e_name, e_time in trial.epochs.items(): if 'fix' in e_name: continue # if ('fix' not in e_name) and ('go' not in e_name): # Take epoch e_time_start = e_time[0] - 1 if e_time[0] > 0 else 0 h_stimavg_byepoch[( rule, e_name)] = h_stimavg[e_time_start:e_time[1], :] # Take last time point from epoch # h_all_byepoch[(rule, e_name)] = np.mean(h[e_time[0]:e_time[1],:,:][-1], axis=1) h_lastt_byepoch[(rule, e_name)] = h[e_time[1], :, :] self.rules = rules self.h_stimavg_byrule = h_stimavg_byrule self.h_stimavg_byepoch = h_stimavg_byepoch self.h_lastt_byepoch = h_lastt_byepoch self.model_dir = model_dir
def run_simulation(save_name, setting): '''Generate simulation data for all trials''' tf.reset_default_graph() model = Model(save_name, sigma_rec=setting['sigma_rec'], dt=10) with tf.Session() as sess: model.restore(sess) Data = _run_simulation(model, setting) return Data
def _compute_variance(model_dir, rules=None, random_rotation=False): """Compute variance for all tasks. Args: model_dir: str, the path of the model directory rules: list of rules to compute variance, list of strings random_rotation: boolean. If True, rotate the neural activity. """ model = Model(model_dir, sigma_rec=0) with tf.Session() as sess: model.restore() _compute_variance_bymodel(model, sess, rules, random_rotation)
def compute_H(self, rules=None, trial_list=None, recompute=False,): if rules is not None: self.rules = rules else: self.rules = self.hp['rule_trains'] if trial_list is not None: self.trial_list = trial_list else: self.trial_list = self.log['trials'] self.in_loc = dict() self.in_loc_set = dict() self.epoch_info = dict() trial_store = dict() #self.trial_store = dict()########################## do we really need self.? print("Epoch information:") for rule in self.rules: trial_store[rule] = generate_trials(rule, self.hp, 'test', noise_on=False) self.in_loc[rule] = np.array([np.argmax(i) for i in trial_store[rule].input_loc]) self.in_loc_set[rule] = sorted(set(self.in_loc[rule])) self.epoch_info[rule] = trial_store[rule].epochs #self.trial_store[rule] = generate_trials(rule, self.hp, 'test', noise_on=False) #self.in_loc[rule] = np.array([np.argmax(i) for i in self.trial_store[rule].input_loc]) print('\t'+rule+':') for e_name, e_time in self.epoch_info[rule].items(): print('\t\t'+e_name+':',e_time) for trial_num in self.trial_list: sub_dir = self.model_dir+'/'+str(trial_num)+'/' for rule in self.rules: if recompute or not os.path.exists(sub_dir+'H_'+rule+'.pkl'): model = Model(sub_dir, hp=self.hp) with tf.Session() as sess: model.restore() self._compute_H(model, rule, trial_store[rule], sess,)
def _psychometric_dm(model_dir, rule, params_list, batch_shape): """Base function for computing psychometric performance in 2AFC tasks Args: model_dir : model name rule : task to analyze params_list : a list of parameter dictionaries used for the psychometric mode batch_shape : shape of each batch. Each batch should have shape (n_rep, ...) n_rep is the number of repetitions that will be averaged over Return: ydatas: list of performances """ print('Starting psychometric analysis of the {:s} task...'.format( rule_name[rule])) model = Model(model_dir) hp = model.hp with tf.Session() as sess: model.restore() ydatas = list() for params in params_list: trial = generate_trials(rule, hp, 'psychometric', params=params) feed_dict = tools.gen_feed_dict(model, trial, hp) y_loc_sample = sess.run(model.y_hat_loc, feed_dict=feed_dict) y_loc_sample = np.reshape(y_loc_sample[-1], batch_shape) stim1_locs_ = np.reshape(params['stim1_locs'], batch_shape) stim2_locs_ = np.reshape(params['stim2_locs'], batch_shape) # Average over the first dimension of each batch choose1 = (get_dist(y_loc_sample - stim1_locs_) < THETA).sum( axis=0) choose2 = (get_dist(y_loc_sample - stim2_locs_) < THETA).sum( axis=0) ydatas.append(choose1 / (choose1 + choose2)) return ydatas
def activity_histogram(model_dir, rules, title=None, save_name=None): """Plot the activity histogram.""" if isinstance(rules, str): rules = [rules] h_all = None model = Model(model_dir) hp = model.hp with tf.Session() as sess: model.restore() t_start = int(500/hp['dt']) for rule in rules: # Generate a batch of trial from the test mode trial = generate_trials(rule, hp, mode='test') feed_dict = tools.gen_feed_dict(model, trial, hp) h = sess.run(model.h, feed_dict=feed_dict) h = h[t_start:, :, :] if h_all is None: h_all = h else: h_all = np.concatenate((h_all, h), axis=1) # var = h_all.var(axis=0).mean(axis=0) # ind = var > 1e-2 # h_plot = h_all[:, :, ind].flatten() h_plot = h_all.flatten() fig = plt.figure(figsize=(1.5, 1.2)) ax = fig.add_axes([0.2, 0.2, 0.7, 0.6]) ax.hist(h_plot, bins=20, density=True) ax.set_xlabel('Activity', fontsize=7) [ax.spines[s].set_visible(False) for s in ['left', 'top', 'right']] ax.set_yticks([])
def networkx_illustration(model_dir): import networkx as nx model = Model(model_dir) with tf.Session() as sess: model.restore() # get all connection weights and biases as tensorflow variables w_rec = sess.run(model.w_rec) w_rec_flat = w_rec.flatten() ind_sort = np.argsort(abs(w_rec_flat - np.mean(w_rec_flat))) n_show = int(0.01 * len(w_rec_flat)) ind_gone = ind_sort[:-n_show] ind_keep = ind_sort[-n_show:] w_rec_flat[ind_gone] = 0 w_rec2 = np.reshape(w_rec_flat, w_rec.shape) w_rec_keep = w_rec_flat[ind_keep] G = nx.from_numpy_array(abs(w_rec2), create_using=nx.DiGraph()) color = w_rec_keep fig = plt.figure(figsize=(4, 4)) ax = fig.add_axes([0.1, 0.1, 0.8, 0.8]) nx.draw(G, linewidths=0, width=0.1, alpha=1.0, edge_vmin=-3, edge_vmax=3, arrows=False, pos=nx.circular_layout(G), node_color=np.array([99. / 255] * 3), node_size=10, edge_color=color, edge_cmap=plt.cm.RdBu_r, ax=ax) plt.savefig('figure/illustration_networkx.pdf', transparent=True)
def _plot_inout_connections(self, conn_type): """Plot connectivity while sorting by group. Args: conn_type: str, type of connectivity to plot. """ # Sort data by labels and by input connectivity model = Model(self.model_dir) hp = model.hp with tf.Session() as sess: model.restore() w_in, w_out = sess.run([model.w_in, model.w_out]) n_ring = hp['n_eachring'] groups = ['1', '2', '12'] # Plot input from stim or output to loc if conn_type == 'input': w_conn = w_in[1:n_ring+1, :].T xlabel = 'Preferred mod 1 input dir.' ylabel = 'Conn. weight\n from mod 1' lgtitle = 'To group' elif conn_type == 'output': w_conn = w_out[:, 1:] xlabel = 'Preferred output dir.' ylabel = 'Conn. weight to output' lgtitle = 'From group' else: raise ValueError('Unknown conn type') w_aves = dict() for group in groups: ind_group = self.group_ind_orig[group] n_group = len(ind_group) w_group = np.zeros((n_group, n_ring)) for i, ind in enumerate(ind_group): tmp = w_conn[ind, :] ind_max = np.argmax(tmp) w_group[i, :] = np.roll(tmp, int(n_ring/2)-ind_max) w_aves[group] = w_group.mean(axis=0) fs = 6 fig = plt.figure(figsize=(1.5, 1.0)) ax = fig.add_axes([.35, .25, .55, .6]) for group in groups: ax.plot(w_aves[group], color=self.colors[group], label=group, lw=1) ax.set_xticks([int(n_ring/2)]) ax.set_xticklabels([xlabel]) # ax.set_xlabel(xlabel, fontsize=fs, labelpad=3) ax.set_ylabel(ylabel, fontsize=fs) lg = ax.legend(title=lgtitle, fontsize=fs, bbox_to_anchor=(1.2,1.2), labelspacing=0.2, loc=1, frameon=False) plt.setp(lg.get_title(),fontsize=fs) ax.tick_params(axis='both', which='major', labelsize=fs) plt.locator_params(axis='y',nbins=3) ax.spines["right"].set_visible(False) ax.spines["top"].set_visible(False) ax.xaxis.set_ticks_position('bottom') ax.yaxis.set_ticks_position('left') plt.savefig('figure/conn_'+conn_type+'_contextdm.pdf', transparent=True)
def plot_rec_connections(self): """Plot connectivity while sorting by group. Args: conn_type: str, type of connectivity to plot. """ # Sort data by labels and by input connectivity model = Model(self.model_dir) hp = model.hp with tf.Session() as sess: model.restore() w_in, w_rec = sess.run([model.w_in, model.w_rec]) w_in, w_rec = w_in.T, w_rec.T n_ring = hp['n_eachring'] groups = ['1', '2', '12'] w_in_ = (w_in[:, 1:n_ring + 1] + w_in[:, 1+n_ring:2*n_ring+1]) / 2. # Plot recurrent connectivity # w_rec_group = np.zeros((len(groups), len(groups))) # for i1, group1 in enumerate(groups): # for i2, group2 in enumerate(groups): # ind1 = self.group_ind_orig[group1] # ind2 = self.group_ind_orig[group2] # w_rec_group[i2, i1] = w_rec[:, ind1][ind2, :].mean() i_pairs = list() for i1 in range(len(groups)): for i2 in range(len(groups)): i_pairs.append((i1, i2)) pref_diffs_list = list() w_recs_list = list() w_rec_bygroup = np.zeros((len(groups), len(groups))) inds = [self.group_ind_orig[g] for g in groups] for i_pair in i_pairs: ind1, ind2 = inds[i_pair[0]], inds[i_pair[1]] # For each neuron get the preference based on input weight # sort by weights w_sortby = w_in_ # w_sortby = w_out_ prefs1 = np.argmax(w_sortby[ind1, :], axis=1)*2.*np.pi/n_ring prefs2 = np.argmax(w_sortby[ind2, :], axis=1)*2.*np.pi/n_ring # Compute the pairwise distance based on preference # Then get the connection weight between pairs pref_diffs = list() w_recs = list() for i, ind_i in enumerate(ind1): for j, ind_j in enumerate(ind2): if ind_j == ind_i: # Excluding self connections, which tend to be positive continue pref_diffs.append(get_dist(prefs1[i]-prefs2[j])) # pref_diffs.append(prefs1[i]-prefs2[j]) w_recs.append(w_rec[ind_j, ind_i]) pref_diffs, w_recs = np.array(pref_diffs), np.array(w_recs) pref_diffs_list.append(pref_diffs) w_recs_list.append(w_recs) w_rec_bygroup[i_pair[1], i_pair[0]] = np.mean(w_recs[pref_diffs<np.pi/6.]) fs = 6 vmax = np.ceil(np.max(w_rec_bygroup) * 100) / 100. vmin = np.floor(np.min(w_rec_bygroup) * 100) / 100. fig = plt.figure(figsize=(1.5, 1.5)) ax = fig.add_axes([0.2, 0.1, 0.5, 0.5]) im = ax.imshow(w_rec_bygroup, interpolation='nearest', cmap='coolwarm', aspect='auto', vmin=vmin, vmax=vmax) # ax.axis('off') ax.xaxis.set_label_position("top") ax.xaxis.set_ticks_position("top") plt.xticks([0, 1, 2], groups, fontsize=6) plt.yticks([0, 1, 2], groups, fontsize=6) ax.tick_params('both', length=0) ax.set_xlabel('From', fontsize=fs, labelpad=2) ax.set_ylabel('To', fontsize=fs, labelpad=2) for s in ['right', 'left', 'top', 'bottom']: ax.spines[s].set_visible(False) ax = fig.add_axes([0.72, 0.1, 0.03, 0.5]) cb = plt.colorbar(im, cax=ax, ticks=[vmin,vmax]) cb.outline.set_linewidth(0.5) cb.set_label(r'Rec. weight', fontsize=fs, labelpad=-7) plt.tick_params(axis='both', which='major', labelsize=fs) plt.locator_params(nbins=3) plt.savefig('figure/conn_rec_contextdm.pdf', transparent=True)
def compute_taskspace(model_dir, setup, restore=False, representation='rate'): if setup == 1: rules = ['fdgo', 'fdanti', 'delaygo', 'delayanti'] elif setup == 2: rules = [ 'contextdelaydm1', 'contextdelaydm2', 'contextdm1', 'contextdm2' ] elif setup == 3: rules = ['dmsgo', 'dmcgo', 'dmsnogo', 'dmcnogo'] elif setup == 4: rules = [ 'contextdelaydm1', 'contextdelaydm2', 'multidelaydm', 'contextdm1', 'contextdm2', 'multidm' ] elif setup == 5: rules = [ 'contextdelaydm1', 'contextdelaydm2', 'multidelaydm', 'delaydm1', 'delaydm2', 'contextdm1', 'contextdm2', 'multidm', 'dm1', 'dm2', ] elif setup == 6: rules = ['fdgo', 'delaygo', 'contextdm1', 'contextdelaydm1'] if representation == 'rate': fname = 'taskset{:d}_space'.format(setup) + '.pkl' fname = os.path.join(model_dir, fname) if restore and os.path.isfile(fname): print('Reloading results from ' + fname) h_trans = tools.load_pickle(fname) else: tsa = TaskSetAnalysis(model_dir, rules=rules) h_trans = tsa.compute_taskspace(rules=rules, epochs=['stim1'], dim_reduction_type='PCA', setup=setup) with open(fname, 'wb') as f: pickle.dump(h_trans, f) print('Results stored at : ' + fname) elif representation == 'weight': from task import get_rule_index model = Model(model_dir) hp = model.hp n_hidden = hp['n_rnn'] n_output = hp['n_output'] with tf.Session() as sess: model.restore() w_in = sess.run(model.w_in).T rule_indices = [get_rule_index(r, hp) for r in rules] w_rules = w_in[:, rule_indices] from sklearn.decomposition import PCA model = PCA(n_components=2) # Transform data data_trans = model.fit_transform(w_rules.T) # Turn into dictionary, and consistent with previous code h_trans = OrderedDict() for i, r in enumerate(rules): # shape will be (1,2), and the key is added an epoch value only for consistency h_trans[(r, 'stim1')] = np.array([data_trans[i]]) else: raise ValueError() return h_trans
def load_data(model_dir=None, sigma_rec=0, lesion_units=None, n_rep=1): """Generate model data into standard format. Returns: data: standard format, list of dict of arrays/dict list is over neurons dict is for response array and task variable dict response array has shape (n_trial, n_time) """ if model_dir is None: model_dir = './mantetemp' # TEMPORARY SETTING # Get rules and regressors rules = ['contextdm1', 'contextdm2'] n_rule = len(rules) data = list() model = Model(model_dir, sigma_rec=sigma_rec) hp = model.hp with tf.Session() as sess: model.restore() if lesion_units is not None: model.lesion_units(sess, lesion_units) # Generate task parameters used # Target location stim1_loc_list = np.arange(0, 2*np.pi, 2*np.pi/12) for stim1_loc in stim1_loc_list: params, batch_size = _gen_taskparams(stim1_loc=stim1_loc, n_rep=n_rep) stim1_locs_tmp = np.tile(params['stim1_locs'], n_rule) x = list() # Network input y_loc = list() # Network target output location # Start computing the neural activity for i, rule in enumerate(rules): # Generating task information trial = generate_trials(rule, hp, 'psychometric', params=params, noise_on=True) x.append(trial.x) y_loc.append(trial.y_loc) x = np.concatenate(x, axis=1) y_loc = np.concatenate(y_loc, axis=1) # Coherences stim_mod1_cohs = params['stim1_mod1_strengths'] - params[ 'stim2_mod1_strengths'] stim_mod2_cohs = params['stim1_mod2_strengths'] - params[ 'stim2_mod2_strengths'] stim_mod1_cohs /= stim_mod1_cohs.max() stim_mod2_cohs /= stim_mod2_cohs.max() # Get neural activity fetches = [model.h, model.y_hat, model.y_hat_loc] H, y_sample, y_sample_loc = sess.run( fetches, feed_dict={model.x: x}) # Downsample in time dt_new = 50 every_t = int(dt_new / hp['dt']) # Only analyze the target epoch epoch = trial.epochs['stim1'] H = H[epoch[0]:epoch[1], ...][int(every_t / 2)::every_t, ...] # Get performance and choices # perfs = get_perf(y_sample, y_loc) # y_choice is 1 for choosing stim1_loc, otherwise -1 y_actual_choice = 2*(get_dist(y_sample_loc[-1]-stim1_loc)<np.pi/2)-1 y_target_choice = 2*(get_dist(y_loc[-1]-stim1_loc)<np.pi/2)-1 # Get task variables task_var = dict() task_var['targ_dir'] = y_actual_choice task_var['stim_dir'] = np.tile(stim_mod1_cohs, n_rule) task_var['stim_col2dir'] = np.tile(stim_mod2_cohs, n_rule) task_var['context'] = np.repeat([1, -1], batch_size) task_var['correct'] = (y_actual_choice == y_target_choice).astype(int) task_var['stim_dir_sign'] = (task_var['stim_dir']>0).astype(int)*2-1 task_var['stim_col2dir_sign'] = (task_var['stim_col2dir']>0).astype(int)*2-1 n_unit = H.shape[-1] for i_unit in range(n_unit): unit_dict = { 'rate': H[:, :, i_unit].T, # standard format (n_trial, n_time) 'task_var': copy.deepcopy(task_var) } data.append(unit_dict) return data
def lesions(self): labels = self.labels from network import get_perf from task import generate_trials # The first will be the intact network lesion_units_list = [None] for il, l in enumerate(self.unique_labels): ind_l = np.where(labels == l)[0] # In original indices lesion_units_list += [self.ind_active[ind_l]] perfs_store_list = list() perfs_changes = list() cost_store_list = list() cost_changes = list() for i, lesion_units in enumerate(lesion_units_list): model = Model(self.model_dir) hp = model.hp with tf.Session() as sess: model.restore() model.lesion_units(sess, lesion_units) perfs_store = list() cost_store = list() for rule in self.rules: n_rep = 16 batch_size_test = 256 batch_size_test_rep = int(batch_size_test / n_rep) clsq_tmp = list() perf_tmp = list() for i_rep in range(n_rep): trial = generate_trials(rule, hp, 'random', batch_size=batch_size_test_rep) feed_dict = tools.gen_feed_dict(model, trial, hp) y_hat_test, c_lsq = sess.run( [model.y_hat, model.cost_lsq], feed_dict=feed_dict) # Cost is first summed over time, and averaged across batch and units # We did the averaging over time through c_mask # IMPORTANT CHANGES: take overall mean perf_test = np.mean(get_perf(y_hat_test, trial.y_loc)) clsq_tmp.append(c_lsq) perf_tmp.append(perf_test) perfs_store.append(np.mean(perf_tmp)) cost_store.append(np.mean(clsq_tmp)) perfs_store = np.array(perfs_store) cost_store = np.array(cost_store) perfs_store_list.append(perfs_store) cost_store_list.append(cost_store) if i > 0: perfs_changes.append(perfs_store - perfs_store_list[0]) cost_changes.append(cost_store - cost_store_list[0]) perfs_changes = np.array(perfs_changes) cost_changes = np.array(cost_changes) return perfs_changes, cost_changes
def plot_connectivity_byclusters(self): """Plot connectivity of the model""" ind_active = self.ind_active # Sort data by labels and by input connectivity model = Model(self.model_dir) hp = model.hp with tf.Session() as sess: model.restore() w_in = sess.run(model.w_in).T w_rec = sess.run(model.w_rec).T w_out = sess.run(model.w_out).T b_rec = sess.run(model.b_rec) b_out = sess.run(model.b_out) w_rec = w_rec[ind_active, :][:, ind_active] w_in = w_in[ind_active, :] w_out = w_out[:, ind_active] b_rec = b_rec[ind_active] # nx, nh, ny = hp['shape'] nr = hp['n_eachring'] sort_by = 'w_in' if sort_by == 'w_in': w_in_mod1 = w_in[:, 1:nr + 1] w_in_mod2 = w_in[:, nr + 1:2 * nr + 1] w_in_modboth = w_in_mod1 + w_in_mod2 w_prefs = np.argmax(w_in_modboth, axis=1) elif sort_by == 'w_out': w_prefs = np.argmax(w_out[1:], axis=0) # sort by labels then by prefs ind_sort = np.lexsort((w_prefs, self.labels)) ######################### Plotting Connectivity ############################### nx = self.hp['n_input'] ny = self.hp['n_output'] nh = len(self.ind_active) nr = self.hp['n_eachring'] nrule = len(self.hp['rules']) # Plot active units _w_rec = w_rec[ind_sort, :][:, ind_sort] _w_in = w_in[ind_sort, :] _w_out = w_out[:, ind_sort] _b_rec = b_rec[ind_sort, np.newaxis] _b_out = b_out[:, np.newaxis] labels = self.labels[ind_sort] l = 0.3 l0 = (1 - 1.5 * l) / nh plot_infos = [ (_w_rec, [l, l, nh * l0, nh * l0]), (_w_in[:, [0]], [l - (nx + 15) * l0, l, 1 * l0, nh * l0]), # Fixation input (_w_in[:, 1:nr + 1], [l - (nx + 11) * l0, l, nr * l0, nh * l0]), # Mod 1 stimulus (_w_in[:, nr + 1:2 * nr + 1], [l - (nx - nr + 8) * l0, l, nr * l0, nh * l0]), # Mod 2 stimulus (_w_in[:, 2 * nr + 1:], [l - (nx - 2 * nr + 5) * l0, l, nrule * l0, nh * l0]), # Rule inputs (_w_out[[0], :], [l, l - (4) * l0, nh * l0, 1 * l0]), (_w_out[1:, :], [l, l - (ny + 6) * l0, nh * l0, (ny - 1) * l0]), (_b_rec, [l + (nh + 6) * l0, l, l0, nh * l0]), (_b_out, [l + (nh + 6) * l0, l - (ny + 6) * l0, l0, ny * l0]) ] # cmap = sns.diverging_palette(220, 10, sep=80, as_cmap=True) cmap = 'coolwarm' fig = plt.figure(figsize=(6, 6)) for plot_info in plot_infos: ax = fig.add_axes(plot_info[1]) vmin, vmid, vmax = np.percentile(plot_info[0].flatten(), [5, 50, 95]) _ = ax.imshow(plot_info[0], interpolation='nearest', cmap=cmap, aspect='auto', vmin=vmid - (vmax - vmin) / 2, vmax=vmid + (vmax - vmin) / 2) ax.axis('off') ax1 = fig.add_axes([l, l + nh * l0, nh * l0, 6 * l0]) ax2 = fig.add_axes([l - 6 * l0, l, 6 * l0, nh * l0]) for il, l in enumerate(self.unique_labels): ind_l = np.where(labels == l)[0][[0, -1]] + np.array([0, 1]) ax1.plot(ind_l, [0, 0], linewidth=2, solid_capstyle='butt', color=kelly_colors[il + 1]) ax2.plot([0, 0], len(labels) - ind_l, linewidth=2, solid_capstyle='butt', color=kelly_colors[il + 1]) ax1.set_xlim([0, len(labels)]) ax2.set_ylim([0, len(labels)]) ax1.axis('off') ax2.axis('off') if save: plt.savefig('figure/connectivity_by' + self.data_type + '.pdf', transparent=True) plt.show()
def train( model_dir, hp=None, max_steps=1e7, display_step=500, ruleset='mante', rule_trains=None, rule_prob_map=None, seed=0, rich_output=False, load_dir=None, trainables=None, ): """Train the network. Args: model_dir: str, training directory hp: dictionary of hyperparameters max_steps: int, maximum number of training steps display_step: int, display steps ruleset: the set of rules to train rule_trains: list of rules to train, if None then all rules possible rule_prob_map: None or dictionary of relative rule probability seed: int, random seed to be used Returns: model is stored at model_dir/model.ckpt training configuration is stored at model_dir/hp.json """ tools.mkdir_p(model_dir) # Network parameters default_hp = get_default_hp(ruleset) if hp is not None: default_hp.update(hp) hp = default_hp hp['seed'] = seed hp['rng'] = np.random.RandomState(seed) # Rules to train and test. Rules in a set are trained together if rule_trains is None: # By default, training all rules available to this ruleset hp['rule_trains'] = task.rules_dict[ruleset] else: hp['rule_trains'] = rule_trains hp['rules'] = hp['rule_trains'] # Assign probabilities for rule_trains. if rule_prob_map is None: rule_prob_map = dict() # Turn into rule_trains format hp['rule_probs'] = None if hasattr(hp['rule_trains'], '__iter__'): # Set default as 1. rule_prob = np.array( [rule_prob_map.get(r, 1.) for r in hp['rule_trains']]) hp['rule_probs'] = list(rule_prob / np.sum(rule_prob)) tools.save_hp(hp, model_dir) # Build the model model = Model(model_dir, hp=hp) # Display hp for key, val in hp.items(): print('{:20s} = '.format(key) + str(val)) # Store results log = defaultdict(list) log['model_dir'] = model_dir # Record time t_start = time.time() with tf.Session() as sess: if load_dir is not None: model.restore(load_dir) # complete restore else: # Assume everything is restored sess.run(tf.global_variables_initializer()) # Set trainable parameters if trainables is None or trainables == 'all': var_list = model.var_list # train everything elif trainables == 'input': # train all nputs var_list = [ v for v in model.var_list if ('input' in v.name) and ('rnn' not in v.name) ] elif trainables == 'rule': # train rule inputs only var_list = [v for v in model.var_list if 'rule_input' in v.name] else: raise ValueError('Unknown trainables') model.set_optimizer(var_list=var_list) # penalty on deviation from initial weight if hp['l2_weight_init'] > 0: anchor_ws = sess.run(model.weight_list) for w, w_val in zip(model.weight_list, anchor_ws): model.cost_reg += (hp['l2_weight_init'] * tf.nn.l2_loss(w - w_val)) model.set_optimizer(var_list=var_list) # partial weight training if ('p_weight_train' in hp and (hp['p_weight_train'] is not None) and hp['p_weight_train'] < 1.0): for w in model.weight_list: w_val = sess.run(w) w_size = sess.run(tf.size(w)) w_mask_tmp = np.linspace(0, 1, w_size) hp['rng'].shuffle(w_mask_tmp) ind_fix = w_mask_tmp > hp['p_weight_train'] w_mask = np.zeros(w_size, dtype=np.float32) w_mask[ind_fix] = 1e-1 # will be squared in l2_loss w_mask = tf.constant(w_mask) w_mask = tf.reshape(w_mask, w.shape) model.cost_reg += tf.nn.l2_loss((w - w_val) * w_mask) model.set_optimizer(var_list=var_list) step = 0 while step * hp['batch_size_train'] <= max_steps: try: # Validation if step % display_step == 0: log['trials'].append(step * hp['batch_size_train']) log['times'].append(time.time() - t_start) log = do_eval(sess, model, log, hp['rule_trains']) #if log['perf_avg'][-1] > model.hp['target_perf']: #check if minimum performance is above target if log['perf_min'][-1] > model.hp['target_perf']: print('Perf reached the target: {:0.2f}'.format( hp['target_perf'])) break if rich_output: display_rich_output(model, sess, step, log, model_dir) # Training rule_train_now = hp['rng'].choice(hp['rule_trains'], p=hp['rule_probs']) # Generate a random batch of trials. # Each batch has the same trial length trial = generate_trials(rule_train_now, hp, 'random', batch_size=hp['batch_size_train']) # Generating feed_dict. feed_dict = tools.gen_feed_dict(model, trial, hp) sess.run(model.train_step, feed_dict=feed_dict) step += 1 except KeyboardInterrupt: print("Optimization interrupted by user") break print("Optimization finished!")
def psychometric_choicefamily_2D(model_dir, rule, lesion_units=None, n_coh=8, n_stim_loc=20, coh_range=0.1): # Generate task parameters for choice tasks # coh_range = 0.2 # coh_range = 0.05 cohs = np.linspace(-coh_range, coh_range, n_coh) batch_size = n_stim_loc * n_coh**2 batch_shape = (n_stim_loc, n_coh, n_coh) ind_stim_loc, ind_stim_mod1, ind_stim_mod2 = np.unravel_index( range(batch_size), batch_shape) # Looping target location stim1_locs = 2 * np.pi * ind_stim_loc / n_stim_loc stim2_locs = (stim1_locs + np.pi) % (2 * np.pi) stim_mod1_cohs = cohs[ind_stim_mod1] stim_mod2_cohs = cohs[ind_stim_mod2] params_dict = dict() params_dict['dm1'] = \ {'stim1_locs' : stim1_locs, 'stim2_locs' : stim2_locs, 'stim1_strengths' : 1 + stim_mod1_cohs, # Just use mod 1 value 'stim2_strengths' : 1 - stim_mod1_cohs, 'stim_time' : 800 } params_dict['dm2'] = params_dict['dm1'] params_dict['contextdm1'] = \ {'stim1_locs' : stim1_locs, 'stim2_locs' : stim2_locs, 'stim1_mod1_strengths' : 1 + stim_mod1_cohs, 'stim2_mod1_strengths' : 1 - stim_mod1_cohs, 'stim1_mod2_strengths' : 1 + stim_mod2_cohs, 'stim2_mod2_strengths' : 1 - stim_mod2_cohs, 'stim_time' : 800 } params_dict['contextdm2'] = params_dict['contextdm1'] params_dict['contextdelaydm1'] = params_dict['contextdm1'] params_dict['contextdelaydm1']['stim_time'] = 800 params_dict['contextdelaydm2'] = params_dict['contextdelaydm1'] params_dict['multidm'] = \ {'stim1_locs' : stim1_locs, 'stim2_locs' : stim2_locs, 'stim1_mod1_strengths' : 1 + stim_mod1_cohs, 'stim2_mod1_strengths' : 1 - stim_mod1_cohs, 'stim1_mod2_strengths' : 1 + stim_mod1_cohs, # Same as Mod 1 'stim2_mod2_strengths' : 1 - stim_mod1_cohs, 'stim_time' : 800 } params_dict['contextdelaydm1'] = \ {'stim1_locs' : stim1_locs, 'stim2_locs' : stim2_locs, 'stim1_mod1_strengths' : 1 + stim_mod1_cohs, 'stim2_mod1_strengths' : 1 - stim_mod1_cohs, 'stim1_mod2_strengths' : 1 + stim_mod2_cohs, 'stim2_mod2_strengths' : 1 - stim_mod2_cohs, 'stim_time' : 800 } model = Model(model_dir) hp = model.hp with tf.Session() as sess: model.restore() model.lesion_units(sess, lesion_units) params = params_dict[rule] trial = generate_trials(rule, hp, 'psychometric', params=params) feed_dict = tools.gen_feed_dict(model, trial, hp) y_sample, y_loc_sample = sess.run([model.y_hat, model.y_hat_loc], feed_dict=feed_dict) # Compute the overall performance. # Importantly, discard trials where no decision was made loc_cor = trial.y_loc[-1] # last time point, correct locations loc_err = (loc_cor + np.pi) % (2 * np.pi) choose_cor = (get_dist(y_loc_sample[-1] - loc_cor) < THETA).sum() choose_err = (get_dist(y_loc_sample[-1] - loc_err) < THETA).sum() perf = choose_cor / (choose_cor + choose_err) # Compute the proportion of choosing choice 1 and maintain the batch_shape stim1_locs_ = np.reshape(stim1_locs, batch_shape) stim2_locs_ = np.reshape(stim2_locs, batch_shape) y_loc_sample = np.reshape(y_loc_sample[-1], batch_shape) choose1 = (get_dist(y_loc_sample - stim1_locs_) < THETA).sum(axis=0) choose2 = (get_dist(y_loc_sample - stim2_locs_) < THETA).sum(axis=0) prop1s = choose1 / (choose1 + choose2) return perf, prop1s, cohs
def pretty_inputoutput_plot(model_dir, rule, save=False, plot_ylabel=False): """Plot the input and output activity for a sample trial from one task. Args: model_dir: model directory rule: string, the rule save: bool, whether to save plots plot_ylabel: bool, whether to plot ylable """ fs = 7 model = Model(model_dir) hp = model.hp with tf.Session() as sess: model.restore() trial = generate_trials(rule, hp, mode='test') x, y = trial.x, trial.y feed_dict = tools.gen_feed_dict(model, trial, hp) h, y_hat = sess.run([model.h, model.y_hat], feed_dict=feed_dict) t_plot = np.arange(x.shape[0]) * hp['dt'] / 1000 assert hp['num_ring'] == 2 n_eachring = hp['n_eachring'] fig = plt.figure(figsize=(1.3, 2)) ylabels = ['fix. in', 'stim. mod1', 'stim. mod2', 'fix. out', 'out'] heights = np.array([0.03, 0.2, 0.2, 0.03, 0.2]) + 0.01 for i in range(5): ax = fig.add_axes( [0.15, sum(heights[i + 1:] + 0.02) + 0.1, 0.8, heights[i]]) cmap = 'Purples' plt.xticks([]) ax.tick_params(axis='both', which='major', labelsize=fs, width=0.5, length=2, pad=3) if plot_ylabel: ax.spines["right"].set_visible(False) ax.spines["bottom"].set_visible(False) ax.spines["top"].set_visible(False) ax.xaxis.set_ticks_position('bottom') ax.yaxis.set_ticks_position('left') else: ax.spines["left"].set_visible(False) ax.spines["right"].set_visible(False) ax.spines["bottom"].set_visible(False) ax.spines["top"].set_visible(False) ax.xaxis.set_ticks_position('none') if i == 0: plt.plot(t_plot, x[:, 0, 0], color='xkcd:blue') if plot_ylabel: plt.yticks([0, 1], ['', ''], rotation='vertical') plt.ylim([-0.1, 1.5]) plt.title(rule_name[rule], fontsize=fs) elif i == 1: plt.imshow(x[:, 0, 1:1 + n_eachring].T, aspect='auto', cmap=cmap, vmin=0, vmax=1, interpolation='none', origin='lower') if plot_ylabel: plt.yticks( [0, (n_eachring - 1) / 2, n_eachring - 1], [r'0$\degree$', r'180$\degree$', r'360$\degree$'], rotation='vertical') elif i == 2: plt.imshow(x[:, 0, 1 + n_eachring:1 + 2 * n_eachring].T, aspect='auto', cmap=cmap, vmin=0, vmax=1, interpolation='none', origin='lower') if plot_ylabel: plt.yticks( [0, (n_eachring - 1) / 2, n_eachring - 1], [r'0$\degree$', r'180$\degree$', r'360$\degree$'], rotation='vertical') elif i == 3: plt.plot(t_plot, y[:, 0, 0], color='xkcd:green') plt.plot(t_plot, y_hat[:, 0, 0], color='xkcd:blue') if plot_ylabel: plt.yticks([0.05, 0.8], ['', ''], rotation='vertical') plt.ylim([-0.1, 1.1]) elif i == 4: plt.imshow(y_hat[:, 0, 1:].T, aspect='auto', cmap=cmap, vmin=0, vmax=1, interpolation='none', origin='lower') if plot_ylabel: plt.yticks( [0, (n_eachring - 1) / 2, n_eachring - 1], [r'0$\degree$', r'180$\degree$', r'360$\degree$'], rotation='vertical') plt.xticks([0, y_hat.shape[0]], ['0', '2']) plt.xlabel('Time (s)', fontsize=fs, labelpad=-3) ax.spines["bottom"].set_visible(True) if plot_ylabel: plt.ylabel(ylabels[i], fontsize=fs) else: plt.yticks([]) ax.get_yaxis().set_label_coords(-0.12, 0.5) if save: save_name = 'figure/sample_' + rule_name[rule].replace(' ', '') + '.pdf' plt.savefig(save_name, transparent=True) plt.show()
def plot_rule_connections(self): """Plot connectivity while sorting by group. Args: conn_type: str, type of connectivity to plot. """ # Sort data by labels and by input connectivity model = Model(self.model_dir) hp = model.hp with tf.Session() as sess: model.restore() w_in = sess.run(model.w_in) w_in = w_in.T groups = ['1', '2', '12'] # Plot input rule connectivity rules = ['contextdm1', 'contextdm2', 'dm1', 'dm2', 'multidm'] w_stores = OrderedDict() w_all_stores = OrderedDict() pos = list() width = 0.15 colors = list() for i_group, group in enumerate(groups): w_store_tmp = list() ind = self.group_ind_orig[group] for i_rule, rule in enumerate(rules): ind_rule = get_rule_index(rule, hp) w_conn = w_in[ind, ind_rule].mean(axis=0) w_store_tmp.append(w_conn) w_all_stores[(group, rule)] = w_in[ind, ind_rule].flatten() pos.append(i_rule+(i_group-1.5)*width) colors.append(self.colors[group]) w_stores[group] = w_store_tmp fs = 6 fig = plt.figure(figsize=(2.5, 1.2)) ax = fig.add_axes([0.17,0.45,0.8,0.4]) # for i, group in enumerate(groups): # x = np.arange(len(rules))+(i-1.5)*width # b0 = ax.bar(x, w_stores[group], # width=width, color=self.colors[group], edgecolor='none') bp = ax.boxplot([w for w in w_all_stores.values()], notch=True, sym='', bootstrap=10000, showcaps=False, patch_artist=True, widths=width, positions=pos, whiskerprops={'linewidth': 1.0}) # for element in ['boxes', 'whiskers', 'fliers']: # plt.setp(bp[element], color='xkcd:cerulean') for patch, c in zip(bp['boxes'], colors): plt.setp(patch, color=c) for i_whisker, patch in enumerate(bp['whiskers']): plt.setp(patch, color=colors[int(i_whisker/2)]) for element in ['means', 'medians']: plt.setp(bp[element], color='white') ax.set_xticks(np.arange(len(rules))) ax.set_xticklabels([rule_name[r] for r in rules], rotation=25) ax.set_xlabel('Input from rule units', fontsize=fs, labelpad=3) ax.set_ylabel('Conn. weight', fontsize=fs) # lg = ax.legend(groups, fontsize=fs, ncol=3, bbox_to_anchor=(1,1.5), # labelspacing=0.2, loc=1, frameon=False, title='To group') # plt.setp(lg.get_title(),fontsize=fs) ax.tick_params(axis='both', which='major', labelsize=fs) plt.locator_params(axis='y',nbins=2) ax.spines["right"].set_visible(False) ax.spines["top"].set_visible(False) ax.xaxis.set_ticks_position('bottom') ax.yaxis.set_ticks_position('left') ax.set_xlim([-0.8, len(rules)-0.2]) ax.plot([-0.5, len(rules)-0.5], [0, 0], color='gray', linewidth=0.5) plt.savefig('figure/conn_rule_contextdm.pdf', transparent=True) plt.show()
def schematic_plot(model_dir, rule=None): fontsize = 6 rule = rule or 'dm1' model = Model(model_dir, dt=1) hp = model.hp with tf.Session() as sess: model.restore() trial = generate_trials(rule, hp, mode='test') feed_dict = tools.gen_feed_dict(model, trial, hp) x = trial.x h, y_hat = sess.run([model.h, model.y_hat], feed_dict=feed_dict) n_eachring = hp['n_eachring'] n_hidden = hp['n_rnn'] # Plot Stimulus fig = plt.figure(figsize=(1.0, 1.2)) heights = np.array([0.06, 0.25, 0.25]) for i in range(3): ax = fig.add_axes( [0.2, sum(heights[i + 1:] + 0.1) + 0.05, 0.7, heights[i]]) cmap = 'Purples' plt.xticks([]) # Fixed style for these plots ax.tick_params(axis='both', which='major', labelsize=fontsize, width=0.5, length=2, pad=3) ax.spines["left"].set_linewidth(0.5) ax.spines["right"].set_visible(False) ax.spines["bottom"].set_visible(False) ax.spines["top"].set_visible(False) ax.xaxis.set_ticks_position('bottom') ax.yaxis.set_ticks_position('left') if i == 0: plt.plot(x[:, 0, 0], color='xkcd:blue') plt.yticks([0, 1], ['', ''], rotation='vertical') plt.ylim([-0.1, 1.5]) plt.title('Fixation input', fontsize=fontsize, y=0.9) elif i == 1: plt.imshow(x[:, 0, 1:1 + n_eachring].T, aspect='auto', cmap=cmap, vmin=0, vmax=1, interpolation='none', origin='lower') plt.yticks([0, (n_eachring - 1) / 2, n_eachring - 1], [r'0$\degree$', '', r'360$\degree$'], rotation='vertical') plt.title('Stimulus mod 1', fontsize=fontsize, y=0.9) elif i == 2: plt.imshow(x[:, 0, 1 + n_eachring:1 + 2 * n_eachring].T, aspect='auto', cmap=cmap, vmin=0, vmax=1, interpolation='none', origin='lower') plt.yticks([0, (n_eachring - 1) / 2, n_eachring - 1], ['', '', ''], rotation='vertical') plt.title('Stimulus mod 2', fontsize=fontsize, y=0.9) ax.get_yaxis().set_label_coords(-0.12, 0.5) plt.savefig('figure/schematic_input.pdf', transparent=True) plt.show() # Plot Rule Inputs fig = plt.figure(figsize=(1.0, 0.5)) ax = fig.add_axes([0.2, 0.3, 0.7, 0.45]) cmap = 'Purples' X = x[:, 0, 1 + 2 * n_eachring:] plt.imshow(X.T, aspect='auto', vmin=0, vmax=1, cmap=cmap, interpolation='none', origin='lower') plt.xticks([0, X.shape[0]]) ax.set_xlabel('Time (ms)', fontsize=fontsize, labelpad=-5) # Fixed style for these plots ax.tick_params(axis='both', which='major', labelsize=fontsize, width=0.5, length=2, pad=3) ax.spines["left"].set_linewidth(0.5) ax.spines["right"].set_visible(False) ax.spines["bottom"].set_linewidth(0.5) ax.spines["top"].set_visible(False) ax.xaxis.set_ticks_position('bottom') ax.yaxis.set_ticks_position('left') plt.yticks([0, X.shape[-1] - 1], ['1', str(X.shape[-1])], rotation='vertical') plt.title('Rule inputs', fontsize=fontsize, y=0.9) ax.get_yaxis().set_label_coords(-0.12, 0.5) plt.savefig('figure/schematic_rule.pdf', transparent=True) plt.show() # Plot Units fig = plt.figure(figsize=(1.0, 0.8)) ax = fig.add_axes([0.2, 0.1, 0.7, 0.75]) cmap = 'Purples' plt.xticks([]) # Fixed style for these plots ax.tick_params(axis='both', which='major', labelsize=fontsize, width=0.5, length=2, pad=3) ax.spines["left"].set_linewidth(0.5) ax.spines["right"].set_visible(False) ax.spines["bottom"].set_visible(False) ax.spines["top"].set_visible(False) ax.xaxis.set_ticks_position('bottom') ax.yaxis.set_ticks_position('left') plt.imshow(h[:, 0, :].T, aspect='auto', cmap=cmap, vmin=0, vmax=1, interpolation='none', origin='lower') plt.yticks([0, n_hidden - 1], ['1', str(n_hidden)], rotation='vertical') plt.title('Recurrent units', fontsize=fontsize, y=0.95) ax.get_yaxis().set_label_coords(-0.12, 0.5) plt.savefig('figure/schematic_units.pdf', transparent=True) plt.show() # Plot Outputs fig = plt.figure(figsize=(1.0, 0.8)) heights = np.array([0.1, 0.45]) + 0.01 for i in range(2): ax = fig.add_axes( [0.2, sum(heights[i + 1:] + 0.15) + 0.1, 0.7, heights[i]]) cmap = 'Purples' plt.xticks([]) # Fixed style for these plots ax.tick_params(axis='both', which='major', labelsize=fontsize, width=0.5, length=2, pad=3) ax.spines["left"].set_linewidth(0.5) ax.spines["right"].set_visible(False) ax.spines["bottom"].set_visible(False) ax.spines["top"].set_visible(False) ax.xaxis.set_ticks_position('bottom') ax.yaxis.set_ticks_position('left') if i == 0: plt.plot(y_hat[:, 0, 0], color='xkcd:blue') plt.yticks([0.05, 0.8], ['', ''], rotation='vertical') plt.ylim([-0.1, 1.1]) plt.title('Fixation output', fontsize=fontsize, y=0.9) elif i == 1: plt.imshow(y_hat[:, 0, 1:].T, aspect='auto', cmap=cmap, vmin=0, vmax=1, interpolation='none', origin='lower') plt.yticks([0, (n_eachring - 1) / 2, n_eachring - 1], [r'0$\degree$', '', r'360$\degree$'], rotation='vertical') plt.xticks([]) plt.title('Response', fontsize=fontsize, y=0.9) ax.get_yaxis().set_label_coords(-0.12, 0.5) plt.savefig('figure/schematic_outputs.pdf', transparent=True) plt.show()
def pretty_singleneuron_plot(model_dir, rules, neurons, epoch=None, save=False, ylabel_firstonly=True, trace_only=False, plot_stim_avg=False, save_name=''): """Plot the activity of a single neuron in time across many trials Args: model_dir: rules: rules to plot neurons: indices of neurons to plot epoch: epoch to plot save: save figure? ylabel_firstonly: if True, only plot ylabel for the first rule in rules """ if isinstance(rules, str): rules = [rules] try: _ = iter(neurons) except TypeError: neurons = [neurons] h_tests = dict() model = Model(model_dir) hp = model.hp with tf.Session() as sess: model.restore() t_start = int(500 / hp['dt']) for rule in rules: # Generate a batch of trial from the test mode trial = generate_trials(rule, hp, mode='test') feed_dict = tools.gen_feed_dict(model, trial, hp) h = sess.run(model.h, feed_dict=feed_dict) h_tests[rule] = h for neuron in neurons: h_max = np.max([h_tests[r][t_start:, :, neuron].max() for r in rules]) for j, rule in enumerate(rules): fs = 6 fig = plt.figure(figsize=(1.0, 0.8)) ax = fig.add_axes([0.35, 0.25, 0.55, 0.55]) t_plot = np.arange( h_tests[rule][t_start:].shape[0]) * hp['dt'] / 1000 _ = ax.plot(t_plot, h_tests[rule][t_start:, :, neuron], lw=0.5, color='gray') if plot_stim_avg: # Plot stimulus averaged trace _ = ax.plot(np.arange(h_tests[rule][t_start:].shape[0]) * hp['dt'] / 1000, h_tests[rule][t_start:, :, neuron].mean(axis=1), lw=1, color='black') if epoch is not None: e0, e1 = trial.epochs[epoch] e0 = e0 if e0 is not None else 0 e1 = e1 if e1 is not None else h_tests[rule].shape[0] ax.plot([e0, e1], [h_max * 1.15] * 2, color='black', linewidth=1.5) figname = 'figure/trace_' + rule_name[ rule] + epoch + save_name + '.pdf' else: figname = 'figure/trace_unit' + str( neuron) + rule_name[rule] + save_name + '.pdf' plt.ylim(np.array([-0.1, 1.2]) * h_max) plt.xticks([0, np.floor(np.max(t_plot) + 0.01)]) plt.xlabel('Time (s)', fontsize=fs, labelpad=-5) plt.locator_params(axis='y', nbins=4) if j > 0 and ylabel_firstonly: ax.set_yticklabels([]) else: plt.ylabel('Activitity (a.u.)', fontsize=fs, labelpad=2) plt.title('Unit {:d} '.format(neuron) + rule_name[rule], fontsize=5) ax.tick_params(axis='both', which='major', labelsize=fs) ax.spines["right"].set_visible(False) ax.spines["top"].set_visible(False) ax.xaxis.set_ticks_position('bottom') ax.yaxis.set_ticks_position('left') if trace_only: ax.spines["left"].set_visible(False) ax.spines["bottom"].set_visible(False) ax.xaxis.set_ticks_position('none') ax.set_xlabel('') ax.set_ylabel('') ax.set_xticks([]) ax.set_yticks([]) ax.set_title('') if save: plt.savefig(figname, transparent=True) plt.show()