Esempio n. 1
0
def decrease_lr_in_optim_config(conf: Config, num_tasks_learnt: int) -> Config:
    """
    Creates a new optim config with a decreased LR
    """
    if num_tasks_learnt <= 0 or not conf.has('decrease_lr_coef'):
        return conf.clone()

    decrease_coef = conf.decrease_lr_coef**num_tasks_learnt

    # Updating LR in the main kwargs
    if conf.kwargs.has('lr'):
        target_lr = conf.kwargs.lr * decrease_coef
        conf = conf.overwrite({'kwargs': {'lr': target_lr}})

    if conf.kwargs.has('groups'):
        groups_with_lr = [
            g for g in conf.groups[g].keys() if conf.groups[g].has('lr')
        ]
        conf = conf.overwrite({
            'groups': {
                g: conf.groups[g].overwrite({'lr': conf.groups[g].lr})
                for g in groups_with_lr
            }
        })

    return conf
Esempio n. 2
0
    def __init__(self, config: Config):
        config = config.overwrite(
            config.datasets[config.dataset]
        )  # Overwriting with the dataset-dependent hyperparams
        config = config.overwrite(
            Config.read_from_cli())  # Overwriting with the CLI arguments
        config = config.overwrite(Config({'datasets':
                                          None}))  # So not to pollute logs

        super(GANTrainer, self).__init__(config)

        if self.is_distributed:
            torch.set_num_threads(4)
Esempio n. 3
0
 def __init__(self, config: Config):
     config = config.overwrite(config[config.dataset])
     config = config.overwrite(Config.read_from_cli())
     config.exp_name = f'zsl_{config.dataset}_{config.hp.compute_hash()}_{config.random_seed}'
     if not config.get('silent'):
         print(config.hp)
     self.random = np.random.RandomState(config.random_seed)
     super().__init__(config)
Esempio n. 4
0
def run_validation_sequence(args: argparse.Namespace, config: Config):
    experiments_vals = generate_experiments_from_hpo_grid(
        config.validation_sequence.hpo_grid)
    experiments_vals = [{p.replace('|', '.'): v
                         for p, v in exp.items()} for exp in experiments_vals]
    configs = [config.overwrite({'hp': Config(hp)}) for hp in experiments_vals]
    scores = []

    print(f'Number of random experiments: {len(configs)}')

    for i, c in enumerate(configs):
        print('<==== Running HPs ====>')
        print(experiments_vals[i])

        c = c.overwrite(
            Config({
                'experiments_dir': f'{config.experiments_dir}-val-seqs',
                'lll_setup.num_tasks': c.validation_sequence.num_tasks,
                'logging.save_train_logits': False,
                'logging.print_accuracy_after_task': False,
                'logging.print_unseen_accuracy': False,
                'logging.print_forgetting': False,
                'exp_name': compute_experiment_name(args, config.hp)
            }))
        trainer = LLLTrainer(c)
        trainer.start()

        if config.validation_sequence.metric == 'harmonic_mean':
            score = np.mean(trainer.compute_harmonic_mean_accuracy())
        elif config.validation_sequence.metric == 'final_task_wise_acc':
            score = np.mean(trainer.compute_final_tasks_performance())
        else:
            raise NotImplementedError('Unknown metric')

        scores.append(score)

    best_config = configs[np.argmax(scores)]
    print('Best found setup:', experiments_vals[np.argmax(scores)])
    print(best_config)

    best_config = best_config.overwrite(
        Config({'start_task': config.validation_sequence.num_tasks}))
    trainer = LLLTrainer(best_config)
    trainer.start()