def plot_all(dataset):
    """Plot all statistics for datasets.
    
    Args:
        dataset: str. Can be mante_ar, mante_single_ar, mante_fe,
        mante_single_fe, siegel, model
    """
    # [0, 3.*1e-6, 1e-5, 3*1e-4, 1e-4, 3*1e-3]
    if dataset == 'model':
# =============================================================================
#         root_dir = './data/vary_l2init_mante'
#         hp_target = {'activation': 'softplus',
#                      'rnn_type': 'LeakyRNN',
#                      'w_rec_init': 'randortho',
#                      'l2_weight_init': 0*1e-4}
# =============================================================================
        root_dir = './data/vary_pweighttrain_mante'
        hp_target = {'activation': 'softplus',
                     'rnn_type': 'LeakyRNN',
                     'p_weight_train': 0.1}
# =============================================================================
#         root_dir = './data/mante_tanh'
#         hp_target = {}
# =============================================================================
        # model_dir = tools.find_model(root_dir, hp_target, perf_min=0.8)
        model_dirs = tools.find_all_models(root_dir, hp_target)
        model_dirs = tools.select_by_perf(model_dirs, perf_min=0.8)
        print(len(model_dirs))
        model_dir = model_dirs[1]
        # model_dir = 'data/mante_l2init'
    else:
        model_dir = None

    data = load_data(dataset=dataset,
                     model_dir=model_dir)

    if dataset == 'siegel':
        data_area = [d for d in data if d['area'] == 'PFC']
    else:
        data_area = data

    if dataset == 'model':
        var_dict = compute_var_all(model_dir)
    else:
        var_dict = _compute_var_all(data_area, var_method='time_avg_late')
    var_thr, thr_type = 0.0, 'or'
    frac_var = compute_frac_var(var_dict, var_thr=var_thr, thr_type=thr_type)

    plot_rate_distribution(data_area)

    plot_frac_var(frac_var, save_name=dataset)

    if dataset == 'model':
        performance.plot_performanceprogress(model_dir, save=False)
        variance.plot_hist_varprop(model_dir=model_dir,
                                   rule_pair=('contextdm1', 'contextdm2'))
Ejemplo n.º 2
0
def display_rich_output(model, sess, step, log, model_dir):
    """Display step by step outputs during training."""
    variance._compute_variance_bymodel(model, sess)
    rule_pair = ['contextdm1', 'contextdm2']
    save_name = '_atstep' + str(step)
    title = ('Step ' + str(step) +
             ' Perf. {:0.2f}'.format(log['perf_avg'][-1]))
    variance.plot_hist_varprop(model_dir, rule_pair,
                               figname_extra=save_name,
                               title=title)
    plt.close('all')