Пример #1
0
def plot_replacerule_performance_group(model_dir,
                                       setup=1,
                                       restore=True,
                                       fig_name_addon=None):
    model_dirs = tools.valid_model_dirs(model_dir)
    print('Analyzing models : ')
    print(model_dirs)

    perfs_plot = list()
    for model_dir in model_dirs:
        perfs, rule, names = compute_replacerule_performance(
            model_dir, setup, restore)
        perfs_plot.append(perfs)

    perfs_plot = np.array(perfs_plot)
    perfs_median = np.median(perfs_plot, axis=0)

    fig_name = 'taskset{:d}_perf'.format(setup)
    if fig_name_addon is not None:
        fig_name = fig_name + fig_name_addon

    print(perfs_median)
    _plot_replacerule_performance(perfs_plot,
                                  rule,
                                  names,
                                  setup,
                                  fig_name=fig_name)
Пример #2
0
def get_allperformance(model_dirs, param_list=None):
    # Get all model names that match patterns (strip off .ckpt.meta at the end)
    model_dirs = tools.valid_model_dirs(model_dirs)

    final_perfs = dict()
    filenames = dict()

    if param_list is None:
        param_list = ['param_intsyn', 'easy_task', 'activation']

    for model_dir in model_dirs:
        log = tools.load_log(model_dir)
        hp = tools.load_hp(model_dir)

        perf_tests = log['perf_tests']

        final_perf = np.mean([float(val[-1]) for val in perf_tests.values()])

        key = tuple([hp[p] for p in param_list])
        if key in final_perfs.keys():
            final_perfs[key].append(final_perf)
        else:
            final_perfs[key] = [final_perf]
            filenames[key] = model_dir

    for key, val in final_perfs.items():
        final_perfs[key] = np.mean(val)
        print(key),
        print('{:0.3f}'.format(final_perfs[key])),
        print(filenames[key])
Пример #3
0
def plot_taskspace_group(root_dir, setup=1, restore=True,
                         representation='rate', fig_name_addon=None):
    """Plot task space for a group of networks.

    Args:
        root_dir : the root directory for all models to analyse
        setup: int, the combination of rules to use
        restore: bool, whether to restore results
        representation: 'rate' or 'weight'
    """

    model_dirs = tools.valid_model_dirs(root_dir)
    print('Analyzing models : ')
    print(model_dirs)

    h_trans_all = OrderedDict()
    i = 0
    for model_dir in model_dirs:
        try:
            h_trans = compute_taskspace(model_dir, setup,
                                        restore=restore,
                                        representation=representation)
        except ValueError:
            print('Skipping model at ' + model_dir)
            continue

        h_trans_values = list(h_trans.values())

        # When PC1 and PC2 capture similar variances, allow for a rotation
        # rotation_matrix, clock wise
        get_angle = lambda vec : np.arctan2(vec[1], vec[0])
        theta = get_angle(h_trans_values[0][0])
        # theta = 0
        rot_mat = np.array([[np.cos(theta), -np.sin(theta)],
                            [np.sin(theta),  np.cos(theta)]])

        for key, val in h_trans.items():
            h_trans[key] = np.dot(val, rot_mat)

        h_trans_values = list(h_trans.values())
        if h_trans_values[1][0][1] < 0:
            for key, val in h_trans.items():
                h_trans[key] = val*np.array([1, -1])

        if i == 0:
            for key, val in h_trans.items():
                h_trans_all[key] = val
        else:
            for key, val in h_trans.items():
                h_trans_all[key] = np.concatenate((h_trans_all[key], val), axis=0)
        i += 1
    fig_name = 'taskset{:d}_{:s}space'.format(setup, representation)
    if fig_name_addon is not None:
        fig_name = fig_name + fig_name_addon

    lxy = _plot_taskspace(h_trans_all, fig_name, setup=setup)
    fig_name = fig_name + '_example'
    lxy = _plot_taskspace(h_trans_all, fig_name, setup=setup,
                          plot_example=True, lxy=lxy)
Пример #4
0
def compute_variance(model_dir, rules=None, random_rotation=False):
    """Compute variance for all tasks.

    Args:
        model_dir: str, the path of the model directory
        rules: list of rules to compute variance, list of strings
        random_rotation: boolean. If True, rotate the neural activity.
    """
    dirs = tools.valid_model_dirs(model_dir)
    for d in dirs:
        _compute_variance(d, rules, random_rotation)
