def decrease_lr_in_optim_config(conf: Config, num_tasks_learnt: int) -> Config: """ Creates a new optim config with a decreased LR """ if num_tasks_learnt <= 0 or not conf.has('decrease_lr_coef'): return conf.clone() decrease_coef = conf.decrease_lr_coef**num_tasks_learnt # Updating LR in the main kwargs if conf.kwargs.has('lr'): target_lr = conf.kwargs.lr * decrease_coef conf = conf.overwrite({'kwargs': {'lr': target_lr}}) if conf.kwargs.has('groups'): groups_with_lr = [ g for g in conf.groups[g].keys() if conf.groups[g].has('lr') ] conf = conf.overwrite({ 'groups': { g: conf.groups[g].overwrite({'lr': conf.groups[g].lr}) for g in groups_with_lr } }) return conf
def __init__(self, config: Config): config = config.overwrite( config.datasets[config.dataset] ) # Overwriting with the dataset-dependent hyperparams config = config.overwrite( Config.read_from_cli()) # Overwriting with the CLI arguments config = config.overwrite(Config({'datasets': None})) # So not to pollute logs super(GANTrainer, self).__init__(config) if self.is_distributed: torch.set_num_threads(4)
def __init__(self, config: Config): config = config.overwrite(config[config.dataset]) config = config.overwrite(Config.read_from_cli()) config.exp_name = f'zsl_{config.dataset}_{config.hp.compute_hash()}_{config.random_seed}' if not config.get('silent'): print(config.hp) self.random = np.random.RandomState(config.random_seed) super().__init__(config)
def run_validation_sequence(args: argparse.Namespace, config: Config): experiments_vals = generate_experiments_from_hpo_grid( config.validation_sequence.hpo_grid) experiments_vals = [{p.replace('|', '.'): v for p, v in exp.items()} for exp in experiments_vals] configs = [config.overwrite({'hp': Config(hp)}) for hp in experiments_vals] scores = [] print(f'Number of random experiments: {len(configs)}') for i, c in enumerate(configs): print('<==== Running HPs ====>') print(experiments_vals[i]) c = c.overwrite( Config({ 'experiments_dir': f'{config.experiments_dir}-val-seqs', 'lll_setup.num_tasks': c.validation_sequence.num_tasks, 'logging.save_train_logits': False, 'logging.print_accuracy_after_task': False, 'logging.print_unseen_accuracy': False, 'logging.print_forgetting': False, 'exp_name': compute_experiment_name(args, config.hp) })) trainer = LLLTrainer(c) trainer.start() if config.validation_sequence.metric == 'harmonic_mean': score = np.mean(trainer.compute_harmonic_mean_accuracy()) elif config.validation_sequence.metric == 'final_task_wise_acc': score = np.mean(trainer.compute_final_tasks_performance()) else: raise NotImplementedError('Unknown metric') scores.append(score) best_config = configs[np.argmax(scores)] print('Best found setup:', experiments_vals[np.argmax(scores)]) print(best_config) best_config = best_config.overwrite( Config({'start_task': config.validation_sequence.num_tasks})) trainer = LLLTrainer(best_config) trainer.start()