예제 #1
0
def setup_eval(
    dataset_builder: ub.datasets.BaseDataset,
    strategy,
    trial_dir: Optional[str],
    model: tf.keras.Model,
    metrics: Dict[str, tf.keras.metrics.Metric]) -> _EvalSetupResult:
  """Setup the test and optionally validation loggers, step fns and datasets."""
  test_dataset = dataset_builder.build('test')
  test_dataset = strategy.experimental_distribute_dataset(test_dataset)
  if trial_dir:
    test_summary_writer = tf.summary.create_file_writer(
        os.path.join(trial_dir, 'test'))
  else:
    test_summary_writer = None
  num_test_steps = (
      dataset_builder.info['num_test_examples'] //
      dataset_builder.eval_batch_size)
  test_fn = eval_step_fn(
      model,
      strategy,
      metrics,
      iterations_per_loop=num_test_steps)

  # Have to have separate val_fn and test_fn because otherwise tf.function
  # retraces the function each time, which is very slow, because we are passing
  # in a Python dict of metrics and int for iterations_per_loop.
  val_fn = None
  val_dataset = None
  val_summary_writer = None
  if dataset_builder.info['num_validation_examples'] > 0:
    num_val_steps = (
        dataset_builder.info['num_validation_examples'] //
        dataset_builder.eval_batch_size)
    val_dataset = dataset_builder.build('validation')
    val_dataset = strategy.experimental_distribute_dataset(val_dataset)
    if trial_dir:
      val_summary_writer = tf.summary.create_file_writer(
          os.path.join(trial_dir, 'validation'))
    if num_val_steps == num_test_steps:
      val_fn = test_fn
    else:
      # The metrics are reset at the start of each call to {val,test}_fn, so
      # reusing them is safe.
      val_fn = eval_step_fn(
          model,
          strategy,
          metrics,
          iterations_per_loop=num_val_steps)
  return (
      val_fn, val_dataset, val_summary_writer, test_fn, test_dataset,
      test_summary_writer)
예제 #2
0
def setup_eval(
    validation_dataset_builder: Optional[ub.datasets.BaseDataset],
    test_dataset_builder: ub.datasets.BaseDataset,
    batch_size: int,
    strategy,
    trial_dir: str,
    model: tf.keras.Model,
    loss_fn,
    metric_names: List[str]) -> _EvalSetupResult:
  """Setup the test and optionally validation loggers, step fns and datasets."""
  test_dataset = test_dataset_builder.load(batch_size=batch_size)
  test_dataset = strategy.experimental_distribute_dataset(test_dataset)
  test_summary_writer = tf.summary.create_file_writer(
      os.path.join(trial_dir, 'summaries/test'))
  num_test_steps = test_dataset_builder.num_examples // batch_size
  test_metrics = {
      name: tf.keras.metrics.Mean(name, dtype=tf.float32)
      for name in metric_names
  }
  test_fn = eval_step_fn(
      model,
      loss_fn,
      strategy,
      test_metrics,
      iterations_per_loop=num_test_steps)

  # Have to have separate val_fn and test_fn because otherwise tf.function
  # retraces the function each time, which is very slow, because we are passing
  # in a Python dict of metrics and int for iterations_per_loop.
  val_fn = None
  val_dataset = None
  val_summary_writer = None
  if validation_dataset_builder:
    num_val_steps = validation_dataset_builder.num_examples // batch_size
    val_dataset = validation_dataset_builder.load(batch_size=batch_size)
    val_dataset = strategy.experimental_distribute_dataset(val_dataset)
    val_summary_writer = tf.summary.create_file_writer(
        os.path.join(trial_dir, 'summaries/val'))
    if num_val_steps == num_test_steps:
      val_fn = test_fn
    else:
      # The metrics are reset at the start of each call to {val,test}_fn, so
      # reusing them is safe.
      val_fn = eval_step_fn(
          model,
          loss_fn,
          strategy,
          test_metrics,
          iterations_per_loop=num_val_steps)
  return (
      val_fn, val_dataset, val_summary_writer, test_fn, test_dataset,
      test_summary_writer)
