Esempio n. 1
0
  def _get_input_iterator(
      self, input_fn: Callable[[params_dict.ParamsDict], tf.data.Dataset],
      strategy: tf.distribute.Strategy) -> Optional[Iterator[Any]]:
    """Returns distributed dataset iterator.

    Args:
      input_fn: (params: dict) -> tf.data.Dataset.
      strategy: an instance of tf.distribute.Strategy.

    Returns:
      An iterator that yields input tensors.
    """

    if input_fn is None:
      return None
    # When training with multiple TPU workers, datasets needs to be cloned
    # across workers. Since Dataset instance cannot be cloned in eager mode,
    # we instead pass callable that returns a dataset.
    input_data = input_fn(self._params)
    return iter(strategy.experimental_distribute_dataset(input_data))
Esempio n. 2
0
def run_train_loop(
        train_dataset_builder: ub.datasets.BaseDataset,
        validation_dataset_builder: Optional[ub.datasets.BaseDataset],
        test_dataset_builder: ub.datasets.BaseDataset, batch_size: int,
        eval_batch_size: int, model: tf.keras.Model,
        optimizer: tf.keras.optimizers.Optimizer, eval_frequency: int,
        log_frequency: int, trial_dir: Optional[str], train_steps: int,
        mode: str, strategy: tf.distribute.Strategy,
        metrics: Dict[str, Union[tf.keras.metrics.Metric,
                                 rm.metrics.KerasMetric]], hparams: Dict[str,
                                                                         Any]):
    """Train, possibly evaluate the model, and record metrics."""

    checkpoint_manager = None
    last_checkpoint_step = 0
    if trial_dir:
        # TODO(znado): add train_iterator to this once DistributedIterators are
        # checkpointable.
        checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
        checkpoint_manager = tf.train.CheckpointManager(checkpoint,
                                                        trial_dir,
                                                        max_to_keep=None)
        checkpoint_path = tf.train.latest_checkpoint(trial_dir)
        if checkpoint_path:
            last_checkpoint_step = int(checkpoint_path.split('-')[-1])
            if last_checkpoint_step >= train_steps:
                # If we have already finished training, exit.
                logging.info(
                    'Training has already finished at step %d. Exiting.',
                    train_steps)
                return
            elif last_checkpoint_step > 0:
                # Restore from where we previously finished.
                checkpoint.restore(checkpoint_manager.latest_checkpoint)
                logging.info('Resuming training from step %d.',
                             last_checkpoint_step)

    train_dataset = train_dataset_builder.load(batch_size=batch_size)
    train_dataset = strategy.experimental_distribute_dataset(train_dataset)
    train_iterator = iter(train_dataset)

    iterations_per_loop = min(eval_frequency, log_frequency)
    # We can only run `iterations_per_loop` steps at a time, because we cannot
    # checkpoint the model inside a tf.function.
    train_step_fn = _train_step_fn(model,
                                   optimizer,
                                   strategy,
                                   metrics,
                                   iterations_per_loop=iterations_per_loop)
    if trial_dir:
        train_summary_writer = tf.summary.create_file_writer(
            os.path.join(trial_dir, 'train'))
    else:
        train_summary_writer = None

    val_summary_writer = None
    test_summary_writer = None
    if mode == 'train_and_eval':
        (val_fn, val_dataset, val_summary_writer, test_fn, test_dataset,
         test_summary_writer) = eval_lib.setup_eval(
             validation_dataset_builder=validation_dataset_builder,
             test_dataset_builder=test_dataset_builder,
             batch_size=eval_batch_size,
             strategy=strategy,
             trial_dir=trial_dir,
             model=model,
             metrics=metrics)
    # Each call to train_step_fn will run iterations_per_loop steps.
    num_train_fn_steps = train_steps // iterations_per_loop
    # We are guaranteed that `last_checkpoint_step` will be divisible by
    # `iterations_per_loop` because that is how frequently we checkpoint.
    start_train_fn_step = last_checkpoint_step // iterations_per_loop
    for train_fn_step in range(start_train_fn_step, num_train_fn_steps):
        current_step = train_fn_step * iterations_per_loop
        # Checkpoint at the start of the step, before the training op is run.
        if (checkpoint_manager and current_step % eval_frequency == 0
                and current_step != last_checkpoint_step):
            checkpoint_manager.save(checkpoint_number=current_step)
        if mode == 'train_and_eval' and current_step % eval_frequency == 0:
            eval_lib.run_eval_epoch(
                val_fn,
                val_dataset,
                val_summary_writer,
                test_fn,
                test_dataset,
                test_summary_writer,
                current_step,
                hparams=None)  # Only write hparams on the last step.
        train_step_outputs = train_step_fn(train_iterator)
        if current_step % log_frequency == 0:
            _write_summaries(train_step_outputs, current_step,
                             train_summary_writer)
            train_step_outputs_np = {
                k: v.numpy()
                for k, v in train_step_outputs.items()
            }
            logging.info('Training metrics for step %d: %s', current_step,
                         train_step_outputs_np)

    if train_steps % iterations_per_loop != 0:
        remainder_train_step_fn = _train_step_fn(
            model,
            optimizer,
            strategy,
            metrics,
            iterations_per_loop=train_steps % iterations_per_loop)
        train_step_outputs = remainder_train_step_fn(train_iterator)

    # Always evaluate and record metrics at the end of training.
    _write_summaries(train_step_outputs, train_steps, train_summary_writer,
                     hparams)
    train_step_outputs_np = {
        k: v.numpy()
        for k, v in train_step_outputs.items()
    }
    logging.info('Training metrics for step %d: %s', current_step,
                 train_step_outputs_np)
    if mode == 'train_and_eval':
        eval_lib.run_eval_epoch(val_fn,
                                val_dataset,
                                val_summary_writer,
                                test_fn,
                                test_dataset,
                                test_summary_writer,
                                train_steps,
                                hparams=hparams)
    # Save checkpoint at the end of training.
    if checkpoint_manager:
        checkpoint_manager.save(checkpoint_number=train_steps)