Exemple #1
0
def run_eval(run_configuration):
  """Evaluates on checkpoints as they become available."""
  config = run_configuration.config
  run_dir = run_configuration.run_dir
  adapter = run_configuration.adapter
  optimizer = adapter.create_optimizer(run_configuration)

  last_checkpoint_path = None
  last_checkpoint_time = time.time()
  checkpoint_dir = checkpoint_utils.build_checkpoint_dir(run_dir)
  success_path = checkpoint_utils.build_success_path(run_dir)
  error_count = 0
  while True:
    success = tf.io.gfile.exists(success_path)
    checkpoint_path = checkpoint_utils.latest_checkpoint(checkpoint_dir)
    if checkpoint_path is not None and checkpoint_path != last_checkpoint_path:
      logging.info('Evaluating with checkpoint_path: %s', checkpoint_path)
      try:
        eval_once(run_configuration, checkpoint_path, optimizer)
      except:  # pylint: disable=bare-except
        logging.info('Could not evaluate %s', checkpoint_path)
        error_count += 1
        if error_count >= 10 or config.debug:
          raise
      last_checkpoint_path = checkpoint_path
      last_checkpoint_time = time.time()
    else:
      if success:
        logging.info('SUCCESS file found. Stopping.')
        break
      if time.time() - last_checkpoint_time > config.eval_timeout:
        logging.info('Timed out waiting for checkpoint. Stopping.')
        break
      logging.info('Waiting for checkpoint.')
      time.sleep(15)
Exemple #2
0
def run_train_single_device(run_configuration):
    """Runs the training workflow without pmap or jit."""
    config = run_configuration.config
    run_dir = run_configuration.run_dir
    adapter = run_configuration.adapter
    checkpoint_path = run_configuration.original_checkpoint_path
    dataset = run_configuration.dataset_info.dataset

    random_seed = 0
    rng = jax.random.PRNGKey(random_seed)
    rng = jax.random.fold_in(rng, jax.host_id())
    dropout_rng, init_rng = jax.random.split(rng)

    # Set up optimizer.
    optimizer = adapter.create_optimizer(run_configuration, rng=init_rng)

    # Set up train step.
    train_step = adapter.make_train_step(single_device=True)

    # Set up checkpointing.
    # TODO(dbieber): Set up phoenix.
    checkpoint_dir = checkpoint_utils.build_checkpoint_dir(run_dir)
    if checkpoint_path is None:
        checkpoint_path = checkpoint_utils.latest_checkpoint(checkpoint_dir)
    optimizer = checkpoint_utils.handle_restart_behavior(
        checkpoint_path, optimizer, config)

    start_step = int(optimizer.state.step)
    num_train_steps = config.train.total_steps

    # Begin training loop.
    dataset_iter_raw = iter(dataset)
    dataset_iter = adapter.preprocess(dataset_iter_raw, single_device=True)

    for step, example in zip(range(start_step, num_train_steps), dataset_iter):
        print(f'Step #{step}')
        train_inputs = adapter.get_train_inputs(example)
        optimizer, metrics, dropout_rng, logits, state = train_step(
            optimizer, train_inputs, dropout_rng)
        del metrics, logits, state  # Unused.

        # Save a Checkpoint.
        if ((step % config.logging.save_freq == 0 and step > 0)
                or step == num_train_steps - 1):
            if jax.host_id() == 0 and config.logging.save_freq:
                # Save unreplicated optimizer + model state.
                checkpoint_utils.save_checkpoint(checkpoint_dir, optimizer,
                                                 step)
