示例#1
0
 def run_config_hooks(self, config, config_updates, command_name, logger):
     final_cfg_updates = {}
     for ch in self.config_hooks:
         cfg_upup = ch(deepcopy(config), command_name, logger)
         if cfg_upup:
             recursive_update(final_cfg_updates, cfg_upup)
     recursive_update(final_cfg_updates, config_updates)
     return final_cfg_updates
示例#2
0
 def run_config_hooks(self, config, config_updates, command_name, logger):
     final_cfg_updates = {}
     for ch in self.config_hooks:
         cfg_upup = ch(deepcopy(config), command_name, logger)
         if cfg_upup:
             recursive_update(final_cfg_updates, cfg_upup)
     recursive_update(final_cfg_updates, config_updates)
     return final_cfg_updates
示例#3
0
def main(config_file, update, _run, _log, _config):
    working_dir = _run.observers[0].dir
    # load the config file
    config = yaml_load(config_file)
    recursive_update(config, update)
    yaml_dump(config, path.join(working_dir, 'config.yaml'))
    _config = config
    print(_config)
    print(working_dir)
    dataset = _config['dataset']
    dataset['device'] = update['device']
    # load the dataset and the vocab
    train_loader, dev_loader, test_loader, vocab = load_data(**dataset)
    vocab_size = len(vocab.itos)

    # model
    _config['hidden']['features'][0] = vocab_size

    # trainer batch
    test_sample = _config['trainer_batch']['test_sample']
    _config['trainer_batch']['test_sample'] = 1

    config = extend_config_reference(_config)
    trainer = config['trainer']
    trainer['evaluate_interval'] = len(
        train_loader) * trainer['evaluate_interval']
    trainer['save_checkpoint_interval'] = trainer['evaluate_interval']
    trainer['base_dir'] = working_dir
    yaml_dump(trainer, path.join(working_dir, 'trainer.yaml'))
    trainer['train_iterator'] = train_loader
    trainer['dev_iterator'] = dev_loader
    trainer['test_iterator'] = None
    callback = EvaluationCallback(working_dir,
                                  vocab,
                                  corpus_dir=path.join(dataset['data_dir'],
                                                       'corpus'),
                                  **config['callback'])
    trainer['callbacks'] = callback
    trainer['logger'] = _log

    print(config)
    trainer = Trainer.from_config(trainer)
    _log.info("model architecture")
    print(trainer.trainer_batch.model)

    # train the model
    trainer.train()

    # testing and save results
    trainer.dev_iterator = test_loader
    trainer.trainer_batch.test_sample = test_sample  # test using many samples, but not in development dataset
    trainer.restore_from_basedir(best=True)
    stat = trainer._evaluate_epoch().get_dict()
    callback.evaluate_topic_coherence()  # topic coherence of best checkpoint
    stat.update(callback.get_dict())
    yaml_dump(stat, path.join(working_dir, 'result.yaml'))
    _log.info('test result of best evaluation {}'.format(stat))
示例#4
0
 def run_config_hooks(self, config, command_name, logger):
     final_cfg_updates = {}
     for ch in self.config_hooks:
         cfg_upup = ch(deepcopy(config), command_name, logger)
         if cfg_upup:
             recursive_update(final_cfg_updates, cfg_upup)
     # final update fills in config_update, if list is already set, config updates is used
     recursive_update(final_cfg_updates, self.config_updates)
     return final_cfg_updates
示例#5
0
def load_component_default_config(component_config, all_default_configs):
    component_default_config = {}
    if '_name' in component_config:
        elt_default = deepcopy(all_default_configs.get(component_config['_name'], {}))
        default = load_component_default_config(elt_default, all_default_configs)
        recursive_update(default, elt_default)
        component_default_config.update(default)
    for key, val in component_config.items():
        if isinstance(val, dict):
            conf = load_component_default_config(val, all_default_configs)
            if conf:
                component_default_config[key] = conf

    return component_default_config
示例#6
0
def test_recursive_update():
    d = {'a': {'b': 1}}
    res = recursive_update(d, {'c': 2, 'a': {'d': 3}})
    assert d is res
    assert res == {'a': {'b': 1, 'd': 3}, 'c': 2}
示例#7
0
def test_recursive_update():
    d = {"a": {"b": 1}}
    res = recursive_update(d, {"c": 2, "a": {"d": 3}})
    assert d is res
    assert res == {"a": {"b": 1, "d": 3}, "c": 2}
示例#8
0
def test_recursive_update():
    d = {'a': {'b': 1}}
    res = recursive_update(d, {'c': 2, 'a': {'d': 3}})
    assert d is res
    assert res == {'a': {'b': 1, 'd': 3}, 'c': 2}
