def train_all_tanhgru(seed=0, model_dir='tanhgru'): """Training of all tasks with Tanh GRUs.""" model_dir = os.path.join(DATAPATH, model_dir, str(seed)) hp = {'activation': 'tanh', 'rnn_type': 'LeakyGRU'} rule_prob_map = {'contextdm1': 5, 'contextdm2': 5} train.train(model_dir, hp=hp, ruleset='all', rule_prob_map=rule_prob_map, seed=seed) # Analyses variance.compute_variance(model_dir) log = tools.load_log(model_dir) analysis = clustering.Analysis(model_dir, 'rule') log['n_cluster'] = analysis.n_cluster tools.save_log(log) data_analysis.compute_var_all(model_dir) setups = [1, 2, 3] for setup in setups: taskset.compute_taskspace(model_dir, setup, restore=False, representation='rate') taskset.compute_replacerule_performance(model_dir, setup, False)
def train_all_mixrule_softplus(seed=0, root_dir='mixrule_softplus'): """Training of all tasks.""" model_dir = os.path.join(DATAPATH, root_dir, str(seed)) hp = { 'activation': 'softplus', 'w_rec_init': 'diag', 'use_separate_input': True, 'mix_rule': True } rule_prob_map = {'contextdm1': 5, 'contextdm2': 5} train.train(model_dir, hp=hp, ruleset='all', rule_prob_map=rule_prob_map, seed=seed) # Analyses variance.compute_variance(model_dir) log = tools.load_log(model_dir) analysis = clustering.Analysis(model_dir, 'rule') log['n_cluster'] = analysis.n_cluster tools.save_log(log) setups = [1, 2, 3] for setup in setups: taskset.compute_taskspace(model_dir, setup, restore=False, representation='rate') taskset.compute_replacerule_performance(model_dir, setup, False)
def train_all_analysis(seed=0, root_dir='train_all'): model_dir = os.path.join(DATAPATH, root_dir, str(seed)) # Analyses variance.compute_variance(model_dir) variance.compute_variance(model_dir, random_rotation=True) log = tools.load_log(model_dir) analysis = clustering.Analysis(model_dir, 'rule') log['n_cluster'] = analysis.n_cluster tools.save_log(log) data_analysis.compute_var_all(model_dir) for rule in ['dm1', 'contextdm1', 'multidm']: performance.compute_choicefamily_varytime(model_dir, rule) setups = [1, 2, 3] for setup in setups: taskset.compute_taskspace(model_dir, setup, restore=False, representation='rate') taskset.compute_replacerule_performance(model_dir, setup, False)