예제 #3
0
def setup_eval(validation_dataset_builder: Optional[ub.datasets.BaseDataset],
               test_dataset_builder: ub.datasets.BaseDataset,
               batch_size: int,
               strategy,
               trial_dir: str,
               model: tf.keras.Model,
               metrics: Dict[str, tf.keras.metrics.Metric],
               ood_dataset_builder: Optional[ub.datasets.BaseDataset] = None,
               ood_metrics: Dict[str, tf.keras.metrics.Metric] = None,
               mean_field_factor: float = -1) -> _EvalSetupResult:
    """Setup the test and optionally validation loggers, step fns and datasets."""
    test_dataset = test_dataset_builder.load(batch_size=batch_size)
    test_dataset = strategy.experimental_distribute_dataset(test_dataset)
    test_summary_writer = tf.summary.create_file_writer(
        os.path.join(trial_dir, 'test'))
    num_test_steps = test_dataset_builder.num_examples // batch_size
    test_fn = eval_step_fn(model,
                           strategy,
                           metrics,
                           iterations_per_loop=num_test_steps,
                           mean_field_factor=mean_field_factor)
    ood_fn = None
    ood_dataset = None
    ood_summary_writer = None

    if ((ood_dataset_builder and not ood_metrics)
            or (not ood_dataset_builder and ood_metrics)):
        raise ValueError('Both ood_dataset_builder and ood_metrics must be'
                         ' specified.')
    if ood_dataset_builder:
        ood_dataset = ood_dataset_builder.load(batch_size=batch_size)
        ood_dataset = strategy.experimental_distribute_dataset(ood_dataset)
        ood_summary_writer = tf.summary.create_file_writer(
            os.path.join(trial_dir, 'ood'))
        num_ood_steps = ood_dataset_builder.num_examples // batch_size

        ood_fn = eval_step_fn(model,
                              strategy,
                              ood_metrics,
                              iterations_per_loop=num_ood_steps,
                              label_key='is_in_distribution',
                              mean_field_factor=mean_field_factor)

    # Have to have separate val_fn and test_fn because otherwise tf.function
    # retraces the function each time, which is very slow, because we are passing
    # in a Python dict of metrics and int for iterations_per_loop.
    val_fn = None
    val_dataset = None
    val_summary_writer = None
    if validation_dataset_builder:
        num_val_steps = validation_dataset_builder.num_examples // batch_size
        val_dataset = validation_dataset_builder.load(batch_size=batch_size)
        val_dataset = strategy.experimental_distribute_dataset(val_dataset)
        val_summary_writer = tf.summary.create_file_writer(
            os.path.join(trial_dir, 'validation'))
        if num_val_steps == num_test_steps:
            val_fn = test_fn
        else:
            # The metrics are reset at the start of each call to {val,test}_fn, so
            # reusing them is safe.
            val_fn = eval_step_fn(model,
                                  strategy,
                                  metrics,
                                  iterations_per_loop=num_val_steps,
                                  mean_field_factor=mean_field_factor)

    return (test_fn, test_dataset, test_summary_writer, val_fn, val_dataset,
            val_summary_writer, ood_fn, ood_dataset, ood_summary_writer)
