def easy_activity_plot(model_dir, rule): """A simple plot of neural activity from one task. Args: model_dir: directory where model file is saved rule: string, the rule to plot """ model = Model(model_dir) hp = model.hp with tf.Session() as sess: model.restore() trial = generate_trials(rule, hp, mode='test') feed_dict = tools.gen_feed_dict(model, trial, hp) h, y_hat = sess.run([model.h, model.y_hat], feed_dict=feed_dict) # All matrices have shape (n_time, n_condition, n_neuron) # Take only the one example trial i_trial = 0 for activity, title in zip([trial.x, h, y_hat], ['input', 'recurrent', 'output']): plt.figure() plt.imshow(activity[:, i_trial, :].T, aspect='auto', cmap='hot', interpolation='none', origin='lower') plt.title(title) plt.colorbar() plt.show()
def quick_statespace(model_dir): """Quick state space analysis using simply PCA.""" rules = ['contextdm1', 'contextdm2'] h_lastts = dict() model = Model(model_dir) hp = model.hp with tf.Session() as sess: model.restore() for rule in rules: # Generate a batch of trial from the test mode trial = generate_trials(rule, hp, mode='test') feed_dict = tools.gen_feed_dict(model, trial, hp) h = sess.run(model.h, feed_dict=feed_dict) lastt = trial.epochs['stim1'][-1] h_lastts[rule] = h[lastt,:,:] from sklearn.decomposition import PCA model = PCA(n_components=5) model.fit(np.concatenate(h_lastts.values(), axis=0)) fig = plt.figure(figsize=(2,2)) ax = fig.add_axes([.3, .3, .6, .6]) for rule, color in zip(rules, ['red', 'blue']): data_trans = model.transform(h_lastts[rule]) ax.scatter(data_trans[:, 0], data_trans[:, 1], s=1, label=rule_name[rule], color=color) plt.tick_params(axis='both', which='major', labelsize=7) ax.set_xlabel('PC 1', fontsize=7) ax.set_ylabel('PC 2', fontsize=7) lg = ax.legend(fontsize=7, ncol=1, bbox_to_anchor=(1,0.3), loc=1, frameon=False) if save: plt.savefig('figure/choiceatt_quickstatespace.pdf',transparent=True)
def run_network_replacerule(model_dir, rule, replace_rule, rule_strength): """Run the network but with replaced rule input weights. Args: model_dir: model directory rule: the rule to test on replace_rule: a list of rule input units to use rule_strength: the relative strength of each replace rule unit """ model = Model(model_dir) hp = model.hp with tf.Session() as sess: model.restore() # Get performance batch_size_test = 1000 n_rep = 20 batch_size_test_rep = int(batch_size_test / n_rep) perf_rep = list() for i_rep in range(n_rep): trial = generate_trials(rule, hp, 'random', batch_size=batch_size_test_rep, replace_rule=replace_rule, rule_strength=rule_strength) feed_dict = tools.gen_feed_dict(model, trial, hp) y_hat_test = sess.run(model.y_hat, feed_dict=feed_dict) perf_rep.append(np.mean(get_perf(y_hat_test, trial.y_loc))) return np.mean(perf_rep), rule_strength
def _compute_H(self, model, rule, trial, sess,): feed_dict = tools.gen_feed_dict(model, trial, self.hp) h = sess.run(model.h, feed_dict=feed_dict) fname = os.path.join(model.model_dir, 'H_'+rule+'.pkl') with open(fname, 'wb') as f: pickle.dump(h, f)
def __init__(self, model_dir, rules=None): """Initialization. Args: model_dir: str, model directory rules: None or a list of rules """ # Stimulus-averaged traces h_stimavg_byrule = OrderedDict() h_stimavg_byepoch = OrderedDict() # Last time points of epochs h_lastt_byepoch = OrderedDict() model = Model(model_dir) hp = model.hp if rules is None: # Default value rules = hp['rules'] n_rules = len(rules) with tf.Session() as sess: model.restore() for rule in rules: trial = generate_trials(rule=rule, hp=hp, mode='test') feed_dict = tools.gen_feed_dict(model, trial, hp) h = sess.run(model.h, feed_dict=feed_dict) # Average across stimulus conditions h_stimavg = h.mean(axis=1) # dt_new = 50 # every_t = int(dt_new/hp['dt']) t_start = int( 500 / hp['dt']) # Important: Ignore the initial transition # Average across stimulus conditions h_stimavg_byrule[rule] = h_stimavg[t_start:, :] for e_name, e_time in trial.epochs.items(): if 'fix' in e_name: continue # if ('fix' not in e_name) and ('go' not in e_name): # Take epoch e_time_start = e_time[0] - 1 if e_time[0] > 0 else 0 h_stimavg_byepoch[( rule, e_name)] = h_stimavg[e_time_start:e_time[1], :] # Take last time point from epoch # h_all_byepoch[(rule, e_name)] = np.mean(h[e_time[0]:e_time[1],:,:][-1], axis=1) h_lastt_byepoch[(rule, e_name)] = h[e_time[1], :, :] self.rules = rules self.h_stimavg_byrule = h_stimavg_byrule self.h_stimavg_byepoch = h_stimavg_byepoch self.h_lastt_byepoch = h_lastt_byepoch self.model_dir = model_dir
def _psychometric_dm(model_dir, rule, params_list, batch_shape): """Base function for computing psychometric performance in 2AFC tasks Args: model_dir : model name rule : task to analyze params_list : a list of parameter dictionaries used for the psychometric mode batch_shape : shape of each batch. Each batch should have shape (n_rep, ...) n_rep is the number of repetitions that will be averaged over Return: ydatas: list of performances """ print('Starting psychometric analysis of the {:s} task...'.format( rule_name[rule])) model = Model(model_dir) hp = model.hp with tf.Session() as sess: model.restore() ydatas = list() for params in params_list: trial = generate_trials(rule, hp, 'psychometric', params=params) feed_dict = tools.gen_feed_dict(model, trial, hp) y_loc_sample = sess.run(model.y_hat_loc, feed_dict=feed_dict) y_loc_sample = np.reshape(y_loc_sample[-1], batch_shape) stim1_locs_ = np.reshape(params['stim1_locs'], batch_shape) stim2_locs_ = np.reshape(params['stim2_locs'], batch_shape) # Average over the first dimension of each batch choose1 = (get_dist(y_loc_sample - stim1_locs_) < THETA).sum( axis=0) choose2 = (get_dist(y_loc_sample - stim2_locs_) < THETA).sum( axis=0) ydatas.append(choose1 / (choose1 + choose2)) return ydatas
def do_eval_test(sess, model, rule): """Do evaluation. Args: sess: tensorflow session model: Model class instance rule_train: string or list of strings, the rules being trained """ hp = model.hp trial = generate_trials(rule, hp, 'test') feed_dict = tools.gen_feed_dict(model, trial, hp) c_lsq, c_reg, y_hat_test = sess.run( [model.cost_lsq, model.cost_reg, model.y_hat], feed_dict=feed_dict) # Cost is first summed over time, # and averaged across batch and units # We did the averaging over time through c_mask perf_test = np.mean(get_perf(y_hat_test, trial.y_loc)) sys.stdout.flush() return c_lsq, c_reg, perf_test
def activity_histogram(model_dir, rules, title=None, save_name=None): """Plot the activity histogram.""" if isinstance(rules, str): rules = [rules] h_all = None model = Model(model_dir) hp = model.hp with tf.Session() as sess: model.restore() t_start = int(500/hp['dt']) for rule in rules: # Generate a batch of trial from the test mode trial = generate_trials(rule, hp, mode='test') feed_dict = tools.gen_feed_dict(model, trial, hp) h = sess.run(model.h, feed_dict=feed_dict) h = h[t_start:, :, :] if h_all is None: h_all = h else: h_all = np.concatenate((h_all, h), axis=1) # var = h_all.var(axis=0).mean(axis=0) # ind = var > 1e-2 # h_plot = h_all[:, :, ind].flatten() h_plot = h_all.flatten() fig = plt.figure(figsize=(1.5, 1.2)) ax = fig.add_axes([0.2, 0.2, 0.7, 0.6]) ax.hist(h_plot, bins=20, density=True) ax.set_xlabel('Activity', fontsize=7) [ax.spines[s].set_visible(False) for s in ['left', 'top', 'right']] ax.set_yticks([])
def lesions(self): labels = self.labels from network import get_perf from task import generate_trials # The first will be the intact network lesion_units_list = [None] for il, l in enumerate(self.unique_labels): ind_l = np.where(labels == l)[0] # In original indices lesion_units_list += [self.ind_active[ind_l]] perfs_store_list = list() perfs_changes = list() cost_store_list = list() cost_changes = list() for i, lesion_units in enumerate(lesion_units_list): model = Model(self.model_dir) hp = model.hp with tf.Session() as sess: model.restore() model.lesion_units(sess, lesion_units) perfs_store = list() cost_store = list() for rule in self.rules: n_rep = 16 batch_size_test = 256 batch_size_test_rep = int(batch_size_test / n_rep) clsq_tmp = list() perf_tmp = list() for i_rep in range(n_rep): trial = generate_trials(rule, hp, 'random', batch_size=batch_size_test_rep) feed_dict = tools.gen_feed_dict(model, trial, hp) y_hat_test, c_lsq = sess.run( [model.y_hat, model.cost_lsq], feed_dict=feed_dict) # Cost is first summed over time, and averaged across batch and units # We did the averaging over time through c_mask # IMPORTANT CHANGES: take overall mean perf_test = np.mean(get_perf(y_hat_test, trial.y_loc)) clsq_tmp.append(c_lsq) perf_tmp.append(perf_test) perfs_store.append(np.mean(perf_tmp)) cost_store.append(np.mean(clsq_tmp)) perfs_store = np.array(perfs_store) cost_store = np.array(cost_store) perfs_store_list.append(perfs_store) cost_store_list.append(cost_store) if i > 0: perfs_changes.append(perfs_store - perfs_store_list[0]) cost_changes.append(cost_store - cost_store_list[0]) perfs_changes = np.array(perfs_changes) cost_changes = np.array(cost_changes) return perfs_changes, cost_changes
def train_sequential( model_dir, rule_trains, hp=None, max_steps=1e7, display_step=500, ruleset='mante', seed=0, ): '''Train the network sequentially. Args: model_dir: str, training directory rule_trains: a list of list of tasks to train sequentially hp: dictionary of hyperparameters max_steps: int, maximum number of training steps for each list of tasks display_step: int, display steps ruleset: the set of rules to train seed: int, random seed to be used Returns: model is stored at model_dir/model.ckpt training configuration is stored at model_dir/hp.json ''' tools.mkdir_p(model_dir) # Network parameters default_hp = get_default_hp(ruleset) if hp is not None: default_hp.update(hp) hp = default_hp hp['seed'] = seed hp['rng'] = np.random.RandomState(seed) hp['rule_trains'] = rule_trains # Get all rules by flattening the list of lists hp['rules'] = [r for rs in rule_trains for r in rs] # Number of training iterations for each rule rule_train_iters = [len(r) * max_steps for r in rule_trains] tools.save_hp(hp, model_dir) # Display hp for key, val in hp.items(): print('{:20s} = '.format(key) + str(val)) # Using continual learning or not c, ksi = hp['c_intsyn'], hp['ksi_intsyn'] # Build the model model = Model(model_dir, hp=hp) grad_unreg = tf.gradients(model.cost_lsq, model.var_list) # Store results log = defaultdict(list) log['model_dir'] = model_dir # Record time t_start = time.time() # tensorboard summaries placeholders = list() for v_name in ['Omega0', 'omega0', 'vdelta']: for v in model.var_list: placeholder = tf.placeholder(tf.float32, shape=v.shape) tf.summary.histogram(v_name + '/' + v.name, placeholder) placeholders.append(placeholder) merged = tf.summary.merge_all() test_writer = tf.summary.FileWriter(model_dir + '/tb') def relu(x): return x * (x > 0.) # Use customized session that launches the graph as well with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # penalty on deviation from initial weight if hp['l2_weight_init'] > 0: raise NotImplementedError() # Looping step_total = 0 for i_rule_train, rule_train in enumerate(hp['rule_trains']): step = 0 # At the beginning of new tasks # Only if using intelligent synapses v_current = sess.run(model.var_list) if i_rule_train == 0: v_anc0 = v_current Omega0 = [np.zeros(v.shape, dtype='float32') for v in v_anc0] omega0 = [np.zeros(v.shape, dtype='float32') for v in v_anc0] v_delta = [np.zeros(v.shape, dtype='float32') for v in v_anc0] elif c > 0: v_anc0_prev = v_anc0 v_anc0 = v_current v_delta = [ v - v_prev for v, v_prev in zip(v_anc0, v_anc0_prev) ] # Make sure all elements in omega0 are non-negative # Penalty Omega0 = [ relu(O + o / (v_d**2 + ksi)) for O, o, v_d in zip(Omega0, omega0, v_delta) ] # Update cost model.cost_reg = tf.constant(0.) for v, w, v_val in zip(model.var_list, Omega0, v_current): model.cost_reg += c * tf.reduce_sum( tf.multiply(tf.constant(w), tf.square(v - tf.constant(v_val)))) model.set_optimizer() # Store Omega0 to tf summary feed_dict = dict(zip(placeholders, Omega0 + omega0 + v_delta)) summary = sess.run(merged, feed_dict=feed_dict) test_writer.add_summary(summary, i_rule_train) # Reset omega0 = [np.zeros(v.shape, dtype='float32') for v in v_anc0] # Keep training until reach max iterations while (step * hp['batch_size_train'] <= rule_train_iters[i_rule_train]): # Validation if step % display_step == 0: trial = step_total * hp['batch_size_train'] log['trials'].append(trial) log['times'].append(time.time() - t_start) log['rule_now'].append(rule_train) log = do_eval(sess, model, log, rule_train) if log['perf_avg'][-1] > model.hp['target_perf']: print('Perf reached the target: {:0.2f}'.format( hp['target_perf'])) break # Training rule_train_now = hp['rng'].choice(rule_train) # Generate a random batch of trials. # Each batch has the same trial length trial = generate_trials(rule_train_now, hp, 'random', batch_size=hp['batch_size_train']) # Generating feed_dict. feed_dict = tools.gen_feed_dict(model, trial, hp) # Continual learning with intelligent synapses v_prev = v_current # This will compute the gradient BEFORE train step _, v_grad = sess.run([model.train_step, grad_unreg], feed_dict=feed_dict) # Get the weight after train step v_current = sess.run(model.var_list) # Update synaptic importance omega0 = [ o - (v_c - v_p) * v_g for o, v_c, v_p, v_g in zip( omega0, v_current, v_prev, v_grad) ] step += 1 step_total += 1 print("Optimization Finished!")
def train_rule_only( model_dir, rule_trains, max_steps, hp=None, ruleset='all', seed=0, ): '''Customized training function. The network sequentially but only train rule for the second set. First train the network to perform tasks in group 1, then train on group 2. When training group 2, only rule connections are being trained. Args: model_dir: str, training directory rule_trains: a list of list of tasks to train sequentially hp: dictionary of hyperparameters max_steps: int, maximum number of training steps for each list of tasks display_step: int, display steps ruleset: the set of rules to train seed: int, random seed to be used Returns: model is stored at model_dir/model.ckpt training configuration is stored at model_dir/hp.json ''' tools.mkdir_p(model_dir) # Network parameters default_hp = get_default_hp(ruleset) if hp is not None: default_hp.update(hp) hp = default_hp hp['seed'] = seed hp['rng'] = np.random.RandomState(seed) hp['rule_trains'] = rule_trains # Get all rules by flattening the list of lists hp['rules'] = [r for rs in rule_trains for r in rs] # Number of training iterations for each rule if hasattr(max_steps, '__iter__'): rule_train_iters = max_steps else: rule_train_iters = [len(r) * max_steps for r in rule_trains] tools.save_hp(hp, model_dir) # Display hp for key, val in hp.items(): print('{:20s} = '.format(key) + str(val)) # Build the model model = Model(model_dir, hp=hp) # Store results log = defaultdict(list) log['model_dir'] = model_dir # Record time t_start = time.time() # Use customized session that launches the graph as well with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # penalty on deviation from initial weight if hp['l2_weight_init'] > 0: raise NotImplementedError() # Looping step_total = 0 for i_rule_train, rule_train in enumerate(hp['rule_trains']): step = 0 if i_rule_train == 0: display_step = 200 else: display_step = 50 if i_rule_train > 0: # var_list = [v for v in model.var_list # if ('input' in v.name) and ('rnn' not in v.name)] var_list = [ v for v in model.var_list if 'rule_input' in v.name ] model.set_optimizer(var_list=var_list) # Keep training until reach max iterations while (step * hp['batch_size_train'] <= rule_train_iters[i_rule_train]): # Validation if step % display_step == 0: trial = step_total * hp['batch_size_train'] log['trials'].append(trial) log['times'].append(time.time() - t_start) log['rule_now'].append(rule_train) log = do_eval(sess, model, log, rule_train) if log['perf_avg'][-1] > model.hp['target_perf']: print('Perf reached the target: {:0.2f}'.format( hp['target_perf'])) break # Training rule_train_now = hp['rng'].choice(rule_train) # Generate a random batch of trials. # Each batch has the same trial length trial = generate_trials(rule_train_now, hp, 'random', batch_size=hp['batch_size_train']) # Generating feed_dict. feed_dict = tools.gen_feed_dict(model, trial, hp) # This will compute the gradient BEFORE train step _ = sess.run(model.train_step, feed_dict=feed_dict) step += 1 step_total += 1 print("Optimization Finished!")
def do_eval(sess, model, log, rule_train): """Do evaluation. Args: sess: tensorflow session model: Model class instance log: dictionary that stores the log rule_train: string or list of strings, the rules being trained """ hp = model.hp if not hasattr(rule_train, '__iter__'): rule_name_print = rule_train else: rule_name_print = ' & '.join(rule_train) print('Trial {:7d}'.format(log['trials'][-1]) + ' | Time {:0.2f} s'.format(log['times'][-1]) + ' | Now training ' + rule_name_print) for rule_test in hp['rules']: n_rep = 16 batch_size_test_rep = int(hp['batch_size_test'] / n_rep) clsq_tmp = list() creg_tmp = list() perf_tmp = list() for i_rep in range(n_rep): trial = generate_trials(rule_test, hp, 'random', batch_size=batch_size_test_rep) feed_dict = tools.gen_feed_dict(model, trial, hp) c_lsq, c_reg, y_hat_test = sess.run( [model.cost_lsq, model.cost_reg, model.y_hat], feed_dict=feed_dict) # Cost is first summed over time, # and averaged across batch and units # We did the averaging over time through c_mask perf_test = np.mean(get_perf(y_hat_test, trial.y_loc)) clsq_tmp.append(c_lsq) creg_tmp.append(c_reg) perf_tmp.append(perf_test) log['cost_' + rule_test].append(np.mean(clsq_tmp, dtype=np.float64)) log['creg_' + rule_test].append(np.mean(creg_tmp, dtype=np.float64)) log['perf_' + rule_test].append(np.mean(perf_tmp, dtype=np.float64)) print('{:15s}'.format(rule_test) + '| cost {:0.6f}'.format(np.mean(clsq_tmp)) + '| c_reg {:0.6f}'.format(np.mean(creg_tmp)) + ' | perf {:0.2f}'.format(np.mean(perf_tmp))) sys.stdout.flush() # TODO: This needs to be fixed since now rules are strings if hasattr(rule_train, '__iter__'): rule_tmp = rule_train else: rule_tmp = [rule_train] perf_tests_mean = np.mean([log['perf_' + r][-1] for r in rule_tmp]) log['perf_avg'].append(perf_tests_mean) perf_tests_min = np.min([log['perf_' + r][-1] for r in rule_tmp]) log['perf_min'].append(perf_tests_min) # Saving the model model.save() tools.save_log(log) return log
def train( model_dir, hp=None, max_steps=1e7, display_step=500, ruleset='mante', rule_trains=None, rule_prob_map=None, seed=0, rich_output=False, load_dir=None, trainables=None, ): """Train the network. Args: model_dir: str, training directory hp: dictionary of hyperparameters max_steps: int, maximum number of training steps display_step: int, display steps ruleset: the set of rules to train rule_trains: list of rules to train, if None then all rules possible rule_prob_map: None or dictionary of relative rule probability seed: int, random seed to be used Returns: model is stored at model_dir/model.ckpt training configuration is stored at model_dir/hp.json """ tools.mkdir_p(model_dir) # Network parameters default_hp = get_default_hp(ruleset) if hp is not None: default_hp.update(hp) hp = default_hp hp['seed'] = seed hp['rng'] = np.random.RandomState(seed) # Rules to train and test. Rules in a set are trained together if rule_trains is None: # By default, training all rules available to this ruleset hp['rule_trains'] = task.rules_dict[ruleset] else: hp['rule_trains'] = rule_trains hp['rules'] = hp['rule_trains'] # Assign probabilities for rule_trains. if rule_prob_map is None: rule_prob_map = dict() # Turn into rule_trains format hp['rule_probs'] = None if hasattr(hp['rule_trains'], '__iter__'): # Set default as 1. rule_prob = np.array( [rule_prob_map.get(r, 1.) for r in hp['rule_trains']]) hp['rule_probs'] = list(rule_prob / np.sum(rule_prob)) tools.save_hp(hp, model_dir) # Build the model model = Model(model_dir, hp=hp) # Display hp for key, val in hp.items(): print('{:20s} = '.format(key) + str(val)) # Store results log = defaultdict(list) log['model_dir'] = model_dir # Record time t_start = time.time() with tf.Session() as sess: if load_dir is not None: model.restore(load_dir) # complete restore else: # Assume everything is restored sess.run(tf.global_variables_initializer()) # Set trainable parameters if trainables is None or trainables == 'all': var_list = model.var_list # train everything elif trainables == 'input': # train all nputs var_list = [ v for v in model.var_list if ('input' in v.name) and ('rnn' not in v.name) ] elif trainables == 'rule': # train rule inputs only var_list = [v for v in model.var_list if 'rule_input' in v.name] else: raise ValueError('Unknown trainables') model.set_optimizer(var_list=var_list) # penalty on deviation from initial weight if hp['l2_weight_init'] > 0: anchor_ws = sess.run(model.weight_list) for w, w_val in zip(model.weight_list, anchor_ws): model.cost_reg += (hp['l2_weight_init'] * tf.nn.l2_loss(w - w_val)) model.set_optimizer(var_list=var_list) # partial weight training if ('p_weight_train' in hp and (hp['p_weight_train'] is not None) and hp['p_weight_train'] < 1.0): for w in model.weight_list: w_val = sess.run(w) w_size = sess.run(tf.size(w)) w_mask_tmp = np.linspace(0, 1, w_size) hp['rng'].shuffle(w_mask_tmp) ind_fix = w_mask_tmp > hp['p_weight_train'] w_mask = np.zeros(w_size, dtype=np.float32) w_mask[ind_fix] = 1e-1 # will be squared in l2_loss w_mask = tf.constant(w_mask) w_mask = tf.reshape(w_mask, w.shape) model.cost_reg += tf.nn.l2_loss((w - w_val) * w_mask) model.set_optimizer(var_list=var_list) step = 0 while step * hp['batch_size_train'] <= max_steps: try: # Validation if step % display_step == 0: log['trials'].append(step * hp['batch_size_train']) log['times'].append(time.time() - t_start) log = do_eval(sess, model, log, hp['rule_trains']) #if log['perf_avg'][-1] > model.hp['target_perf']: #check if minimum performance is above target if log['perf_min'][-1] > model.hp['target_perf']: print('Perf reached the target: {:0.2f}'.format( hp['target_perf'])) break if rich_output: display_rich_output(model, sess, step, log, model_dir) # Training rule_train_now = hp['rng'].choice(hp['rule_trains'], p=hp['rule_probs']) # Generate a random batch of trials. # Each batch has the same trial length trial = generate_trials(rule_train_now, hp, 'random', batch_size=hp['batch_size_train']) # Generating feed_dict. feed_dict = tools.gen_feed_dict(model, trial, hp) sess.run(model.train_step, feed_dict=feed_dict) step += 1 except KeyboardInterrupt: print("Optimization interrupted by user") break print("Optimization finished!")
def train_sequential_orthogonalized(model_dir, rule_trains, hp=None, max_steps=1e7, display_step=500, rich_output=False, ruleset='mante', applyProj='both', seed=0, nEpisodeBatches=100, projGrad=True, alpha=0.001, fixReadout=False): '''Train the network sequentially. Args: model_dir: str, training directory rule_trains: a list of list of tasks to train sequentially hp: dictionary of hyperparameters max_steps: int, maximum number of training steps for each list of tasks display_step: int, display steps ruleset: the set of rules to train seed: int, random seed to be used Returns: model is stored at model_dir/model.ckpt training configuration is stored at model_dir/hp.json ''' tools.mkdir_p(model_dir) # Network parameters default_hp = get_default_hp(ruleset) if hp is not None: default_hp.update(hp) hp = default_hp hp['seed'] = seed hp['rng'] = np.random.RandomState(seed) hp['rule_trains'] = rule_trains # Get all rules by flattening the list of lists # hp['rules'] = [r for rs in rule_trains for r in rs] hp['rules'] = rule_trains # save some other parameters hp['alpha_projection'] = alpha hp['max_steps'] = max_steps # Number of training iterations for each rule rule_train_iters = [max_steps for _ in rule_trains] tools.save_hp(hp, model_dir) # Display hp for key, val in hp.items(): print('{:20s} = '.format(key) + str(val)) # Build the model model = Sequential_Model(model_dir, projGrad=projGrad, applyProj=applyProj, hp=hp) # Store results log = defaultdict(list) log['model_dir'] = model_dir # Record time t_start = time.time() def relu(x): return x * (x > 0.) # ------------------------------------------------------- # Use customized session that launches the graph as well with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # penalty on deviation from initial weight if hp['l2_weight_init'] > 0: raise NotImplementedError() # Looping step_total = 0 taskNumber = 0 if fixReadout is True: my_var_list = [ var for var in model.var_list if 'rnn/leaky_rnn_cell/kernel:0' in var.name ] else: my_var_list = [ var for var in model.var_list if 'rnn/leaky_rnn_cell/kernel:0' in var.name or 'output/weights:0' in var.name ] # initialise projection matrices input_proj = tf.zeros( (hp['n_rnn'] + hp['n_input'], hp['n_rnn'] + hp['n_input'])) activity_proj = tf.zeros((hp['n_rnn'], hp['n_rnn'])) output_proj = tf.zeros((hp['n_output'], hp['n_output'])) recurrent_proj = tf.zeros((hp['n_rnn'], hp['n_rnn'])) for i_rule_train, rule_train in enumerate(hp['rule_trains']): step = 0 model.set_optimizer(activity_proj=activity_proj, input_proj=input_proj, output_proj=output_proj, recurrent_proj=recurrent_proj, taskNumber=taskNumber, var_list=my_var_list, alpha=alpha) # Keep training until reach max iterations while (step * hp['batch_size_train'] <= rule_train_iters[i_rule_train]): # Validation if step % display_step == 0: trial = step_total * hp['batch_size_train'] log['trials'].append(trial) log['times'].append(time.time() - t_start) log['rule_now'].append(rule_train) log = do_eval(sess, model, log, rule_train) if log['perf_avg'][-1] > model.hp['target_perf']: print('Perf reached the target: {:0.2f}'.format( hp['target_perf'])) break # Training # rule_train_now = hp['rng'].choice(rule_train) # Generate a random batch of trials. # Each batch has the same trial length trial = generate_trials(rule_train, hp, 'random', batch_size=hp['batch_size_train'], delay_fac=hp['delay_fac']) # Generating feed_dict. feed_dict = tools.gen_feed_dict(model, trial, hp) # update model sess.run(model.train_step, feed_dict=feed_dict) # # Get the weight after train step # v_current = sess.run(model.var_list) step += 1 step_total += 1 if step % display_step == 0: model.save_ckpt(step_total) # ---------- save model after its completed training the current task ---------- model.save_after_task(taskNumber) # ---------- generate task activity for continual learning ------- trial = generate_trials(rule_train, hp, 'random', batch_size=hp['batch_size_test'], delay_fac=hp['delay_fac']) # Generating feed_dict. feed_dict = tools.gen_feed_dict(model, trial, hp) eval_h, eval_x, eval_y, Wrec, Win = sess.run( [model.h, model.x, model.y, model.w_rec, model.w_in], feed_dict=feed_dict) full_state = np.concatenate([eval_x, eval_h], -1) # get weight matrix after current task Wfull = np.concatenate([Win, Wrec], 0) # joint covariance matrix of input and activity Shx_task = compute_covariance( np.reshape(full_state, (-1, hp['n_rnn'] + hp['n_input'])).T) # covariance matrix of output Sy_task = compute_covariance( np.reshape(eval_y, (-1, hp['n_output'])).T) # get block matrices from Shx_task # Sh_task = Shx_task[-hp['n_rnn']:, -hp['n_rnn']:] Sh_task = np.matmul(np.matmul(Wfull.T, Shx_task), Wfull) # ---------- update stored covariance matrices for continual learning ------- if taskNumber == 0: input_cov = Shx_task activity_cov = Sh_task output_cov = Sy_task else: input_cov = taskNumber / ( taskNumber + 1) * input_cov + Shx_task / (taskNumber + 1) activity_cov = taskNumber / ( taskNumber + 1) * activity_cov + Sh_task / (taskNumber + 1) output_cov = taskNumber / ( taskNumber + 1) * output_cov + Sy_task / (taskNumber + 1) # ---------- update projection matrices for continual learning ---------- activity_proj, input_proj, output_proj, recurrent_proj = compute_projection_matrices( activity_cov, input_cov, output_cov, input_cov[-hp['n_rnn']:, -hp['n_rnn']:], alpha) # update task number taskNumber += 1 print("Optimization Finished!")
with tf.Session() as sess: model.restore() model._sigma = 0 # get all connection weights and biases as tensorflow variables var_list = model.var_list # evaluate the parameters after training params = [sess.run(var) for var in var_list] # get hparams hparams = model.hp # create a trial trial = generate_trials(rule, hparams, mode='test', noise_on=False, batch_size=40) # get feed_dict feed_dict = tools.gen_feed_dict(model, trial, hparams) # run model h_tf, y_hat_tf = sess.run( [model.h, model.y_hat], feed_dict=feed_dict) # (n_time, n_condition, n_neuron) ################################################################## # get shapes n_steps, n_trials, n_input_dim = np.shape(trial.x) n_rnn = np.shape(h_tf)[2] n_output = np.shape(y_hat_tf)[2] # Fixed point finder hyperparameters # See FixedPointFinder.py for detailed descriptions of available # hyperparameters. fpf_hps = {}
def pretty_inputoutput_plot(model_dir, rule, save=False, plot_ylabel=False): """Plot the input and output activity for a sample trial from one task. Args: model_dir: model directory rule: string, the rule save: bool, whether to save plots plot_ylabel: bool, whether to plot ylable """ fs = 7 model = Model(model_dir) hp = model.hp with tf.Session() as sess: model.restore() trial = generate_trials(rule, hp, mode='test') x, y = trial.x, trial.y feed_dict = tools.gen_feed_dict(model, trial, hp) h, y_hat = sess.run([model.h, model.y_hat], feed_dict=feed_dict) t_plot = np.arange(x.shape[0]) * hp['dt'] / 1000 assert hp['num_ring'] == 2 n_eachring = hp['n_eachring'] fig = plt.figure(figsize=(1.3, 2)) ylabels = ['fix. in', 'stim. mod1', 'stim. mod2', 'fix. out', 'out'] heights = np.array([0.03, 0.2, 0.2, 0.03, 0.2]) + 0.01 for i in range(5): ax = fig.add_axes( [0.15, sum(heights[i + 1:] + 0.02) + 0.1, 0.8, heights[i]]) cmap = 'Purples' plt.xticks([]) ax.tick_params(axis='both', which='major', labelsize=fs, width=0.5, length=2, pad=3) if plot_ylabel: ax.spines["right"].set_visible(False) ax.spines["bottom"].set_visible(False) ax.spines["top"].set_visible(False) ax.xaxis.set_ticks_position('bottom') ax.yaxis.set_ticks_position('left') else: ax.spines["left"].set_visible(False) ax.spines["right"].set_visible(False) ax.spines["bottom"].set_visible(False) ax.spines["top"].set_visible(False) ax.xaxis.set_ticks_position('none') if i == 0: plt.plot(t_plot, x[:, 0, 0], color='xkcd:blue') if plot_ylabel: plt.yticks([0, 1], ['', ''], rotation='vertical') plt.ylim([-0.1, 1.5]) plt.title(rule_name[rule], fontsize=fs) elif i == 1: plt.imshow(x[:, 0, 1:1 + n_eachring].T, aspect='auto', cmap=cmap, vmin=0, vmax=1, interpolation='none', origin='lower') if plot_ylabel: plt.yticks( [0, (n_eachring - 1) / 2, n_eachring - 1], [r'0$\degree$', r'180$\degree$', r'360$\degree$'], rotation='vertical') elif i == 2: plt.imshow(x[:, 0, 1 + n_eachring:1 + 2 * n_eachring].T, aspect='auto', cmap=cmap, vmin=0, vmax=1, interpolation='none', origin='lower') if plot_ylabel: plt.yticks( [0, (n_eachring - 1) / 2, n_eachring - 1], [r'0$\degree$', r'180$\degree$', r'360$\degree$'], rotation='vertical') elif i == 3: plt.plot(t_plot, y[:, 0, 0], color='xkcd:green') plt.plot(t_plot, y_hat[:, 0, 0], color='xkcd:blue') if plot_ylabel: plt.yticks([0.05, 0.8], ['', ''], rotation='vertical') plt.ylim([-0.1, 1.1]) elif i == 4: plt.imshow(y_hat[:, 0, 1:].T, aspect='auto', cmap=cmap, vmin=0, vmax=1, interpolation='none', origin='lower') if plot_ylabel: plt.yticks( [0, (n_eachring - 1) / 2, n_eachring - 1], [r'0$\degree$', r'180$\degree$', r'360$\degree$'], rotation='vertical') plt.xticks([0, y_hat.shape[0]], ['0', '2']) plt.xlabel('Time (s)', fontsize=fs, labelpad=-3) ax.spines["bottom"].set_visible(True) if plot_ylabel: plt.ylabel(ylabels[i], fontsize=fs) else: plt.yticks([]) ax.get_yaxis().set_label_coords(-0.12, 0.5) if save: save_name = 'figure/sample_' + rule_name[rule].replace(' ', '') + '.pdf' plt.savefig(save_name, transparent=True) plt.show()
def schematic_plot(model_dir, rule=None): fontsize = 6 rule = rule or 'dm1' model = Model(model_dir, dt=1) hp = model.hp with tf.Session() as sess: model.restore() trial = generate_trials(rule, hp, mode='test') feed_dict = tools.gen_feed_dict(model, trial, hp) x = trial.x h, y_hat = sess.run([model.h, model.y_hat], feed_dict=feed_dict) n_eachring = hp['n_eachring'] n_hidden = hp['n_rnn'] # Plot Stimulus fig = plt.figure(figsize=(1.0, 1.2)) heights = np.array([0.06, 0.25, 0.25]) for i in range(3): ax = fig.add_axes( [0.2, sum(heights[i + 1:] + 0.1) + 0.05, 0.7, heights[i]]) cmap = 'Purples' plt.xticks([]) # Fixed style for these plots ax.tick_params(axis='both', which='major', labelsize=fontsize, width=0.5, length=2, pad=3) ax.spines["left"].set_linewidth(0.5) ax.spines["right"].set_visible(False) ax.spines["bottom"].set_visible(False) ax.spines["top"].set_visible(False) ax.xaxis.set_ticks_position('bottom') ax.yaxis.set_ticks_position('left') if i == 0: plt.plot(x[:, 0, 0], color='xkcd:blue') plt.yticks([0, 1], ['', ''], rotation='vertical') plt.ylim([-0.1, 1.5]) plt.title('Fixation input', fontsize=fontsize, y=0.9) elif i == 1: plt.imshow(x[:, 0, 1:1 + n_eachring].T, aspect='auto', cmap=cmap, vmin=0, vmax=1, interpolation='none', origin='lower') plt.yticks([0, (n_eachring - 1) / 2, n_eachring - 1], [r'0$\degree$', '', r'360$\degree$'], rotation='vertical') plt.title('Stimulus mod 1', fontsize=fontsize, y=0.9) elif i == 2: plt.imshow(x[:, 0, 1 + n_eachring:1 + 2 * n_eachring].T, aspect='auto', cmap=cmap, vmin=0, vmax=1, interpolation='none', origin='lower') plt.yticks([0, (n_eachring - 1) / 2, n_eachring - 1], ['', '', ''], rotation='vertical') plt.title('Stimulus mod 2', fontsize=fontsize, y=0.9) ax.get_yaxis().set_label_coords(-0.12, 0.5) plt.savefig('figure/schematic_input.pdf', transparent=True) plt.show() # Plot Rule Inputs fig = plt.figure(figsize=(1.0, 0.5)) ax = fig.add_axes([0.2, 0.3, 0.7, 0.45]) cmap = 'Purples' X = x[:, 0, 1 + 2 * n_eachring:] plt.imshow(X.T, aspect='auto', vmin=0, vmax=1, cmap=cmap, interpolation='none', origin='lower') plt.xticks([0, X.shape[0]]) ax.set_xlabel('Time (ms)', fontsize=fontsize, labelpad=-5) # Fixed style for these plots ax.tick_params(axis='both', which='major', labelsize=fontsize, width=0.5, length=2, pad=3) ax.spines["left"].set_linewidth(0.5) ax.spines["right"].set_visible(False) ax.spines["bottom"].set_linewidth(0.5) ax.spines["top"].set_visible(False) ax.xaxis.set_ticks_position('bottom') ax.yaxis.set_ticks_position('left') plt.yticks([0, X.shape[-1] - 1], ['1', str(X.shape[-1])], rotation='vertical') plt.title('Rule inputs', fontsize=fontsize, y=0.9) ax.get_yaxis().set_label_coords(-0.12, 0.5) plt.savefig('figure/schematic_rule.pdf', transparent=True) plt.show() # Plot Units fig = plt.figure(figsize=(1.0, 0.8)) ax = fig.add_axes([0.2, 0.1, 0.7, 0.75]) cmap = 'Purples' plt.xticks([]) # Fixed style for these plots ax.tick_params(axis='both', which='major', labelsize=fontsize, width=0.5, length=2, pad=3) ax.spines["left"].set_linewidth(0.5) ax.spines["right"].set_visible(False) ax.spines["bottom"].set_visible(False) ax.spines["top"].set_visible(False) ax.xaxis.set_ticks_position('bottom') ax.yaxis.set_ticks_position('left') plt.imshow(h[:, 0, :].T, aspect='auto', cmap=cmap, vmin=0, vmax=1, interpolation='none', origin='lower') plt.yticks([0, n_hidden - 1], ['1', str(n_hidden)], rotation='vertical') plt.title('Recurrent units', fontsize=fontsize, y=0.95) ax.get_yaxis().set_label_coords(-0.12, 0.5) plt.savefig('figure/schematic_units.pdf', transparent=True) plt.show() # Plot Outputs fig = plt.figure(figsize=(1.0, 0.8)) heights = np.array([0.1, 0.45]) + 0.01 for i in range(2): ax = fig.add_axes( [0.2, sum(heights[i + 1:] + 0.15) + 0.1, 0.7, heights[i]]) cmap = 'Purples' plt.xticks([]) # Fixed style for these plots ax.tick_params(axis='both', which='major', labelsize=fontsize, width=0.5, length=2, pad=3) ax.spines["left"].set_linewidth(0.5) ax.spines["right"].set_visible(False) ax.spines["bottom"].set_visible(False) ax.spines["top"].set_visible(False) ax.xaxis.set_ticks_position('bottom') ax.yaxis.set_ticks_position('left') if i == 0: plt.plot(y_hat[:, 0, 0], color='xkcd:blue') plt.yticks([0.05, 0.8], ['', ''], rotation='vertical') plt.ylim([-0.1, 1.1]) plt.title('Fixation output', fontsize=fontsize, y=0.9) elif i == 1: plt.imshow(y_hat[:, 0, 1:].T, aspect='auto', cmap=cmap, vmin=0, vmax=1, interpolation='none', origin='lower') plt.yticks([0, (n_eachring - 1) / 2, n_eachring - 1], [r'0$\degree$', '', r'360$\degree$'], rotation='vertical') plt.xticks([]) plt.title('Response', fontsize=fontsize, y=0.9) ax.get_yaxis().set_label_coords(-0.12, 0.5) plt.savefig('figure/schematic_outputs.pdf', transparent=True) plt.show()
def pretty_singleneuron_plot(model_dir, rules, neurons, epoch=None, save=False, ylabel_firstonly=True, trace_only=False, plot_stim_avg=False, save_name=''): """Plot the activity of a single neuron in time across many trials Args: model_dir: rules: rules to plot neurons: indices of neurons to plot epoch: epoch to plot save: save figure? ylabel_firstonly: if True, only plot ylabel for the first rule in rules """ if isinstance(rules, str): rules = [rules] try: _ = iter(neurons) except TypeError: neurons = [neurons] h_tests = dict() model = Model(model_dir) hp = model.hp with tf.Session() as sess: model.restore() t_start = int(500 / hp['dt']) for rule in rules: # Generate a batch of trial from the test mode trial = generate_trials(rule, hp, mode='test') feed_dict = tools.gen_feed_dict(model, trial, hp) h = sess.run(model.h, feed_dict=feed_dict) h_tests[rule] = h for neuron in neurons: h_max = np.max([h_tests[r][t_start:, :, neuron].max() for r in rules]) for j, rule in enumerate(rules): fs = 6 fig = plt.figure(figsize=(1.0, 0.8)) ax = fig.add_axes([0.35, 0.25, 0.55, 0.55]) t_plot = np.arange( h_tests[rule][t_start:].shape[0]) * hp['dt'] / 1000 _ = ax.plot(t_plot, h_tests[rule][t_start:, :, neuron], lw=0.5, color='gray') if plot_stim_avg: # Plot stimulus averaged trace _ = ax.plot(np.arange(h_tests[rule][t_start:].shape[0]) * hp['dt'] / 1000, h_tests[rule][t_start:, :, neuron].mean(axis=1), lw=1, color='black') if epoch is not None: e0, e1 = trial.epochs[epoch] e0 = e0 if e0 is not None else 0 e1 = e1 if e1 is not None else h_tests[rule].shape[0] ax.plot([e0, e1], [h_max * 1.15] * 2, color='black', linewidth=1.5) figname = 'figure/trace_' + rule_name[ rule] + epoch + save_name + '.pdf' else: figname = 'figure/trace_unit' + str( neuron) + rule_name[rule] + save_name + '.pdf' plt.ylim(np.array([-0.1, 1.2]) * h_max) plt.xticks([0, np.floor(np.max(t_plot) + 0.01)]) plt.xlabel('Time (s)', fontsize=fs, labelpad=-5) plt.locator_params(axis='y', nbins=4) if j > 0 and ylabel_firstonly: ax.set_yticklabels([]) else: plt.ylabel('Activitity (a.u.)', fontsize=fs, labelpad=2) plt.title('Unit {:d} '.format(neuron) + rule_name[rule], fontsize=5) ax.tick_params(axis='both', which='major', labelsize=fs) ax.spines["right"].set_visible(False) ax.spines["top"].set_visible(False) ax.xaxis.set_ticks_position('bottom') ax.yaxis.set_ticks_position('left') if trace_only: ax.spines["left"].set_visible(False) ax.spines["bottom"].set_visible(False) ax.xaxis.set_ticks_position('none') ax.set_xlabel('') ax.set_ylabel('') ax.set_xticks([]) ax.set_yticks([]) ax.set_title('') if save: plt.savefig(figname, transparent=True) plt.show()
def _compute_variance_bymodel(model, sess, rules=None, random_rotation=False): """Compute variance for all tasks. Args: model: network.Model instance sess: tensorflow session rules: list of rules to compute variance, list of strings random_rotation: boolean. If True, rotate the neural activity. """ h_all_byrule = OrderedDict() h_all_byepoch = OrderedDict() hp = model.hp if rules is None: rules = hp['rules'] print(rules) n_hidden = hp['n_rnn'] if random_rotation: # Generate random orthogonal matrix from scipy.stats import ortho_group random_ortho_matrix = ortho_group.rvs(dim=n_hidden) for rule in rules: trial = generate_trials(rule, hp, 'test', noise_on=False) feed_dict = tools.gen_feed_dict(model, trial, hp) h = sess.run(model.h, feed_dict=feed_dict) if random_rotation: h = np.dot(h, random_ortho_matrix) # randomly rotate for e_name, e_time in trial.epochs.items(): if 'fix' not in e_name: # Ignore fixation period h_all_byepoch[(rule, e_name)] = h[e_time[0]:e_time[1], :, :] # Ignore fixation period h_all_byrule[rule] = h[trial.epochs['fix1'][1]:, :, :] # Reorder h_all_byepoch by epoch-first keys = list(h_all_byepoch.keys()) # ind_key_sort = np.lexsort(zip(*keys)) # Using mergesort because it is stable ind_key_sort = np.argsort(list(zip(*keys))[1], kind='mergesort') h_all_byepoch = OrderedDict([(keys[i], h_all_byepoch[keys[i]]) for i in ind_key_sort]) for data_type in ['rule', 'epoch']: if data_type == 'rule': h_all = h_all_byrule elif data_type == 'epoch': h_all = h_all_byepoch else: raise ValueError h_var_all = np.zeros((n_hidden, len(h_all.keys()))) for i, val in enumerate(h_all.values()): # val is Time, Batch, Units # Variance across time and stimulus # h_var_all[:, i] = val[t_start:].reshape((-1, n_hidden)).var(axis=0) # Variance acros stimulus, then averaged across time h_var_all[:, i] = val.var(axis=1).mean(axis=0) result = {'h_var_all': h_var_all, 'keys': list(h_all.keys())} save_name = 'variance_' + data_type if random_rotation: save_name += '_rr' fname = os.path.join(model.model_dir, save_name + '.pkl') print('Variance saved at {:s}'.format(fname)) with open(fname, 'wb') as f: pickle.dump(result, f)
def psychometric_choicefamily_2D(model_dir, rule, lesion_units=None, n_coh=8, n_stim_loc=20, coh_range=0.1): # Generate task parameters for choice tasks # coh_range = 0.2 # coh_range = 0.05 cohs = np.linspace(-coh_range, coh_range, n_coh) batch_size = n_stim_loc * n_coh**2 batch_shape = (n_stim_loc, n_coh, n_coh) ind_stim_loc, ind_stim_mod1, ind_stim_mod2 = np.unravel_index( range(batch_size), batch_shape) # Looping target location stim1_locs = 2 * np.pi * ind_stim_loc / n_stim_loc stim2_locs = (stim1_locs + np.pi) % (2 * np.pi) stim_mod1_cohs = cohs[ind_stim_mod1] stim_mod2_cohs = cohs[ind_stim_mod2] params_dict = dict() params_dict['dm1'] = \ {'stim1_locs' : stim1_locs, 'stim2_locs' : stim2_locs, 'stim1_strengths' : 1 + stim_mod1_cohs, # Just use mod 1 value 'stim2_strengths' : 1 - stim_mod1_cohs, 'stim_time' : 800 } params_dict['dm2'] = params_dict['dm1'] params_dict['contextdm1'] = \ {'stim1_locs' : stim1_locs, 'stim2_locs' : stim2_locs, 'stim1_mod1_strengths' : 1 + stim_mod1_cohs, 'stim2_mod1_strengths' : 1 - stim_mod1_cohs, 'stim1_mod2_strengths' : 1 + stim_mod2_cohs, 'stim2_mod2_strengths' : 1 - stim_mod2_cohs, 'stim_time' : 800 } params_dict['contextdm2'] = params_dict['contextdm1'] params_dict['contextdelaydm1'] = params_dict['contextdm1'] params_dict['contextdelaydm1']['stim_time'] = 800 params_dict['contextdelaydm2'] = params_dict['contextdelaydm1'] params_dict['multidm'] = \ {'stim1_locs' : stim1_locs, 'stim2_locs' : stim2_locs, 'stim1_mod1_strengths' : 1 + stim_mod1_cohs, 'stim2_mod1_strengths' : 1 - stim_mod1_cohs, 'stim1_mod2_strengths' : 1 + stim_mod1_cohs, # Same as Mod 1 'stim2_mod2_strengths' : 1 - stim_mod1_cohs, 'stim_time' : 800 } params_dict['contextdelaydm1'] = \ {'stim1_locs' : stim1_locs, 'stim2_locs' : stim2_locs, 'stim1_mod1_strengths' : 1 + stim_mod1_cohs, 'stim2_mod1_strengths' : 1 - stim_mod1_cohs, 'stim1_mod2_strengths' : 1 + stim_mod2_cohs, 'stim2_mod2_strengths' : 1 - stim_mod2_cohs, 'stim_time' : 800 } model = Model(model_dir) hp = model.hp with tf.Session() as sess: model.restore() model.lesion_units(sess, lesion_units) params = params_dict[rule] trial = generate_trials(rule, hp, 'psychometric', params=params) feed_dict = tools.gen_feed_dict(model, trial, hp) y_sample, y_loc_sample = sess.run([model.y_hat, model.y_hat_loc], feed_dict=feed_dict) # Compute the overall performance. # Importantly, discard trials where no decision was made loc_cor = trial.y_loc[-1] # last time point, correct locations loc_err = (loc_cor + np.pi) % (2 * np.pi) choose_cor = (get_dist(y_loc_sample[-1] - loc_cor) < THETA).sum() choose_err = (get_dist(y_loc_sample[-1] - loc_err) < THETA).sum() perf = choose_cor / (choose_cor + choose_err) # Compute the proportion of choosing choice 1 and maintain the batch_shape stim1_locs_ = np.reshape(stim1_locs, batch_shape) stim2_locs_ = np.reshape(stim2_locs, batch_shape) y_loc_sample = np.reshape(y_loc_sample[-1], batch_shape) choose1 = (get_dist(y_loc_sample - stim1_locs_) < THETA).sum(axis=0) choose2 = (get_dist(y_loc_sample - stim2_locs_) < THETA).sum(axis=0) prop1s = choose1 / (choose1 + choose2) return perf, prop1s, cohs