def get_avg_performance(model_dirs, rule): """Get average performance across trials for model_dirs. Some networks converge earlier than others. For those converged early, choose the last performance for later performance """ perfs = defaultdict(list) trials = [] for model_dir in model_dirs: log = tools.load_log(model_dir) trials += list(log['trials']) trials = np.sort(np.unique(trials)) for model_dir in model_dirs: log = tools.load_log(model_dir) for t in trials: if t in log['trials']: ind = log['trials'].index(t) else: ind = -1 perfs[t].append(log['perf_' + rule][ind]) # for t, perf in zip(log['trials'], log['perf_'+rule]): # perfs[t].append(perf) # average performances trials = list(perfs.keys()) trials = np.sort(trials) avg_perfs = [np.mean(perfs[t]) for t in trials] return avg_perfs, trials
def get_allperformance(model_dirs, param_list=None): # Get all model names that match patterns (strip off .ckpt.meta at the end) model_dirs = tools.valid_model_dirs(model_dirs) final_perfs = dict() filenames = dict() if param_list is None: param_list = ['param_intsyn', 'easy_task', 'activation'] for model_dir in model_dirs: log = tools.load_log(model_dir) hp = tools.load_hp(model_dir) perf_tests = log['perf_tests'] final_perf = np.mean([float(val[-1]) for val in perf_tests.values()]) key = tuple([hp[p] for p in param_list]) if key in final_perfs.keys(): final_perfs[key].append(final_perf) else: final_perfs[key] = [final_perf] filenames[key] = model_dir for key, val in final_perfs.items(): final_perfs[key] = np.mean(val) print(key), print('{:0.3f}'.format(final_perfs[key])), print(filenames[key])
def train_all_tanhgru(seed=0, model_dir='tanhgru'): """Training of all tasks with Tanh GRUs.""" model_dir = os.path.join(DATAPATH, model_dir, str(seed)) hp = {'activation': 'tanh', 'rnn_type': 'LeakyGRU'} rule_prob_map = {'contextdm1': 5, 'contextdm2': 5} train.train(model_dir, hp=hp, ruleset='all', rule_prob_map=rule_prob_map, seed=seed) # Analyses variance.compute_variance(model_dir) log = tools.load_log(model_dir) analysis = clustering.Analysis(model_dir, 'rule') log['n_cluster'] = analysis.n_cluster tools.save_log(log) data_analysis.compute_var_all(model_dir) setups = [1, 2, 3] for setup in setups: taskset.compute_taskspace(model_dir, setup, restore=False, representation='rate') taskset.compute_replacerule_performance(model_dir, setup, False)
def __init__(self, model_dir): self.model_dir = model_dir self.log = tools.load_log(model_dir) self.hp = tools.load_hp(self.log['model_dir']) self.neurons = self.hp['n_rnn'] self.print_basic_info()
def train_all_mixrule_softplus(seed=0, root_dir='mixrule_softplus'): """Training of all tasks.""" model_dir = os.path.join(DATAPATH, root_dir, str(seed)) hp = { 'activation': 'softplus', 'w_rec_init': 'diag', 'use_separate_input': True, 'mix_rule': True } rule_prob_map = {'contextdm1': 5, 'contextdm2': 5} train.train(model_dir, hp=hp, ruleset='all', rule_prob_map=rule_prob_map, seed=seed) # Analyses variance.compute_variance(model_dir) log = tools.load_log(model_dir) analysis = clustering.Analysis(model_dir, 'rule') log['n_cluster'] = analysis.n_cluster tools.save_log(log) setups = [1, 2, 3] for setup in setups: taskset.compute_taskspace(model_dir, setup, restore=False, representation='rate') taskset.compute_replacerule_performance(model_dir, setup, False)
def mante_tanh(seed=0, model_dir='mante_tanh'): """Training of only the Mante task.""" hp = {'activation': 'tanh', 'target_perf': 0.9} model_dir = os.path.join(DATAPATH, model_dir, str(seed)) train.train(model_dir, hp=hp, ruleset='mante', seed=seed) # Analyses variance.compute_variance(model_dir) log = tools.load_log(model_dir) analysis = clustering.Analysis(model_dir, 'rule') log['n_cluster'] = analysis.n_cluster tools.save_log(log) data_analysis.compute_var_all(model_dir)
def get_n_clusters(root_dir): model_dirs = tools.valid_model_dirs(root_dir) hp_list = list() n_clusters = list() for i, model_dir in enumerate(model_dirs): if i % 50 == 0: print('Analyzing model {:d}/{:d}'.format(i, len(model_dirs))) hp = tools.load_hp(model_dir) log = tools.load_log(model_dir) # check if performance exceeds target if log['perf_min'][-1] > hp['target_perf']: n_clusters.append(log['n_cluster']) hp_list.append(hp) return n_clusters, hp_list
def plot_performanceprogress(model_dir, rule_plot=None): # Plot Training Progress log = tools.load_log(model_dir) hp = tools.load_hp(model_dir) trials = log['trials'] fs = 6 # fontsize fig = plt.figure(figsize=(3.5, 1.2)) ax = fig.add_axes([0.1, 0.25, 0.35, 0.6]) lines = list() labels = list() x_plot = np.array(trials) / 1000. if rule_plot == None: rule_plot = hp['rules'] for i, rule in enumerate(rule_plot): # line = ax1.plot(x_plot, np.log10(cost_tests[rule]),color=color_rules[i%26]) # ax2.plot(x_plot, perf_tests[rule],color=color_rules[i%26]) line = ax.plot(x_plot, np.log10(log['cost_' + rule]), color=rule_color[rule]) ax.plot(x_plot, log['perf_' + rule], color=rule_color[rule]) lines.append(line[0]) labels.append(rule_name[rule]) ax.tick_params(axis='both', which='major', labelsize=fs) ax.set_ylim([0, 1]) ax.set_xlabel('Total trials (1,000)', fontsize=fs, labelpad=2) ax.set_ylabel('Performance', fontsize=fs, labelpad=0) ax.locator_params(axis='x', nbins=3) ax.set_yticks([0, 1]) ax.spines["right"].set_visible(False) ax.spines["top"].set_visible(False) ax.xaxis.set_ticks_position('bottom') ax.yaxis.set_ticks_position('left') lg = fig.legend(lines, labels, title='Task', ncol=2, bbox_to_anchor=(0.47, 0.5), fontsize=fs, labelspacing=0.3, loc=6, frameon=False) plt.setp(lg.get_title(), fontsize=fs) plt.savefig('figure/Performance_Progresss.pdf', transparent=True) plt.show()
def train_vary_hp(i): """Vary the hyperparameters. This experiment loops over a set of hyperparameters. Args: i: int, the index of the hyperparameters list """ # Ranges of hyperparameters to loop over hp_ranges = OrderedDict() # hp_ranges['activation'] = ['softplus', 'relu', 'tanh', 'retanh'] # hp_ranges['rnn_type'] = ['LeakyRNN', 'LeakyGRU'] # hp_ranges['w_rec_init'] = ['diag', 'randortho'] hp_ranges['activation'] = ['softplus'] hp_ranges['rnn_type'] = ['LeakyRNN'] hp_ranges['w_rec_init'] = ['randortho'] hp_ranges['l1_h'] = [0, 1e-9, 1e-8, 1e-7, 1e-6] # TODO(gryang): Change this? hp_ranges['l2_h'] = [0] hp_ranges['l1_weight'] = [0, 1e-7, 1e-6, 1e-5] # TODO(gryang): add the level of overtraining # Unravel the input index keys = hp_ranges.keys() dims = [len(hp_ranges[k]) for k in keys] n_max = np.prod(dims) indices = np.unravel_index(i % n_max, dims=dims) # Set up new hyperparameter hp = dict() for key, index in zip(keys, indices): hp[key] = hp_ranges[key][index] model_dir = os.path.join(DATAPATH, 'varyhp_reg2', str(i)) rule_prob_map = {'contextdm1': 5, 'contextdm2': 5} train.train(model_dir, hp, ruleset='all', rule_prob_map=rule_prob_map, seed=i // n_max) # Analyses variance.compute_variance(model_dir) log = tools.load_log(model_dir) analysis = clustering.Analysis(model_dir, 'rule') log['n_cluster'] = analysis.n_cluster tools.save_log(log) data_analysis.compute_var_all(model_dir)
def compute_n_cluster(model_dirs): for model_dir in model_dirs: print(model_dir) log = tools.load_log(model_dir) hp = tools.load_hp(model_dir) try: analysis = clustering.Analysis(model_dir, 'rule') log['n_cluster'] = analysis.n_cluster log['model_dir'] = model_dir tools.save_log(log) except IOError: # Training never finished assert log['perf_min'][-1] <= hp['target_perf'] # analysis.plot_example_unit() # analysis.plot_variance() # analysis.plot_2Dvisualization() print("done")
def train_all_analysis(seed=0, root_dir='train_all'): model_dir = os.path.join(DATAPATH, root_dir, str(seed)) # Analyses variance.compute_variance(model_dir) variance.compute_variance(model_dir, random_rotation=True) log = tools.load_log(model_dir) analysis = clustering.Analysis(model_dir, 'rule') log['n_cluster'] = analysis.n_cluster tools.save_log(log) data_analysis.compute_var_all(model_dir) for rule in ['dm1', 'contextdm1', 'multidm']: performance.compute_choicefamily_varytime(model_dir, rule) setups = [1, 2, 3] for setup in setups: taskset.compute_taskspace(model_dir, setup, restore=False, representation='rate') taskset.compute_replacerule_performance(model_dir, setup, False)
def _base_vary_hp_mante(i, hp_ranges, base_name): """Vary hyperparameters for mante tasks.""" # Unravel the input index keys = hp_ranges.keys() dims = [len(hp_ranges[k]) for k in keys] n_max = np.prod(dims) indices = np.unravel_index(i % n_max, dims=dims) # Set up new hyperparameter hp = dict() for key, index in zip(keys, indices): hp[key] = hp_ranges[key][index] model_dir = os.path.join(DATAPATH, base_name, str(i)) train.train(model_dir, hp, ruleset='mante', max_steps=1e7, seed=i // n_max) # Analyses variance.compute_variance(model_dir) log = tools.load_log(model_dir) analysis = clustering.Analysis(model_dir, 'rule') log['n_cluster'] = analysis.n_cluster tools.save_log(log) data_analysis.compute_var_all(model_dir)
def get_finalperformance(model_dirs): """Get lists of final performance.""" hp = tools.load_hp(model_dirs[0]) rule_plot = hp['rules'] final_cost, final_perf = OrderedDict(), OrderedDict() for rule in rule_plot: final_cost[rule] = list() final_perf[rule] = list() training_time_plot = list() # Recording performance and cost for networks for model_dir in model_dirs: log = tools.load_log(model_dir) if log is None: continue for rule in rule_plot: final_perf[rule] += [float(log['perf_' + rule][-1])] final_cost[rule] += [float(log['cost_' + rule][-1])] training_time_plot.append(log['times'][-1]) return final_cost, final_perf, rule_plot, training_time_plot
def plot_posttrain_performance(posttrain_setup, trainables): from task import rule_name hp_target = {'posttrain_setup': posttrain_setup, 'trainables': trainables} fs = 7 fig = plt.figure(figsize=(1.5, 1.2)) ax = fig.add_axes([0.25, 0.3, 0.7, 0.65]) colors = ['xkcd:blue', 'xkcd:red'] for pretrain_setup in [1, 0]: c = colors[pretrain_setup] l = ['B', 'A'][pretrain_setup] hp_target['pretrain_setup'] = pretrain_setup model_dirs = tools.find_all_models(DATAPATH, hp_target) hp = tools.load_hp(model_dirs[0]) rule = hp['rule_trains'][0] # depends on posttrain setup for model_dir in model_dirs: log = tools.load_log(model_dir) ax.plot(np.array(log['trials']) / 1000., log['perf_' + rule], color=c, alpha=0.1) avg_perfs, trials = get_avg_performance(model_dirs, rule) l0 = ax.plot(trials / 1000., avg_perfs, color=c, label=l) ax.set_ylim([0, 1]) ax.set_xlabel('Total trials (1,000)', fontsize=fs, labelpad=2) ax.set_yticks([0, 1]) ax.spines["right"].set_visible(False) ax.spines["top"].set_visible(False) # lg = ax.legend(title='Pretrained set', ncol=2, loc=4, # frameon=False) plt.ylabel('Perf. of ' + rule_name[rule]) # plt.title('Training ' + hp_target['trainables']) plt.savefig('figure/Posttrain_post{:d}train{:s}.pdf'.format( posttrain_setup, trainables), transparent=True)
def plot_fracvar_hist_byhp(hp_vary, save_name=None, mode='all_var', legend=True): """Plot how fractional variance distribution depends on hparams.""" hp_target = {'activation': 'softplus', 'rnn_type': 'LeakyRNN', 'w_rec_init': 'randortho', } if hp_vary == 'l2_weight_init': root_dir = './data/vary_l2init_mante' title = r'$L_2$ initial weight' hp_vary_vals = [0, 8*1e-4] ylim = [0, 0.3] n = len(hp_vary_vals) colors = [mpl.cm.cool(i * 1.0 / n) for i in range(n)] elif hp_vary == 'l2_weight': root_dir = './data/vary_l2weight_mante' title = r'$L_2$ weight' hp_vary_vals = [0, 8 * 1e-4] ylim = [0, 0.3] n = len(hp_vary_vals) colors = [mpl.cm.cool(i * 1.0 / n) for i in range(n)] elif hp_vary == 'p_weight_train': root_dir = './data/vary_pweighttrain_mante' title = r'$P_{\mathrm{train}}$' hp_vary_vals = [1, 0.1] ylim = [0, 0.15] n = len(hp_vary_vals) colors = [mpl.cm.cool(i * 1.0 / n) for i in range(n)] elif hp_vary == 'c_intsyn': root_dir = 'data/seq' title = '$c$' hp_vary_vals = [0, 1] ylim = [0, 0.2] hp_target = {} n = len(hp_vary_vals) colors = ['gray', 'red'] hp_targets = [dict(hp_target, **{hp_vary: h}) for h in hp_vary_vals] hists, xs, bottoms, tops, labels = list(), list(), list(), list(), list() for hp_target in hp_targets: model_dirs = tools.find_all_models(root_dir, hp_target) print([tools.load_log(d)['perf_min'][-1] for d in model_dirs]) # Only analyze models that trained # Perf_min applies to the last rule in sequential trained networks model_dirs = tools.select_by_perf(model_dirs, perf_min=0.8) if not model_dirs: continue rule_pair = ('contextdm1', 'contextdm2') if mode == 'all_var': hist_tmp, bins_edge = variance.compute_hist_varprop(model_dirs, rule_pair) elif mode == 'mante_var': hist = list() for d in model_dirs: var_dict = compute_var_all(d) frac_var = compute_frac_var(var_dict, var_thr=0.5, thr_type='or') hist_tmp, bins_edge = np.histogram(frac_var, bins=20, range=(-1, 1)) hist.append(hist_tmp) hist_tmp = np.array(hist) else: raise ValueError('Unknown mode') bin_size = bins_edge[1] - bins_edge[0] hist_tmp = hist_tmp.astype(np.float) hist_density = (hist_tmp.T / hist_tmp.sum(axis=1)).T hist = np.median(hist_density, axis=0) # Get the confidence interval with bootstrapping bottom, top = list(), list() n_model, n_point = hist_density.shape for i in range(n_point): medians = list() for j in range(400): h_sample = np.random.choice(hist_density[:, i], size=n_model) medians.append(np.median(h_sample)) bottom_tmp, top_tmp = np.percentile(medians, (2.5, 97.5)) bottom.append(bottom_tmp) top.append(top_tmp) hists.append(hist) xs.append((bins_edge[1:] + bins_edge[:-1]) / 2) bottoms.append(bottom) tops.append(top) labels.append(hp_target[hp_vary]) plt.figure(figsize=(3, 3)) _ = plt.plot(xs[-1], hist_tmp.T) plt.title(str(hp_target[hp_vary])) fs = 7 fig = plt.figure(figsize=(2.0, 1.2)) ax = fig.add_axes([0.3, 0.3, 0.5, 0.5]) for i in range(n): ax.plot(xs[i], hists[i], color=colors[i], label=labels[i]) ax.fill_between(xs[i], bottoms[i], tops[i], alpha=0.2, color=colors[i]) if legend: lg = ax.legend(title=title, fontsize=fs, frameon=False, loc=1, bbox_to_anchor=(1.2, 1.2)) plt.setp(lg.get_title(), fontsize=fs) ax.set_ylim(ylim) ax.set_xlim([-1.1, 1.1]) ax.tick_params(axis='both', which='major', labelsize=fs) ax.locator_params(nbins=3) ax.spines["right"].set_visible(False) ax.spines["top"].set_visible(False) ax.xaxis.set_ticks_position('bottom') ax.yaxis.set_ticks_position('left') ax.set_xlabel('FTV(Ctx DM 1, Ctx DM 2)', fontsize=fs) ax.set_ylabel('Proportion', fontsize=fs) fig_name = 'figure/fracvar_by' + hp_vary if save_name is not None: fig_name += save_name plt.savefig(fig_name + '.pdf', transparent=True) return hists
def _plot_performanceprogress_cont(model_dir, model_dir2=None, save=True): # Plot Training Progress log = tools.load_log(model_dir) hp = tools.load_hp(model_dir) trials = np.array(log['trials']) / 1000. times = log['times'] rule_now = log['rule_now'] if model_dir2 is not None: log2 = tools.load_log(model_dir2) trials2 = np.array(log2['trials']) / 1000. fs = 7 # fontsize lines = list() labels = list() rule_train_plot = hp['rule_trains'] rule_test_plot = hp['rules'] nx, ny = 4, 2 fig, axarr = plt.subplots(nx, ny, figsize=(3, 3), sharex=True) for i in range(int(nx * ny)): ix, iy = i % nx, int(i / nx) ax = axarr[ix, iy] if i >= len(rule_test_plot): ax.axis('off') continue rule = rule_test_plot[i] # Plot fills trials_rule_prev_end = 0 # end of previous rule training time for rule_ in rule_train_plot: if rule == rule_: ec = 'black' else: ec = (0, 0, 0, 0.1) trials_rule_now = [trials_rule_prev_end] + [ trials[ii] for ii in range(len(rule_now)) if rule_now[ii] == rule_ ] trials_rule_prev_end = trials_rule_now[-1] ax.fill_between(trials_rule_now, 0, 1, facecolor='none', edgecolor=ec, linewidth=0.5) # Plot lines line = ax.plot(trials, log['perf_' + rule], lw=1, color='gray') if model_dir2 is not None: ax.plot(trials2, log2['perf_' + rule], lw=1, color='red') lines.append(line[0]) if isinstance(rule, str): rule_name_print = rule_name[rule] else: rule_name_print = ' & '.join([rule_name[r] for r in rule]) labels.append(rule_name_print) ax.tick_params(axis='both', which='major', labelsize=fs) ax.set_ylim([0, 1.05]) ax.set_xlim([0, trials_rule_prev_end]) ax.set_yticks([0, 1]) ax.set_xticks([0, np.floor(trials_rule_prev_end / 100.) * 100]) if (ix == nx - 1) and (iy == 0): ax.set_xlabel('Total trials (1,000)', fontsize=fs, labelpad=1) if i == 0: ax.set_ylabel('Performance', fontsize=fs, labelpad=1) ax.spines["right"].set_visible(False) ax.spines["top"].set_visible(False) ax.xaxis.set_ticks_position('none') ax.yaxis.set_ticks_position('left') ax.set_title(rule_name[rule], fontsize=fs, y=0.87, color='black') print('Training time {:0.1f} hours'.format(times[-1] / 3600.)) if save: name = 'TrainingCont_Progress' if model_dir2 is not None: name = name + '2' plt.savefig('figure/' + name + '.pdf', transparent=True) plt.show()
def plot_histogram(): initdict = defaultdict(list) initdictother = defaultdict(list) initdictotherother = defaultdict(list) for model_dir in model_dirs: hp = tools.load_hp(model_dir) #check if performance exceeds target log = tools.load_log(model_dir) #if log['perf_avg'][-1] > hp['target_perf']: if log['perf_min'][-1] > hp['target_perf']: print('no. of clusters', log['n_cluster']) n_clusters.append(log['n_cluster']) hp_list.append(hp) initdict[hp['w_rec_init']].append(log['n_cluster']) initdict[hp['activation']].append(log['n_cluster']) #initdict[hp['rnn_type']].append(log['n_cluster']) if hp['activation'] != 'tanh': initdict[hp['rnn_type']].append(log['n_cluster']) initdictother[hp['rnn_type']+hp['activation']].append(log['n_cluster']) initdictotherother[hp['rnn_type']+hp['activation']+hp['w_rec_init']].append(log['n_cluster']) if hp['l1_h'] == 0: initdict['l1_h_0'].append(log['n_cluster']) else: #hp['l1_h'] == 1e-3 or 1e-4 or 1e-5: keyvalstr = 'l1_h_1emin'+str(int(abs(np.log10(hp['l1_h'])))) initdict[keyvalstr].append(log['n_cluster']) if hp['l1_weight'] == 0: initdict['l1_weight_0'].append(log['n_cluster']) else: #hp['l1_h'] == 1e-3 or 1e-4 or 1e-5: keyvalstr = 'l1_weight_1emin'+str(int(abs(np.log10(hp['l1_weight'])))) initdict[keyvalstr].append(log['n_cluster']) #initdict[hp['l1_weight']].append(log['n_cluster']) # Check no of clusters under various conditions. f, axarr = plt.subplots(7, 1, figsize=(3,12), sharex=True) u = 0 for key in initdict.keys(): if 'l1_' not in key: title = (key + ' ' + str(len(initdict[key])) + ' mean: '+str(round(np.mean(initdict[key]),2))) axarr[u].set_title(title) axarr[u].hist(initdict[key]) u += 1 f.subplots_adjust(wspace=.3, hspace=0.3) # plt.savefig('./figure/histforcases_96nets.png') # plt.savefig('./figure/histforcases__pt9_192nets.pdf') # plt.savefig('./figure/histforcases___leakygrunotanh_pt9_192nets.pdf') f, axarr = plt.subplots(4, 1, figsize=(3,8), sharex=True) u = 0 for key in initdictother.keys(): if 'l1_' not in key: axarr[u].set_title(key + ' ' + str(len(initdictother[key]))+ ' mean: '+str(round(np.mean(initdictother[key]),2)) ) axarr[u].hist(initdictother[key]) u += 1 f.subplots_adjust(wspace=.3, hspace=0.3) # plt.savefig('./figure/histforcases__leakyrnngrurelusoftplus_pt9_192nets.pdf') f, axarr = plt.subplots(4, 1, figsize=(3,6), sharex=True) u = 0 for key in initdictotherother.keys(): if 'l1_' not in key and 'diag' not in key: axarr[u].set_title(key + ' ' + str(len(initdictotherother[key]))+ ' mean: '+str(round(np.mean(initdictotherother[key]),2)) ) axarr[u].hist(initdictotherother[key]) u += 1 f.subplots_adjust(wspace=.3, hspace=0.3) # plt.savefig('./figure/histforcases_randortho_notanh_pt9_192nets.pdf') f, axarr = plt.subplots(4, 1, figsize=(3,6),sharex=True) u = 0 for key in initdictotherother.keys(): if 'l1_' not in key and 'randortho' not in key: axarr[u].set_title(key + ' ' + str(len(initdictotherother[key]))+ ' mean: '+str(round(np.mean(initdictotherother[key]),2)) ) axarr[u].hist(initdictotherother[key]) u += 1 f.subplots_adjust(wspace=.3, hspace=0.3) # plt.savefig('./figure/histforcases_diag_notanh_pt9_192nets.pdf') #regu-- f, axarr = plt.subplots(4, 1,figsize=(3,8),sharex=True) u = 0 for key in initdict.keys(): if 'l1_h_' in key: axarr[u].set_title(key + ' ' + str(len(initdict[key]))+ ' mean: '+str(round(np.mean(initdict[key]),2)) ) axarr[u].hist(initdict[key]) u += 1 f.subplots_adjust(wspace=.3, hspace=0.3) #plt.savefig('./figure/noofclusters_pt9_l1_h_192nets.pdf') f, axarr = plt.subplots(4, 1,figsize=(3,8),sharex=True) u = 0 for key in initdict.keys(): if 'l1_weight_' in key: axarr[u].set_title(key + ' ' + str(len(initdict[key])) + ' mean: '+str(round(np.mean(initdict[key]),2)) ) axarr[u].hist(initdict[key]) u += 1 f.subplots_adjust(wspace=.3, hspace=0.3)