Exemple #3
0
def run_train(run_configuration):
    """Runs the training workflow."""
    config = run_configuration.config
    run_dir = run_configuration.run_dir
    adapter = run_configuration.adapter
    log_dir = os.path.join(run_dir, 'train')
    checkpoint_path = run_configuration.original_checkpoint_path

    dataset = run_configuration.dataset_info.dataset
    info = run_configuration.dataset_info.info

    random_seed = 0
    rng = jax.random.PRNGKey(random_seed)
    rng = jax.random.fold_in(rng, jax.host_id())
    rng, init_rng = jax.random.split(rng)
    dropout_rngs = jax.random.split(rng, jax.local_device_count())

    # Set up optimizer.
    optimizer = adapter.create_optimizer(run_configuration, rng=init_rng)

    # Set up train step.
    train_step = adapter.make_train_step()

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(log_dir)

    # Set up checkpointing.
    # TODO(dbieber): Set up phoenix.
    checkpoint_dir = checkpoint_utils.build_checkpoint_dir(run_dir)
    if checkpoint_path is None:
        checkpoint_path = checkpoint_utils.latest_checkpoint(checkpoint_dir)
    optimizer = checkpoint_utils.handle_restart_behavior(
        checkpoint_path, optimizer, config)

    start_step = int(optimizer.state.step)
    num_train_steps = config.train.total_steps

    # Replicate optimizer.
    optimizer = flax.jax_utils.replicate(optimizer)

    # Begin training loop.
    dataset_iter_raw = iter(dataset)
    dataset_iter = adapter.preprocess(dataset_iter_raw)

    summary_freq = config.logging.summary_freq
    metrics_all = []
    tick = time.time()
    for step, example in zip(range(start_step, num_train_steps), dataset_iter):
        train_inputs = adapter.get_train_inputs(example)
        optimizer, metrics, dropout_rngs, logits, state = train_step(
            optimizer, train_inputs, dropout_rngs)
        metrics_all.append(metrics)

        # Save a Checkpoint
        if ((step % config.logging.save_freq == 0 and step > 0)
                or step == num_train_steps - 1):
            if jax.host_id() == 0 and config.logging.save_freq:
                # Save unreplicated optimizer + model state.
                checkpoint_utils.save_checkpoint(
                    checkpoint_dir, jax_utils.unreplicate(optimizer), step)

        # Periodic metric handling.
        if summary_freq and step % summary_freq == 0 and step > 0:
            metrics_all = common_utils.get_metrics(metrics_all)
            lr = metrics_all.pop('learning_rate').mean()
            metrics_sums = jax.tree_map(jnp.sum, metrics_all)
            denominator = metrics_sums.pop('denominator')
            summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
            summary['learning_rate'] = lr
            # Calculate (clipped) perplexity after averaging log-perplexities:
            summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']),
                                             a_max=1.0e4)
            logging.info('train step: %d, loss: %.4f', step, summary['loss'])
            if jax.host_id() == 0:
                tock = time.time()
                steps_per_sec = summary_freq / (tock - tick)
                examples_per_sec = denominator / (tock - tick)
                tick = tock
                summary_writer.scalar('per-second/steps', steps_per_sec, step)
                summary_writer.scalar('per-second/examples', examples_per_sec,
                                      step)
                for key, val in summary.items():
                    summary_writer.scalar(key, val, step)

                adapter.write_summaries(example, logits, summary_writer, info,
                                        step, state)

                summary_writer.flush()
            # Reset metric accumulation for next evaluation cycle.
            metrics_all = []
Exemple #4
0
def configure(data_dir, run_dir, config, override_values, xm_parameters=None):
    """Sets up the Learned Interpreter code with the specified configuration."""
    seed()

    # Apply any overrides set at the command line or in the launcher.
    if config.overrides != config.default_overrides:
        logging.info('Applying overrides set at command line: %s',
                     config.overrides)
        overrides_lib.apply_overrides(
            config, override_names=config.overrides.split(','))
    config.update_from_flattened_dict(override_values)

    # If a checkpoint is specified, it determines the "original run."
    # Otherwise the run_dir, if already present, determines the "original run."
    config_filepath = os.path.join(run_dir, 'config.json')
    if checkpoint_utils.is_checkpoint_specified(config.checkpoint):
        original_checkpoint_path = checkpoint_utils.get_specified_checkpoint_path(
            run_dir, config.checkpoint)
        original_run_dir = checkpoint_utils.get_run_dir(
            original_checkpoint_path)
        original_config_filepath = os.path.join(original_run_dir,
                                                'config.json')
    else:
        checkpoint_dir = checkpoint_utils.build_checkpoint_dir(run_dir)
        original_checkpoint_path = checkpoint_utils.latest_checkpoint(
            checkpoint_dir)
        original_config_filepath = config_filepath
    original_config_exists = gfile.exists(original_config_filepath)

    # Handle any existing configs.
    if original_config_exists:
        original_config = config_utils.load_config(original_config_filepath)

        # Handle the model config.
        if config.runner.model_config == 'load':
            logging.info('Loading the model config from %s',
                         original_config_filepath)
            config.model.update(original_config.model)
            config.dataset.representation = original_config.dataset.representation
        elif config.runner.model_config == 'assert':
            same_config = config_utils.equals(config.model,
                                              original_config.model)
            # Resolution:
            # Either use a new run_dir, or set model_config to 'load' or 'keep'.
            assert same_config, 'Model config has changed.'
        else:
            assert config.runner.model_config == 'keep'

        # Handle the dataset config.
        if config.runner.dataset_config == 'load':
            logging.info('Loading the data config from %s',
                         original_config_filepath)
            config.dataset.update(original_config.dataset)
        elif config.runner.dataset_config == 'assert':
            same_config = config_utils.equals(config.dataset,
                                              original_config.dataset)
            assert same_config, 'Dataset config has changed.'
        else:
            assert config.runner.dataset_config == 'keep'

    elif (config.runner.model_config == 'load'
          or config.runner.dataset_config == 'load'):
        raise ValueError('Original model config not found.')

    # In interactive mode, force batch size 1.
    if config.runner.mode == 'interact':
        config.dataset.batch_size = 1

    config_exists = gfile.exists(config_filepath)
    if not config_exists and config.runner.mode in 'train':
        gfile.makedirs(run_dir)
        config_utils.save_config(config, config_filepath)

    # Load dataset.
    if config.setup.setup_dataset:
        dataset_info = dataset_utils.get_dataset(data_dir, config)
        info = dataset_info.info
    else:
        dataset_info = None
        info = None

    # Create model.
    if config.setup.setup_model:
        model = models_lib.get_model(info, config)
    else:
        model = None

    adapter = adapters_lib.get_default_adapter(info, config)

    return RunConfiguration(mode=config.runner.mode,
                            method=config.runner.method,
                            run_dir=run_dir,
                            data_dir=data_dir,
                            original_checkpoint_path=original_checkpoint_path,
                            model=model,
                            info=info,
                            config=config,
                            adapter=adapter,
                            dataset_info=dataset_info)
