def get_config():
    """Gets the config for the interpreter model."""
    config = default_config()
    override_names = config.overrides.split(',')
    overrides_lib.apply_overrides(config, override_names=override_names)
    adhoc_overrides(config)
    # default_overrides are the overrides that are already applied when you call
    # get_config. Additional overrides may be specified at the command line or by
    # the launcher by setting config.overrides.
    config.default_overrides = config.overrides
    config.lock()
    return config
Exemple #2
0
def configure(data_dir, run_dir, config, override_values, xm_parameters=None):
    """Sets up the Learned Interpreter code with the specified configuration."""
    seed()

    # Apply any overrides set at the command line or in the launcher.
    if config.overrides != config.default_overrides:
        logging.info('Applying overrides set at command line: %s',
                     config.overrides)
        overrides_lib.apply_overrides(
            config, override_names=config.overrides.split(','))
    config.update_from_flattened_dict(override_values)

    # If a checkpoint is specified, it determines the "original run."
    # Otherwise the run_dir, if already present, determines the "original run."
    config_filepath = os.path.join(run_dir, 'config.json')
    if checkpoint_utils.is_checkpoint_specified(config.checkpoint):
        original_checkpoint_path = checkpoint_utils.get_specified_checkpoint_path(
            run_dir, config.checkpoint)
        original_run_dir = checkpoint_utils.get_run_dir(
            original_checkpoint_path)
        original_config_filepath = os.path.join(original_run_dir,
                                                'config.json')
    else:
        checkpoint_dir = checkpoint_utils.build_checkpoint_dir(run_dir)
        original_checkpoint_path = checkpoint_utils.latest_checkpoint(
            checkpoint_dir)
        original_config_filepath = config_filepath
    original_config_exists = gfile.exists(original_config_filepath)

    # Handle any existing configs.
    if original_config_exists:
        original_config = config_utils.load_config(original_config_filepath)

        # Handle the model config.
        if config.runner.model_config == 'load':
            logging.info('Loading the model config from %s',
                         original_config_filepath)
            config.model.update(original_config.model)
            config.dataset.representation = original_config.dataset.representation
        elif config.runner.model_config == 'assert':
            same_config = config_utils.equals(config.model,
                                              original_config.model)
            # Resolution:
            # Either use a new run_dir, or set model_config to 'load' or 'keep'.
            assert same_config, 'Model config has changed.'
        else:
            assert config.runner.model_config == 'keep'

        # Handle the dataset config.
        if config.runner.dataset_config == 'load':
            logging.info('Loading the data config from %s',
                         original_config_filepath)
            config.dataset.update(original_config.dataset)
        elif config.runner.dataset_config == 'assert':
            same_config = config_utils.equals(config.dataset,
                                              original_config.dataset)
            assert same_config, 'Dataset config has changed.'
        else:
            assert config.runner.dataset_config == 'keep'

    elif (config.runner.model_config == 'load'
          or config.runner.dataset_config == 'load'):
        raise ValueError('Original model config not found.')

    # In interactive mode, force batch size 1.
    if config.runner.mode == 'interact':
        config.dataset.batch_size = 1

    config_exists = gfile.exists(config_filepath)
    if not config_exists and config.runner.mode in 'train':
        gfile.makedirs(run_dir)
        config_utils.save_config(config, config_filepath)

    # Load dataset.
    if config.setup.setup_dataset:
        dataset_info = dataset_utils.get_dataset(data_dir, config)
        info = dataset_info.info
    else:
        dataset_info = None
        info = None

    # Create model.
    if config.setup.setup_model:
        model = models_lib.get_model(info, config)
    else:
        model = None

    adapter = adapters_lib.get_default_adapter(info, config)

    return RunConfiguration(mode=config.runner.mode,
                            method=config.runner.method,
                            run_dir=run_dir,
                            data_dir=data_dir,
                            original_checkpoint_path=original_checkpoint_path,
                            model=model,
                            info=info,
                            config=config,
                            adapter=adapter,
                            dataset_info=dataset_info)
 def test_valid_overrides(self):
     config = config_lib.get_config()
     all_overrides = overrides_lib.get_all_overrides()
     all_override_names = all_overrides.keys()
     overrides_lib.apply_overrides(config,
                                   override_names=all_override_names)