def run_config_hooks(self, config, config_updates, command_name, logger): final_cfg_updates = {} for ch in self.config_hooks: cfg_upup = ch(deepcopy(config), command_name, logger) if cfg_upup: recursive_update(final_cfg_updates, cfg_upup) recursive_update(final_cfg_updates, config_updates) return final_cfg_updates
def main(config_file, update, _run, _log, _config): working_dir = _run.observers[0].dir # load the config file config = yaml_load(config_file) recursive_update(config, update) yaml_dump(config, path.join(working_dir, 'config.yaml')) _config = config print(_config) print(working_dir) dataset = _config['dataset'] dataset['device'] = update['device'] # load the dataset and the vocab train_loader, dev_loader, test_loader, vocab = load_data(**dataset) vocab_size = len(vocab.itos) # model _config['hidden']['features'][0] = vocab_size # trainer batch test_sample = _config['trainer_batch']['test_sample'] _config['trainer_batch']['test_sample'] = 1 config = extend_config_reference(_config) trainer = config['trainer'] trainer['evaluate_interval'] = len( train_loader) * trainer['evaluate_interval'] trainer['save_checkpoint_interval'] = trainer['evaluate_interval'] trainer['base_dir'] = working_dir yaml_dump(trainer, path.join(working_dir, 'trainer.yaml')) trainer['train_iterator'] = train_loader trainer['dev_iterator'] = dev_loader trainer['test_iterator'] = None callback = EvaluationCallback(working_dir, vocab, corpus_dir=path.join(dataset['data_dir'], 'corpus'), **config['callback']) trainer['callbacks'] = callback trainer['logger'] = _log print(config) trainer = Trainer.from_config(trainer) _log.info("model architecture") print(trainer.trainer_batch.model) # train the model trainer.train() # testing and save results trainer.dev_iterator = test_loader trainer.trainer_batch.test_sample = test_sample # test using many samples, but not in development dataset trainer.restore_from_basedir(best=True) stat = trainer._evaluate_epoch().get_dict() callback.evaluate_topic_coherence() # topic coherence of best checkpoint stat.update(callback.get_dict()) yaml_dump(stat, path.join(working_dir, 'result.yaml')) _log.info('test result of best evaluation {}'.format(stat))
def run_config_hooks(self, config, command_name, logger): final_cfg_updates = {} for ch in self.config_hooks: cfg_upup = ch(deepcopy(config), command_name, logger) if cfg_upup: recursive_update(final_cfg_updates, cfg_upup) # final update fills in config_update, if list is already set, config updates is used recursive_update(final_cfg_updates, self.config_updates) return final_cfg_updates
def load_component_default_config(component_config, all_default_configs): component_default_config = {} if '_name' in component_config: elt_default = deepcopy(all_default_configs.get(component_config['_name'], {})) default = load_component_default_config(elt_default, all_default_configs) recursive_update(default, elt_default) component_default_config.update(default) for key, val in component_config.items(): if isinstance(val, dict): conf = load_component_default_config(val, all_default_configs) if conf: component_default_config[key] = conf return component_default_config
def test_recursive_update(): d = {'a': {'b': 1}} res = recursive_update(d, {'c': 2, 'a': {'d': 3}}) assert d is res assert res == {'a': {'b': 1, 'd': 3}, 'c': 2}
def test_recursive_update(): d = {"a": {"b": 1}} res = recursive_update(d, {"c": 2, "a": {"d": 3}}) assert d is res assert res == {"a": {"b": 1, "d": 3}, "c": 2}
def create_run(experiment, command_name, config_updates=None, named_configs=(), force=False, log_level=None): sorted_ingredients = gather_ingredients_topological(experiment) scaffolding = create_scaffolding(experiment, sorted_ingredients) # get all split non-empty prefixes sorted from deepest to shallowest prefixes = sorted([s.split('.') for s in scaffolding if s != ''], reverse=True, key=lambda p: len(p)) # --------- configuration process ------------------- # Phase 1: Config updates config_updates = config_updates or {} config_updates = convert_to_nested_dict(config_updates) root_logger, run_logger = initialize_logging(experiment, scaffolding, log_level) distribute_config_updates(prefixes, scaffolding, config_updates) # Phase 2: Named Configs for ncfg in named_configs: scaff, cfg_name = get_scaffolding_and_config_name(ncfg, scaffolding) scaff.gather_fallbacks() ncfg_updates = scaff.run_named_config(cfg_name) distribute_presets(prefixes, scaffolding, ncfg_updates) for ncfg_key, value in iterate_flattened(ncfg_updates): set_by_dotted_path(config_updates, join_paths(scaff.path, ncfg_key), value) distribute_config_updates(prefixes, scaffolding, config_updates) # Phase 3: Normal config scopes for scaffold in scaffolding.values(): scaffold.gather_fallbacks() scaffold.set_up_config() # update global config config = get_configuration(scaffolding) # run config hooks config_hook_updates = scaffold.run_config_hooks( config, command_name, run_logger) recursive_update(scaffold.config, config_hook_updates) # Phase 4: finalize seeding for scaffold in reversed(list(scaffolding.values())): scaffold.set_up_seed() # partially recursive config = get_configuration(scaffolding) config_modifications = get_config_modifications(scaffolding) # ---------------------------------------------------- experiment_info = experiment.get_experiment_info() host_info = get_host_info() main_function = get_command(scaffolding, command_name) pre_runs = [pr for ing in sorted_ingredients for pr in ing.pre_run_hooks] post_runs = [pr for ing in sorted_ingredients for pr in ing.post_run_hooks] run = Run(config, config_modifications, main_function, copy(experiment.observers), root_logger, run_logger, experiment_info, host_info, pre_runs, post_runs, experiment.captured_out_filter) if hasattr(main_function, 'unobserved'): run.unobserved = main_function.unobserved run.force = force for scaffold in scaffolding.values(): scaffold.finalize_initialization(run=run) return run