Exemplo n.º 1
0
    def test_dataset_standard_batching(self):
        dataset_name = 'control_flow_programs/decimal-L10'

        data_dir = tempfile.mkdtemp()
        config = config_lib.get_config()

        with config.unlocked():
            config.dataset.name = dataset_name
            config.dataset.in_memory = True
            config.dataset.batch_size = 5
            config.dataset.representation = 'trace'
        config = ml_collections.FrozenConfigDict(config)

        dataset_info = dataset_utils.get_dataset(data_dir, config)
        item = next(iter(dataset_info.dataset))

        self.assertEqual(item['cfg']['data'].shape[0], 5)
Exemplo n.º 2
0
def configure(data_dir, run_dir, config, override_values, xm_parameters=None):
    """Sets up the Learned Interpreter code with the specified configuration."""
    seed()

    # Apply any overrides set at the command line or in the launcher.
    if config.overrides != config.default_overrides:
        logging.info('Applying overrides set at command line: %s',
                     config.overrides)
        overrides_lib.apply_overrides(
            config, override_names=config.overrides.split(','))
    config.update_from_flattened_dict(override_values)

    # If a checkpoint is specified, it determines the "original run."
    # Otherwise the run_dir, if already present, determines the "original run."
    config_filepath = os.path.join(run_dir, 'config.json')
    if checkpoint_utils.is_checkpoint_specified(config.checkpoint):
        original_checkpoint_path = checkpoint_utils.get_specified_checkpoint_path(
            run_dir, config.checkpoint)
        original_run_dir = checkpoint_utils.get_run_dir(
            original_checkpoint_path)
        original_config_filepath = os.path.join(original_run_dir,
                                                'config.json')
    else:
        checkpoint_dir = checkpoint_utils.build_checkpoint_dir(run_dir)
        original_checkpoint_path = checkpoint_utils.latest_checkpoint(
            checkpoint_dir)
        original_config_filepath = config_filepath
    original_config_exists = gfile.exists(original_config_filepath)

    # Handle any existing configs.
    if original_config_exists:
        original_config = config_utils.load_config(original_config_filepath)

        # Handle the model config.
        if config.runner.model_config == 'load':
            logging.info('Loading the model config from %s',
                         original_config_filepath)
            config.model.update(original_config.model)
            config.dataset.representation = original_config.dataset.representation
        elif config.runner.model_config == 'assert':
            same_config = config_utils.equals(config.model,
                                              original_config.model)
            # Resolution:
            # Either use a new run_dir, or set model_config to 'load' or 'keep'.
            assert same_config, 'Model config has changed.'
        else:
            assert config.runner.model_config == 'keep'

        # Handle the dataset config.
        if config.runner.dataset_config == 'load':
            logging.info('Loading the data config from %s',
                         original_config_filepath)
            config.dataset.update(original_config.dataset)
        elif config.runner.dataset_config == 'assert':
            same_config = config_utils.equals(config.dataset,
                                              original_config.dataset)
            assert same_config, 'Dataset config has changed.'
        else:
            assert config.runner.dataset_config == 'keep'

    elif (config.runner.model_config == 'load'
          or config.runner.dataset_config == 'load'):
        raise ValueError('Original model config not found.')

    # In interactive mode, force batch size 1.
    if config.runner.mode == 'interact':
        config.dataset.batch_size = 1

    config_exists = gfile.exists(config_filepath)
    if not config_exists and config.runner.mode in 'train':
        gfile.makedirs(run_dir)
        config_utils.save_config(config, config_filepath)

    # Load dataset.
    if config.setup.setup_dataset:
        dataset_info = dataset_utils.get_dataset(data_dir, config)
        info = dataset_info.info
    else:
        dataset_info = None
        info = None

    # Create model.
    if config.setup.setup_model:
        model = models_lib.get_model(info, config)
    else:
        model = None

    adapter = adapters_lib.get_default_adapter(info, config)

    return RunConfiguration(mode=config.runner.mode,
                            method=config.runner.method,
                            run_dir=run_dir,
                            data_dir=data_dir,
                            original_checkpoint_path=original_checkpoint_path,
                            model=model,
                            info=info,
                            config=config,
                            adapter=adapter,
                            dataset_info=dataset_info)