Пример #5
0
def get_n_clusters(root_dir):
    model_dirs = tools.valid_model_dirs(root_dir)
    hp_list = list()
    n_clusters = list()
    for i, model_dir in enumerate(model_dirs):
        if i % 50 == 0:
            print('Analyzing model {:d}/{:d}'.format(i, len(model_dirs)))
        hp = tools.load_hp(model_dir)
        log = tools.load_log(model_dir)
        # check if performance exceeds target
        if log['perf_min'][-1] > hp['target_perf']:
            n_clusters.append(log['n_cluster'])
            hp_list.append(hp)
    return n_clusters, hp_list
Пример #6
0
def compute_hist_varprop(model_dir, rule_pair, random_rotation=False):
    data_type = 'rule'
    assert len(rule_pair) == 2
    assert data_type == 'rule'

    model_dirs = tools.valid_model_dirs(model_dir)

    hists = list()
    for model_dir in model_dirs:
        hist, bins_edge_ = _compute_hist_varprop(model_dir, rule_pair,
                                                 random_rotation)
        if hist is None:
            continue
        else:
            bins_edge = bins_edge_

        # Store
        hists.append(hist)

    # Get median of all histogram
    hists = np.array(hists)
    # hist_low, hist_med, hist_high = np.percentile(hists, [10, 50, 90], axis=0)

    return hists, bins_edge
Пример #7
0
def plot_hist_varprop_all(model_dir, plot_control=True):
    '''
    Plot histogram of proportion of variance for some tasks across units
    :param save_name:
    :param data_type:
    :param rule_pair: list of rule_pair. Show proportion of variance for the first rule
    :return:
    '''

    model_dirs = tools.valid_model_dirs(model_dir)

    hp = tools.load_hp(model_dirs[0])
    rules = hp['rules']

    figsize = (7, 7)

    # For testing
    # rules, figsize = ['fdgo','reactgo','delaygo', 'fdanti', 'reactanti'], (4, 4)

    fs = 6  # fontsize

    f, axarr = plt.subplots(len(rules), len(rules), figsize=figsize)
    plt.subplots_adjust(left=0.1, right=0.98, bottom=0.02, top=0.9)

    for i in range(len(rules)):
        for j in range(len(rules)):
            ax = axarr[i, j]
            if i == 0:
                ax.set_title(rule_name[rules[j]],
                             fontsize=fs,
                             rotation=45,
                             va='bottom')
            if j == 0:
                ax.set_ylabel(rule_name[rules[i]],
                              fontsize=fs,
                              rotation=45,
                              ha='right')

            ax.spines["right"].set_visible(False)
            ax.spines["left"].set_visible(False)
            ax.spines["top"].set_visible(False)
            if i == j:
                ax.spines["bottom"].set_visible(False)
                ax.set_xticks([])
                ax.set_yticks([])
                continue

            hists, bins_edge = compute_hist_varprop(model_dir,
                                                    (rules[i], rules[j]))
            hist_low, hist_med, hist_high = np.percentile(hists, [10, 50, 90],
                                                          axis=0)
            hist_med /= hist_med.sum()

            # Control case
            if plot_control:
                hists_ctrl, _ = compute_hist_varprop(model_dir,
                                                     (rules[i], rules[j]),
                                                     random_rotation=True)
                _, hist_med_ctrl, _ = np.percentile(hists_ctrl, [10, 50, 90],
                                                    axis=0)
                hist_med_ctrl /= hist_med_ctrl.sum()
                ax.plot((bins_edge[:-1] + bins_edge[1:]) / 2,
                        hist_med_ctrl,
                        color='gray',
                        lw=0.75)

            ax.plot((bins_edge[:-1] + bins_edge[1:]) / 2,
                    hist_med,
                    color='black')
            plt.locator_params(nbins=3)

            # ax.set_ylim(bottom=-0.02*hist_med.max())
            ax.set_ylim([-0.01, 0.6])
            print(hist_med.max())
            ax.set_xticks([-1, 1])
            ax.set_xticklabels([])
            if i == 0 and j == 1:
                ax.set_yticks([0, 0.6])
                ax.spines["left"].set_visible(True)
            else:
                ax.set_yticks([])
            ax.set_xlim([-1, 1])
            ax.xaxis.set_ticks_position('bottom')
            ax.tick_params(axis='both', which='major', labelsize=fs, length=2)

    # plt.tight_layout()
    plt.savefig('figure/plot_hist_varprop_all.pdf', transparent=True)