def run_eval_multi_dataset(run_configuration):
  """Evaluates on checkpoints as they become available."""
  config = run_configuration.config
  run_dir = run_configuration.run_dir
  data_dir = run_configuration.data_dir
  adapter = run_configuration.adapter
  optimizer = adapter.create_optimizer(run_configuration)

  eval_dataset_names = config.launcher.eval_dataset_names.split(',')
  dataset_infos = [dataset_utils.get_dataset(data_dir, config, name)
                   for name in eval_dataset_names]
  all_dataset_ids = set(range(len(dataset_infos)))
  dataset_ids_evaluated = set()

  last_dataset_id = -1
  last_checkpoint_path = None
  last_checkpoint_time = time.time()
  checkpoint_dir = checkpoint_utils.build_checkpoint_dir(run_dir)
  success_path = checkpoint_utils.build_success_path(run_dir)
  error_count = 0
  while True:
    success = gfile.exists(success_path)
    # Always evaluate at the latest checkpoint.
    checkpoint_path = checkpoint_utils.latest_checkpoint(checkpoint_dir)
    # Choose the dataset to evaluate.

    if checkpoint_path is not None and checkpoint_path != last_checkpoint_path:
      # The dataset ids evaluated at the latest checkpoint.
      # Our goal is to evaluate all the datasets at the final checkpoint as soon
      # as possible, while providing best effort progress updates along the way.
      dataset_ids_evaluated = set()

    dataset_id = (last_dataset_id + 1) % len(dataset_infos)
    if (dataset_id in dataset_ids_evaluated
        and dataset_ids_evaluated != all_dataset_ids):
      dataset_id = next(iter(all_dataset_ids - dataset_ids_evaluated))

    if ((checkpoint_path, dataset_id) != (last_checkpoint_path, last_dataset_id)
        and dataset_ids_evaluated != all_dataset_ids
        and checkpoint_path is not None):
      logging.info('Evaluating with checkpoint_path: %s', checkpoint_path)
      logging.info('Evaluating on dataset id: %d', dataset_id)
      run_configuration.dataset_info = dataset_infos[dataset_id]
      try:
        workflows.eval_once(run_configuration, checkpoint_path, optimizer)
      except:  # pylint: disable=bare-except
        logging.info('Could not evaluate %s on dataset %d', checkpoint_path,
                     dataset_id)
        error_count += 1
        if error_count >= 10 or config.debug:
          raise
      last_dataset_id = dataset_id
      dataset_ids_evaluated.add(dataset_id)
      last_checkpoint_path = checkpoint_path
      last_checkpoint_time = time.time()
    else:
      if success:
        logging.info('SUCCESS file found. Stopping.')
        break
      if time.time() - last_checkpoint_time > config.eval_timeout:
        logging.info('Timed out waiting for checkpoint. Stopping.')
        break
      logging.info('Waiting for checkpoint.')
      time.sleep(15)