예제 #4
0
def run_train_loop(
        train_dataset_builder: ub.datasets.BaseDataset,
        validation_dataset_builder: Optional[ub.datasets.BaseDataset],
        test_dataset_builder: ub.datasets.BaseDataset, batch_size: int,
        eval_batch_size: int, model: tf.keras.Model,
        optimizer: tf.keras.optimizers.Optimizer, eval_frequency: int,
        log_frequency: int, trial_dir: Optional[str], train_steps: int,
        mode: str, strategy: tf.distribute.Strategy,
        metrics: Dict[str, Union[tf.keras.metrics.Metric,
                                 rm.metrics.KerasMetric]], hparams: Dict[str,
                                                                         Any]):
    """Train, possibly evaluate the model, and record metrics."""

    checkpoint_manager = None
    last_checkpoint_step = 0
    if trial_dir:
        # TODO(znado): add train_iterator to this once DistributedIterators are
        # checkpointable.
        checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
        checkpoint_manager = tf.train.CheckpointManager(checkpoint,
                                                        trial_dir,
                                                        max_to_keep=None)
        checkpoint_path = tf.train.latest_checkpoint(trial_dir)
        if checkpoint_path:
            last_checkpoint_step = int(checkpoint_path.split('-')[-1])
            if last_checkpoint_step >= train_steps:
                # If we have already finished training, exit.
                logging.info(
                    'Training has already finished at step %d. Exiting.',
                    train_steps)
                return
            elif last_checkpoint_step > 0:
                # Restore from where we previously finished.
                checkpoint.restore(checkpoint_manager.latest_checkpoint)
                logging.info('Resuming training from step %d.',
                             last_checkpoint_step)

    train_dataset = train_dataset_builder.load(batch_size=batch_size)
    train_dataset = strategy.experimental_distribute_dataset(train_dataset)
    train_iterator = iter(train_dataset)

    iterations_per_loop = min(eval_frequency, log_frequency)
    # We can only run `iterations_per_loop` steps at a time, because we cannot
    # checkpoint the model inside a tf.function.
    train_step_fn = _train_step_fn(model,
                                   optimizer,
                                   strategy,
                                   metrics,
                                   iterations_per_loop=iterations_per_loop)
    if trial_dir:
        train_summary_writer = tf.summary.create_file_writer(
            os.path.join(trial_dir, 'train'))
    else:
        train_summary_writer = None

    val_summary_writer = None
    test_summary_writer = None
    if mode == 'train_and_eval':
        (val_fn, val_dataset, val_summary_writer, test_fn, test_dataset,
         test_summary_writer) = eval_lib.setup_eval(
             validation_dataset_builder=validation_dataset_builder,
             test_dataset_builder=test_dataset_builder,
             batch_size=eval_batch_size,
             strategy=strategy,
             trial_dir=trial_dir,
             model=model,
             metrics=metrics)
    # Each call to train_step_fn will run iterations_per_loop steps.
    num_train_fn_steps = train_steps // iterations_per_loop
    # We are guaranteed that `last_checkpoint_step` will be divisible by
    # `iterations_per_loop` because that is how frequently we checkpoint.
    start_train_fn_step = last_checkpoint_step // iterations_per_loop
    for train_fn_step in range(start_train_fn_step, num_train_fn_steps):
        current_step = train_fn_step * iterations_per_loop
        # Checkpoint at the start of the step, before the training op is run.
        if (checkpoint_manager and current_step % eval_frequency == 0
                and current_step != last_checkpoint_step):
            checkpoint_manager.save(checkpoint_number=current_step)
        if mode == 'train_and_eval' and current_step % eval_frequency == 0:
            eval_lib.run_eval_epoch(
                val_fn,
                val_dataset,
                val_summary_writer,
                test_fn,
                test_dataset,
                test_summary_writer,
                current_step,
                hparams=None)  # Only write hparams on the last step.
        train_step_outputs = train_step_fn(train_iterator)
        if current_step % log_frequency == 0:
            _write_summaries(train_step_outputs, current_step,
                             train_summary_writer)
            train_step_outputs_np = {
                k: v.numpy()
                for k, v in train_step_outputs.items()
            }
            logging.info('Training metrics for step %d: %s', current_step,
                         train_step_outputs_np)

    if train_steps % iterations_per_loop != 0:
        remainder_train_step_fn = _train_step_fn(
            model,
            optimizer,
            strategy,
            metrics,
            iterations_per_loop=train_steps % iterations_per_loop)
        train_step_outputs = remainder_train_step_fn(train_iterator)

    # Always evaluate and record metrics at the end of training.
    _write_summaries(train_step_outputs, train_steps, train_summary_writer,
                     hparams)
    train_step_outputs_np = {
        k: v.numpy()
        for k, v in train_step_outputs.items()
    }
    logging.info('Training metrics for step %d: %s', current_step,
                 train_step_outputs_np)
    if mode == 'train_and_eval':
        eval_lib.run_eval_epoch(val_fn,
                                val_dataset,
                                val_summary_writer,
                                test_fn,
                                test_dataset,
                                test_summary_writer,
                                train_steps,
                                hparams=hparams)
    # Save checkpoint at the end of training.
    if checkpoint_manager:
        checkpoint_manager.save(checkpoint_number=train_steps)