示例#9
0
def create_run(experiment,
               command_name,
               config_updates=None,
               named_configs=(),
               force=False,
               log_level=None):

    sorted_ingredients = gather_ingredients_topological(experiment)
    scaffolding = create_scaffolding(experiment, sorted_ingredients)
    # get all split non-empty prefixes sorted from deepest to shallowest
    prefixes = sorted([s.split('.') for s in scaffolding if s != ''],
                      reverse=True,
                      key=lambda p: len(p))

    # --------- configuration process -------------------

    # Phase 1: Config updates
    config_updates = config_updates or {}
    config_updates = convert_to_nested_dict(config_updates)
    root_logger, run_logger = initialize_logging(experiment, scaffolding,
                                                 log_level)
    distribute_config_updates(prefixes, scaffolding, config_updates)

    # Phase 2: Named Configs
    for ncfg in named_configs:
        scaff, cfg_name = get_scaffolding_and_config_name(ncfg, scaffolding)
        scaff.gather_fallbacks()
        ncfg_updates = scaff.run_named_config(cfg_name)
        distribute_presets(prefixes, scaffolding, ncfg_updates)
        for ncfg_key, value in iterate_flattened(ncfg_updates):
            set_by_dotted_path(config_updates,
                               join_paths(scaff.path, ncfg_key), value)

    distribute_config_updates(prefixes, scaffolding, config_updates)

    # Phase 3: Normal config scopes
    for scaffold in scaffolding.values():
        scaffold.gather_fallbacks()
        scaffold.set_up_config()

        # update global config
        config = get_configuration(scaffolding)
        # run config hooks
        config_hook_updates = scaffold.run_config_hooks(
            config, command_name, run_logger)
        recursive_update(scaffold.config, config_hook_updates)

    # Phase 4: finalize seeding
    for scaffold in reversed(list(scaffolding.values())):
        scaffold.set_up_seed()  # partially recursive

    config = get_configuration(scaffolding)
    config_modifications = get_config_modifications(scaffolding)

    # ----------------------------------------------------

    experiment_info = experiment.get_experiment_info()
    host_info = get_host_info()
    main_function = get_command(scaffolding, command_name)
    pre_runs = [pr for ing in sorted_ingredients for pr in ing.pre_run_hooks]
    post_runs = [pr for ing in sorted_ingredients for pr in ing.post_run_hooks]

    run = Run(config, config_modifications, main_function,
              copy(experiment.observers), root_logger, run_logger,
              experiment_info, host_info, pre_runs, post_runs,
              experiment.captured_out_filter)

    if hasattr(main_function, 'unobserved'):
        run.unobserved = main_function.unobserved

    run.force = force

    for scaffold in scaffolding.values():
        scaffold.finalize_initialization(run=run)

    return run
示例#10
0
def create_run(experiment, command_name, config_updates=None,
               named_configs=(), force=False, log_level=None):

    sorted_ingredients = gather_ingredients_topological(experiment)
    scaffolding = create_scaffolding(experiment, sorted_ingredients)
    # get all split non-empty prefixes sorted from deepest to shallowest
    prefixes = sorted([s.split('.') for s in scaffolding if s != ''],
                      reverse=True, key=lambda p: len(p))

    # --------- configuration process -------------------

    # Phase 1: Config updates
    config_updates = config_updates or {}
    config_updates = convert_to_nested_dict(config_updates)
    root_logger, run_logger = initialize_logging(experiment, scaffolding,
                                                 log_level)
    distribute_config_updates(prefixes, scaffolding, config_updates)

    # Phase 2: Named Configs
    for ncfg in named_configs:
        scaff, cfg_name = get_scaffolding_and_config_name(ncfg, scaffolding)
        scaff.gather_fallbacks()
        ncfg_updates = scaff.run_named_config(cfg_name)
        distribute_presets(prefixes, scaffolding, ncfg_updates)
        for ncfg_key, value in iterate_flattened(ncfg_updates):
            set_by_dotted_path(config_updates,
                               join_paths(scaff.path, ncfg_key),
                               value)

    distribute_config_updates(prefixes, scaffolding, config_updates)

    # Phase 3: Normal config scopes
    for scaffold in scaffolding.values():
        scaffold.gather_fallbacks()
        scaffold.set_up_config()

        # update global config
        config = get_configuration(scaffolding)
        # run config hooks
        config_hook_updates = scaffold.run_config_hooks(
            config, command_name, run_logger)
        recursive_update(scaffold.config, config_hook_updates)

    # Phase 4: finalize seeding
    for scaffold in reversed(list(scaffolding.values())):
        scaffold.set_up_seed()  # partially recursive

    config = get_configuration(scaffolding)
    config_modifications = get_config_modifications(scaffolding)

    # ----------------------------------------------------

    experiment_info = experiment.get_experiment_info()
    host_info = get_host_info()
    main_function = get_command(scaffolding, command_name)
    pre_runs = [pr for ing in sorted_ingredients for pr in ing.pre_run_hooks]
    post_runs = [pr for ing in sorted_ingredients for pr in ing.post_run_hooks]

    run = Run(config, config_modifications, main_function,
              copy(experiment.observers), root_logger, run_logger,
              experiment_info, host_info, pre_runs, post_runs,
              experiment.captured_out_filter)

    if hasattr(main_function, 'unobserved'):
        run.unobserved = main_function.unobserved

    run.force = force

    for scaffold in scaffolding.values():
        scaffold.finalize_initialization(run=run)

    return run