def train_all_tanhgru(seed=0, model_dir='tanhgru'):
    """Training of all tasks with Tanh GRUs."""
    model_dir = os.path.join(DATAPATH, model_dir, str(seed))
    hp = {'activation': 'tanh', 'rnn_type': 'LeakyGRU'}
    rule_prob_map = {'contextdm1': 5, 'contextdm2': 5}
    train.train(model_dir,
                hp=hp,
                ruleset='all',
                rule_prob_map=rule_prob_map,
                seed=seed)
    # Analyses
    variance.compute_variance(model_dir)
    log = tools.load_log(model_dir)
    analysis = clustering.Analysis(model_dir, 'rule')
    log['n_cluster'] = analysis.n_cluster
    tools.save_log(log)
    data_analysis.compute_var_all(model_dir)

    setups = [1, 2, 3]
    for setup in setups:
        taskset.compute_taskspace(model_dir,
                                  setup,
                                  restore=False,
                                  representation='rate')
        taskset.compute_replacerule_performance(model_dir, setup, False)
def train_all_mixrule_softplus(seed=0, root_dir='mixrule_softplus'):
    """Training of all tasks."""
    model_dir = os.path.join(DATAPATH, root_dir, str(seed))
    hp = {
        'activation': 'softplus',
        'w_rec_init': 'diag',
        'use_separate_input': True,
        'mix_rule': True
    }
    rule_prob_map = {'contextdm1': 5, 'contextdm2': 5}
    train.train(model_dir,
                hp=hp,
                ruleset='all',
                rule_prob_map=rule_prob_map,
                seed=seed)

    # Analyses
    variance.compute_variance(model_dir)
    log = tools.load_log(model_dir)
    analysis = clustering.Analysis(model_dir, 'rule')
    log['n_cluster'] = analysis.n_cluster
    tools.save_log(log)

    setups = [1, 2, 3]
    for setup in setups:
        taskset.compute_taskspace(model_dir,
                                  setup,
                                  restore=False,
                                  representation='rate')
        taskset.compute_replacerule_performance(model_dir, setup, False)
def train_all_analysis(seed=0, root_dir='train_all'):
    model_dir = os.path.join(DATAPATH, root_dir, str(seed))
    # Analyses
    variance.compute_variance(model_dir)
    variance.compute_variance(model_dir, random_rotation=True)
    log = tools.load_log(model_dir)
    analysis = clustering.Analysis(model_dir, 'rule')
    log['n_cluster'] = analysis.n_cluster
    tools.save_log(log)
    data_analysis.compute_var_all(model_dir)

    for rule in ['dm1', 'contextdm1', 'multidm']:
        performance.compute_choicefamily_varytime(model_dir, rule)

    setups = [1, 2, 3]
    for setup in setups:
        taskset.compute_taskspace(model_dir,
                                  setup,
                                  restore=False,
                                  representation='rate')
        taskset.compute_replacerule_performance(model_dir, setup, False)