예제 #5
0
def setup_eval(
    dataset_builder: ub.datasets.BaseDataset,
    strategy,
    trial_dir: str,
    model: tf.keras.Model,
    metrics: Dict[str, tf.keras.metrics.Metric],
    ood_dataset_builder: ub.datasets.BaseDataset = None,
    ood_metrics: Dict[str,
                      tf.keras.metrics.Metric] = None) -> _EvalSetupResult:
    """Setup the test and optionally validation loggers, step fns and datasets."""
    test_dataset = dataset_builder.build('test')
    test_dataset = strategy.experimental_distribute_dataset(test_dataset)
    test_summary_writer = tf.summary.create_file_writer(
        os.path.join(trial_dir, 'test'))
    num_test_steps = (dataset_builder.info['num_test_examples'] //
                      dataset_builder.eval_batch_size)
    test_fn = eval_step_fn(model,
                           strategy,
                           metrics,
                           iterations_per_loop=num_test_steps)
    ood_fn = None
    ood_dataset = None
    ood_summary_writer = None

    if ((ood_dataset_builder and not ood_metrics)
            or (not ood_dataset_builder and ood_metrics)):
        raise ValueError('Both ood_dataset_builder and ood_metrics must be'
                         ' specified.')
    if ood_dataset_builder:
        ood_dataset_in = dataset_builder.build('test', ood_split='in')
        ood_dataset_out = ood_dataset_builder.build('test', ood_split='ood')
        ood_dataset = ood_dataset_in.concatenate(ood_dataset_out)
        ood_dataset = strategy.experimental_distribute_dataset(ood_dataset)
        ood_summary_writer = tf.summary.create_file_writer(
            os.path.join(trial_dir, 'ood'))
        num_test_steps = (dataset_builder.info['num_test_examples'] //
                          dataset_builder.eval_batch_size)
        num_ood_steps = (ood_dataset_builder.info['num_ood_examples'] //
                         ood_dataset_builder.eval_batch_size)

        ood_fn = eval_step_fn(model,
                              strategy,
                              ood_metrics,
                              iterations_per_loop=num_test_steps +
                              num_ood_steps,
                              label_key='is_in_distribution')

    # Have to have separate val_fn and test_fn because otherwise tf.function
    # retraces the function each time, which is very slow, because we are passing
    # in a Python dict of metrics and int for iterations_per_loop.
    val_fn = None
    val_dataset = None
    val_summary_writer = None
    if dataset_builder.info['num_validation_examples'] > 0:
        num_val_steps = (dataset_builder.info['num_validation_examples'] //
                         dataset_builder.eval_batch_size)
        val_dataset = dataset_builder.build('validation')
        val_dataset = strategy.experimental_distribute_dataset(val_dataset)
        val_summary_writer = tf.summary.create_file_writer(
            os.path.join(trial_dir, 'validation'))
        if num_val_steps == num_test_steps:
            val_fn = test_fn
        else:
            # The metrics are reset at the start of each call to {val,test}_fn, so
            # reusing them is safe.
            val_fn = eval_step_fn(model,
                                  strategy,
                                  metrics,
                                  iterations_per_loop=num_val_steps)

    return (test_fn, test_dataset, test_summary_writer, val_fn, val_dataset,
            val_summary_writer, ood_fn, ood_dataset, ood_summary_writer)