def plot_replacerule_performance_group(model_dir, setup=1, restore=True, fig_name_addon=None): model_dirs = tools.valid_model_dirs(model_dir) print('Analyzing models : ') print(model_dirs) perfs_plot = list() for model_dir in model_dirs: perfs, rule, names = compute_replacerule_performance( model_dir, setup, restore) perfs_plot.append(perfs) perfs_plot = np.array(perfs_plot) perfs_median = np.median(perfs_plot, axis=0) fig_name = 'taskset{:d}_perf'.format(setup) if fig_name_addon is not None: fig_name = fig_name + fig_name_addon print(perfs_median) _plot_replacerule_performance(perfs_plot, rule, names, setup, fig_name=fig_name)
def get_allperformance(model_dirs, param_list=None): # Get all model names that match patterns (strip off .ckpt.meta at the end) model_dirs = tools.valid_model_dirs(model_dirs) final_perfs = dict() filenames = dict() if param_list is None: param_list = ['param_intsyn', 'easy_task', 'activation'] for model_dir in model_dirs: log = tools.load_log(model_dir) hp = tools.load_hp(model_dir) perf_tests = log['perf_tests'] final_perf = np.mean([float(val[-1]) for val in perf_tests.values()]) key = tuple([hp[p] for p in param_list]) if key in final_perfs.keys(): final_perfs[key].append(final_perf) else: final_perfs[key] = [final_perf] filenames[key] = model_dir for key, val in final_perfs.items(): final_perfs[key] = np.mean(val) print(key), print('{:0.3f}'.format(final_perfs[key])), print(filenames[key])
def plot_taskspace_group(root_dir, setup=1, restore=True, representation='rate', fig_name_addon=None): """Plot task space for a group of networks. Args: root_dir : the root directory for all models to analyse setup: int, the combination of rules to use restore: bool, whether to restore results representation: 'rate' or 'weight' """ model_dirs = tools.valid_model_dirs(root_dir) print('Analyzing models : ') print(model_dirs) h_trans_all = OrderedDict() i = 0 for model_dir in model_dirs: try: h_trans = compute_taskspace(model_dir, setup, restore=restore, representation=representation) except ValueError: print('Skipping model at ' + model_dir) continue h_trans_values = list(h_trans.values()) # When PC1 and PC2 capture similar variances, allow for a rotation # rotation_matrix, clock wise get_angle = lambda vec : np.arctan2(vec[1], vec[0]) theta = get_angle(h_trans_values[0][0]) # theta = 0 rot_mat = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) for key, val in h_trans.items(): h_trans[key] = np.dot(val, rot_mat) h_trans_values = list(h_trans.values()) if h_trans_values[1][0][1] < 0: for key, val in h_trans.items(): h_trans[key] = val*np.array([1, -1]) if i == 0: for key, val in h_trans.items(): h_trans_all[key] = val else: for key, val in h_trans.items(): h_trans_all[key] = np.concatenate((h_trans_all[key], val), axis=0) i += 1 fig_name = 'taskset{:d}_{:s}space'.format(setup, representation) if fig_name_addon is not None: fig_name = fig_name + fig_name_addon lxy = _plot_taskspace(h_trans_all, fig_name, setup=setup) fig_name = fig_name + '_example' lxy = _plot_taskspace(h_trans_all, fig_name, setup=setup, plot_example=True, lxy=lxy)
def compute_variance(model_dir, rules=None, random_rotation=False): """Compute variance for all tasks. Args: model_dir: str, the path of the model directory rules: list of rules to compute variance, list of strings random_rotation: boolean. If True, rotate the neural activity. """ dirs = tools.valid_model_dirs(model_dir) for d in dirs: _compute_variance(d, rules, random_rotation)
def get_n_clusters(root_dir): model_dirs = tools.valid_model_dirs(root_dir) hp_list = list() n_clusters = list() for i, model_dir in enumerate(model_dirs): if i % 50 == 0: print('Analyzing model {:d}/{:d}'.format(i, len(model_dirs))) hp = tools.load_hp(model_dir) log = tools.load_log(model_dir) # check if performance exceeds target if log['perf_min'][-1] > hp['target_perf']: n_clusters.append(log['n_cluster']) hp_list.append(hp) return n_clusters, hp_list
def compute_hist_varprop(model_dir, rule_pair, random_rotation=False): data_type = 'rule' assert len(rule_pair) == 2 assert data_type == 'rule' model_dirs = tools.valid_model_dirs(model_dir) hists = list() for model_dir in model_dirs: hist, bins_edge_ = _compute_hist_varprop(model_dir, rule_pair, random_rotation) if hist is None: continue else: bins_edge = bins_edge_ # Store hists.append(hist) # Get median of all histogram hists = np.array(hists) # hist_low, hist_med, hist_high = np.percentile(hists, [10, 50, 90], axis=0) return hists, bins_edge
def plot_hist_varprop_all(model_dir, plot_control=True): ''' Plot histogram of proportion of variance for some tasks across units :param save_name: :param data_type: :param rule_pair: list of rule_pair. Show proportion of variance for the first rule :return: ''' model_dirs = tools.valid_model_dirs(model_dir) hp = tools.load_hp(model_dirs[0]) rules = hp['rules'] figsize = (7, 7) # For testing # rules, figsize = ['fdgo','reactgo','delaygo', 'fdanti', 'reactanti'], (4, 4) fs = 6 # fontsize f, axarr = plt.subplots(len(rules), len(rules), figsize=figsize) plt.subplots_adjust(left=0.1, right=0.98, bottom=0.02, top=0.9) for i in range(len(rules)): for j in range(len(rules)): ax = axarr[i, j] if i == 0: ax.set_title(rule_name[rules[j]], fontsize=fs, rotation=45, va='bottom') if j == 0: ax.set_ylabel(rule_name[rules[i]], fontsize=fs, rotation=45, ha='right') ax.spines["right"].set_visible(False) ax.spines["left"].set_visible(False) ax.spines["top"].set_visible(False) if i == j: ax.spines["bottom"].set_visible(False) ax.set_xticks([]) ax.set_yticks([]) continue hists, bins_edge = compute_hist_varprop(model_dir, (rules[i], rules[j])) hist_low, hist_med, hist_high = np.percentile(hists, [10, 50, 90], axis=0) hist_med /= hist_med.sum() # Control case if plot_control: hists_ctrl, _ = compute_hist_varprop(model_dir, (rules[i], rules[j]), random_rotation=True) _, hist_med_ctrl, _ = np.percentile(hists_ctrl, [10, 50, 90], axis=0) hist_med_ctrl /= hist_med_ctrl.sum() ax.plot((bins_edge[:-1] + bins_edge[1:]) / 2, hist_med_ctrl, color='gray', lw=0.75) ax.plot((bins_edge[:-1] + bins_edge[1:]) / 2, hist_med, color='black') plt.locator_params(nbins=3) # ax.set_ylim(bottom=-0.02*hist_med.max()) ax.set_ylim([-0.01, 0.6]) print(hist_med.max()) ax.set_xticks([-1, 1]) ax.set_xticklabels([]) if i == 0 and j == 1: ax.set_yticks([0, 0.6]) ax.spines["left"].set_visible(True) else: ax.set_yticks([]) ax.set_xlim([-1, 1]) ax.xaxis.set_ticks_position('bottom') ax.tick_params(axis='both', which='major', labelsize=fs, length=2) # plt.tight_layout() plt.savefig('figure/plot_hist_varprop_all.pdf', transparent=True)