示例#1
0
def configure(data_dir, run_dir, config, override_values, xm_parameters=None):
    """Sets up the Learned Interpreter code with the specified configuration."""
    seed()

    # Apply any overrides set at the command line or in the launcher.
    if config.overrides != config.default_overrides:
        logging.info('Applying overrides set at command line: %s',
                     config.overrides)
        overrides_lib.apply_overrides(
            config, override_names=config.overrides.split(','))
    config.update_from_flattened_dict(override_values)

    # If a checkpoint is specified, it determines the "original run."
    # Otherwise the run_dir, if already present, determines the "original run."
    config_filepath = os.path.join(run_dir, 'config.json')
    if checkpoint_utils.is_checkpoint_specified(config.checkpoint):
        original_checkpoint_path = checkpoint_utils.get_specified_checkpoint_path(
            run_dir, config.checkpoint)
        original_run_dir = checkpoint_utils.get_run_dir(
            original_checkpoint_path)
        original_config_filepath = os.path.join(original_run_dir,
                                                'config.json')
    else:
        checkpoint_dir = checkpoint_utils.build_checkpoint_dir(run_dir)
        original_checkpoint_path = checkpoint_utils.latest_checkpoint(
            checkpoint_dir)
        original_config_filepath = config_filepath
    original_config_exists = gfile.exists(original_config_filepath)

    # Handle any existing configs.
    if original_config_exists:
        original_config = config_utils.load_config(original_config_filepath)

        # Handle the model config.
        if config.runner.model_config == 'load':
            logging.info('Loading the model config from %s',
                         original_config_filepath)
            config.model.update(original_config.model)
            config.dataset.representation = original_config.dataset.representation
        elif config.runner.model_config == 'assert':
            same_config = config_utils.equals(config.model,
                                              original_config.model)
            # Resolution:
            # Either use a new run_dir, or set model_config to 'load' or 'keep'.
            assert same_config, 'Model config has changed.'
        else:
            assert config.runner.model_config == 'keep'

        # Handle the dataset config.
        if config.runner.dataset_config == 'load':
            logging.info('Loading the data config from %s',
                         original_config_filepath)
            config.dataset.update(original_config.dataset)
        elif config.runner.dataset_config == 'assert':
            same_config = config_utils.equals(config.dataset,
                                              original_config.dataset)
            assert same_config, 'Dataset config has changed.'
        else:
            assert config.runner.dataset_config == 'keep'

    elif (config.runner.model_config == 'load'
          or config.runner.dataset_config == 'load'):
        raise ValueError('Original model config not found.')

    # In interactive mode, force batch size 1.
    if config.runner.mode == 'interact':
        config.dataset.batch_size = 1

    config_exists = gfile.exists(config_filepath)
    if not config_exists and config.runner.mode in 'train':
        gfile.makedirs(run_dir)
        config_utils.save_config(config, config_filepath)

    # Load dataset.
    if config.setup.setup_dataset:
        dataset_info = dataset_utils.get_dataset(data_dir, config)
        info = dataset_info.info
    else:
        dataset_info = None
        info = None

    # Create model.
    if config.setup.setup_model:
        model = models_lib.get_model(info, config)
    else:
        model = None

    adapter = adapters_lib.get_default_adapter(info, config)

    return RunConfiguration(mode=config.runner.mode,
                            method=config.runner.method,
                            run_dir=run_dir,
                            data_dir=data_dir,
                            original_checkpoint_path=original_checkpoint_path,
                            model=model,
                            info=info,
                            config=config,
                            adapter=adapter,
                            dataset_info=dataset_info)
def analyze_once(run_configuration):
  """Analyzes the existing model checkpoint.

  Runs inference for the model for each of the evaluation datasets, and writes
  the metrics to disk as JSON.

  Args:
    run_configuration: The setup.RunConfiguration for the run.
  """
  analysis_run_dir = run_configuration.run_dir
  original_checkpoint_path = run_configuration.original_checkpoint_path
  config = run_configuration.config
  data_dir = run_configuration.data_dir
  adapter = run_configuration.adapter
  optimizer = adapter.create_optimizer(run_configuration)

  original_run_dir = checkpoint_utils.get_run_dir(original_checkpoint_path)
  original_run_name = os.path.basename(original_run_dir)
  original_config_path = os.path.join(original_run_dir, 'config.json')
  checkpoint_basename = os.path.basename(original_checkpoint_path)
  analysis_dir = os.path.join(analysis_run_dir, checkpoint_basename)
  analysis_file = os.path.join(analysis_dir, 'data.json')
  checkpoint_dir = checkpoint_utils.build_checkpoint_dir(analysis_dir)
  checkpoint_path = os.path.join(checkpoint_dir, checkpoint_basename)
  gfile.makedirs(checkpoint_dir)
  gfile.makedirs(analysis_dir)

  # Save a copy of the original checkpoint.
  gfile.copy(original_checkpoint_path, checkpoint_path)
  # Save analysis metadata.
  metadata = {
      'name': original_run_name,
      'run_dir': original_run_dir,
      'timestamp': datetime.datetime.now().timestamp(),
  }
  metadata_path = os.path.join(analysis_dir, 'metadata.json')
  with gfile.GFile(metadata_path, 'w') as f:
    f.write(json.dumps(metadata))

  # Save a copy of the original config.
  new_config_path = os.path.join(analysis_run_dir, 'config.json')
  # We set overwrite=True to handle preemption.
  gfile.copy(original_config_path, new_config_path, overwrite=True)
  logging.info('Saving results to %s', analysis_file)

  # Load the datasets to analyze.
  analysis_results = []
  eval_dataset_names = config.launcher.eval_dataset_names.split(',')
  for dataset_name in eval_dataset_names:
    logging.info('Evaluating with checkpoint_path: %s', checkpoint_path)
    logging.info('Evaluating on dataset: %s', dataset_name)
    dataset_info = dataset_utils.get_dataset(data_dir, config, dataset_name)
    run_configuration = dataclasses.replace(
        run_configuration,
        run_dir=analysis_dir,
        original_checkpoint_path=checkpoint_path,
        info=dataset_info.info,
        dataset_info=dataset_info,
    )
    metrics = workflows.predict_once(run_configuration, optimizer)
    logging.info('Done evaluating on dataset: %s', dataset_name)

    results = {
        'dataset_name': dataset_name,
        'accuracy': metrics['accuracy'].tolist(),
        'denominator': metrics['denominator'].tolist(),
    }
    analysis_results.append(results)

    with gfile.GFile(analysis_file, 'wb') as f:
      json.dump(analysis_results, f)