def analyze_once(run_configuration):
  """Analyzes the existing model checkpoint.

  Runs inference for the model for each of the evaluation datasets, and writes
  the metrics to disk as JSON.

  Args:
    run_configuration: The setup.RunConfiguration for the run.
  """
  analysis_run_dir = run_configuration.run_dir
  original_checkpoint_path = run_configuration.original_checkpoint_path
  config = run_configuration.config
  data_dir = run_configuration.data_dir
  adapter = run_configuration.adapter
  optimizer = adapter.create_optimizer(run_configuration)

  original_run_dir = checkpoint_utils.get_run_dir(original_checkpoint_path)
  original_run_name = os.path.basename(original_run_dir)
  original_config_path = os.path.join(original_run_dir, 'config.json')
  checkpoint_basename = os.path.basename(original_checkpoint_path)
  analysis_dir = os.path.join(analysis_run_dir, checkpoint_basename)
  analysis_file = os.path.join(analysis_dir, 'data.json')
  checkpoint_dir = checkpoint_utils.build_checkpoint_dir(analysis_dir)
  checkpoint_path = os.path.join(checkpoint_dir, checkpoint_basename)
  gfile.makedirs(checkpoint_dir)
  gfile.makedirs(analysis_dir)

  # Save a copy of the original checkpoint.
  gfile.copy(original_checkpoint_path, checkpoint_path)
  # Save analysis metadata.
  metadata = {
      'name': original_run_name,
      'run_dir': original_run_dir,
      'timestamp': datetime.datetime.now().timestamp(),
  }
  metadata_path = os.path.join(analysis_dir, 'metadata.json')
  with gfile.GFile(metadata_path, 'w') as f:
    f.write(json.dumps(metadata))

  # Save a copy of the original config.
  new_config_path = os.path.join(analysis_run_dir, 'config.json')
  # We set overwrite=True to handle preemption.
  gfile.copy(original_config_path, new_config_path, overwrite=True)
  logging.info('Saving results to %s', analysis_file)

  # Load the datasets to analyze.
  analysis_results = []
  eval_dataset_names = config.launcher.eval_dataset_names.split(',')
  for dataset_name in eval_dataset_names:
    logging.info('Evaluating with checkpoint_path: %s', checkpoint_path)
    logging.info('Evaluating on dataset: %s', dataset_name)
    dataset_info = dataset_utils.get_dataset(data_dir, config, dataset_name)
    run_configuration = dataclasses.replace(
        run_configuration,
        run_dir=analysis_dir,
        original_checkpoint_path=checkpoint_path,
        info=dataset_info.info,
        dataset_info=dataset_info,
    )
    metrics = workflows.predict_once(run_configuration, optimizer)
    logging.info('Done evaluating on dataset: %s', dataset_name)

    results = {
        'dataset_name': dataset_name,
        'accuracy': metrics['accuracy'].tolist(),
        'denominator': metrics['denominator'].tolist(),
    }
    analysis_results.append(results)

    with gfile.GFile(analysis_file, 'wb') as f:
      json.dump(analysis_results, f)
def run_eval_multi_dataset(run_configuration):
  """Evaluates on checkpoints as they become available."""
  config = run_configuration.config
  run_dir = run_configuration.run_dir
  data_dir = run_configuration.data_dir
  adapter = run_configuration.adapter
  optimizer = adapter.create_optimizer(run_configuration)

  eval_dataset_names = config.launcher.eval_dataset_names.split(',')
  dataset_infos = [dataset_utils.get_dataset(data_dir, config, name)
                   for name in eval_dataset_names]
  all_dataset_ids = set(range(len(dataset_infos)))
  dataset_ids_evaluated = set()

  last_dataset_id = -1
  last_checkpoint_path = None
  last_checkpoint_time = time.time()
  checkpoint_dir = checkpoint_utils.build_checkpoint_dir(run_dir)
  success_path = checkpoint_utils.build_success_path(run_dir)
  error_count = 0
  while True:
    success = gfile.exists(success_path)
    # Always evaluate at the latest checkpoint.
    checkpoint_path = checkpoint_utils.latest_checkpoint(checkpoint_dir)
    # Choose the dataset to evaluate.

    if checkpoint_path is not None and checkpoint_path != last_checkpoint_path:
      # The dataset ids evaluated at the latest checkpoint.
      # Our goal is to evaluate all the datasets at the final checkpoint as soon
      # as possible, while providing best effort progress updates along the way.
      dataset_ids_evaluated = set()

    dataset_id = (last_dataset_id + 1) % len(dataset_infos)
    if (dataset_id in dataset_ids_evaluated
        and dataset_ids_evaluated != all_dataset_ids):
      dataset_id = next(iter(all_dataset_ids - dataset_ids_evaluated))

    if ((checkpoint_path, dataset_id) != (last_checkpoint_path, last_dataset_id)
        and dataset_ids_evaluated != all_dataset_ids
        and checkpoint_path is not None):
      logging.info('Evaluating with checkpoint_path: %s', checkpoint_path)
      logging.info('Evaluating on dataset id: %d', dataset_id)
      run_configuration.dataset_info = dataset_infos[dataset_id]
      try:
        workflows.eval_once(run_configuration, checkpoint_path, optimizer)
      except:  # pylint: disable=bare-except
        logging.info('Could not evaluate %s on dataset %d', checkpoint_path,
                     dataset_id)
        error_count += 1
        if error_count >= 10 or config.debug:
          raise
      last_dataset_id = dataset_id
      dataset_ids_evaluated.add(dataset_id)
      last_checkpoint_path = checkpoint_path
      last_checkpoint_time = time.time()
    else:
      if success:
        logging.info('SUCCESS file found. Stopping.')
        break
      if time.time() - last_checkpoint_time > config.eval_timeout:
        logging.info('Timed out waiting for checkpoint. Stopping.')
        break
      logging.info('Waiting for checkpoint.')
      time.sleep(15)