Esempio n. 1
0
    def __call__(self, server_state, train_metrics, round_num):
        """A function suitable for passing as an eval hook to the training_loop.

    Args:
      server_state: A `ServerState`.
      train_metrics: A `dict` of training metrics computed in TFF.
      round_num: The current round number.
    """
        tff.learning.assign_weights_to_keras_model(self.model,
                                                   server_state.model)
        eval_metrics = self.model.evaluate(self.eval_dataset, verbose=0)

        metrics = {
            'train':
            train_metrics,
            'eval':
            collections.OrderedDict(
                zip(['loss', 'sparse_categorical_accuracy'], eval_metrics))
        }
        flat_metrics = tree.flatten_with_path(metrics)
        flat_metrics = [('/'.join(map(str, path)), item)
                        for path, item in flat_metrics]
        flat_metrics = collections.OrderedDict(flat_metrics)
        flat_metrics['round'] = round_num

        logging.info('Evaluation at round {:d}:\n{!s}'.format(
            round_num, pprint.pformat(flat_metrics)))

        # Also write metrics to a tf.summary logdir
        with self.summary_writer.as_default():
            for name, value in flat_metrics.items():
                tf.compat.v2.summary.scalar(name, value, step=round_num)

        self.results = self.results.append(flat_metrics, ignore_index=True)
        utils_impl.atomic_write_to_csv(self.results, self.results_file)
Esempio n. 2
0
def _setup_outputs(root_output_dir, experiment_name, hparam_dict):
  """Set up directories for experiment loops, write hyperparameters to disk."""

  if not experiment_name:
    raise ValueError('experiment_name must be specified.')

  create_if_not_exists(root_output_dir)

  checkpoint_dir = os.path.join(root_output_dir, 'checkpoints', experiment_name)
  create_if_not_exists(checkpoint_dir)
  checkpoint_mngr = checkpoint_manager.FileCheckpointManager(checkpoint_dir)

  results_dir = os.path.join(root_output_dir, 'results', experiment_name)
  create_if_not_exists(results_dir)
  metrics_mngr = metrics_manager.ScalarMetricsManager(results_dir)

  summary_logdir = os.path.join(root_output_dir, 'logdir', experiment_name)
  create_if_not_exists(summary_logdir)
  summary_writer = tf.compat.v2.summary.create_file_writer(summary_logdir)

  hparam_dict['metrics_file'] = metrics_mngr.metrics_filename
  hparams_file = os.path.join(results_dir, 'hparams.csv')
  utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file)

  logging.info('Writing...')
  logging.info('    checkpoints to: %s', checkpoint_dir)
  logging.info('    metrics csv to: %s', metrics_mngr.metrics_filename)
  logging.info('    summaries to: %s', summary_logdir)

  return checkpoint_mngr, metrics_mngr, summary_writer
Esempio n. 3
0
    def clear_rounds_after(self, last_valid_round_num: int) -> None:
        """Metrics for rounds greater than `last_valid_round_num` are cleared out.

    By using this method, this class can be used upon restart of an experiment
    at `last_valid_round_num` to ensure that no duplicate rows of data exist in
    the CSV file. This method will atomically update the stored CSV file.

    Args:
      last_valid_round_num: All metrics for rounds later than this are expunged.

    Raises:
      RuntimeError: If metrics do not exist (none loaded during construction '
        nor recorded via `update_metrics()` and `last_valid_round_num` is not
        zero.
      ValueError: If `last_valid_round_num` is negative.
    """
        if last_valid_round_num < 0:
            raise ValueError('Attempting to clear metrics after round '
                             f'{last_valid_round_num}, which is negative.')
        if self._latest_round_num is None:
            if last_valid_round_num == 0:
                return
            raise RuntimeError('Metrics do not exist yet.')
        self._metrics = self._metrics.drop(self._metrics[
            self._metrics.round_num > last_valid_round_num].index)
        utils_impl.atomic_write_to_csv(self._metrics, self._metrics_filename)
        self._latest_round_num = last_valid_round_num
Esempio n. 4
0
  def __call__(self, server_state, metrics, round_num):
    """A function suitable for passing as an eval hook to the training_loop.

    Args:
      server_state: A `ServerState`.
      metrics: A dict of metrics computed in TFF.
      round_num: The current round number.
    """
    tff.learning.assign_weights_to_keras_model(self.model, server_state.model)
    eval_metrics = self.model.evaluate(self.eval_dataset, verbose=0)

    metrics['eval'] = collections.OrderedDict(
        zip(['loss', 'sparse_categorical_accuracy'], eval_metrics))

    flat_metrics = collections.OrderedDict(
        nest_fork.flatten_with_joined_string_paths(metrics))

    # Use a DataFrame just to get nice formatting.
    df = pd.DataFrame.from_dict(flat_metrics, orient='index', columns=['value'])
    print(df)

    # Also write metrics to a tf.summary logdir
    with self.summary_writer.as_default():
      for name, value in flat_metrics.items():
        tf.compat.v2.summary.scalar(name, value, step=round_num)

    self.results = self.results.append(flat_metrics, ignore_index=True)
    utils_impl.atomic_write_to_csv(self.results, self.results_file)
Esempio n. 5
0
def main(_):

  tf.enable_v2_behavior()

  experiment_output_dir = FLAGS.root_output_dir
  tensorboard_dir = os.path.join(experiment_output_dir, 'logdir',
                                 FLAGS.experiment_name)
  results_dir = os.path.join(experiment_output_dir, 'results',
                             FLAGS.experiment_name)

  for path in [experiment_output_dir, tensorboard_dir, results_dir]:
    try:
      tf.io.gfile.makedirs(path)
    except tf.errors.OpError:
      pass  # Directory already exists.

  hparam_dict = collections.OrderedDict([
      (name, FLAGS[name].value) for name in hparam_flags
  ])
  hparam_dict['results_file'] = results_dir
  hparams_file = os.path.join(results_dir, 'hparams.csv')

  logging.info('Saving hyper parameters to: [%s]', hparams_file)
  utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file)

  train_dataset, eval_dataset = emnist_ae_dataset.get_centralized_emnist_datasets(
      batch_size=FLAGS.batch_size, only_digits=False)

  optimizer = optimizer_utils.create_optimizer_fn_from_flags('centralized')()

  model = emnist_ae_models.create_autoencoder_model()
  model.compile(
      loss=tf.keras.losses.MeanSquaredError(),
      optimizer=optimizer,
      metrics=[tf.keras.metrics.MeanSquaredError()])

  logging.info('Training model:')
  logging.info(model.summary())

  csv_logger_callback = keras_callbacks.AtomicCSVLogger(results_dir)
  tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=tensorboard_dir)
  # Reduce the learning rate after a fixed number of epochs.
  def decay_lr(epoch, learning_rate):
    if (epoch + 1) % FLAGS.decay_epochs == 0:
      return learning_rate * FLAGS.lr_decay
    else:
      return learning_rate

  lr_callback = tf.keras.callbacks.LearningRateScheduler(decay_lr, verbose=1)

  history = model.fit(
      train_dataset,
      validation_data=eval_dataset,
      epochs=FLAGS.num_epochs,
      callbacks=[lr_callback, tensorboard_callback, csv_logger_callback])

  logging.info('Final metrics:')
  for name in ['loss', 'mean_squared_error']:
    metric = history.history['val_{}'.format(name)][-1]
    logging.info('\t%s: %.4f', name, metric)
Esempio n. 6
0
    def test_atomic_read(self):
        for name in ['foo.csv', 'baz.csv.bz2']:
            dataframe = pd.DataFrame(dict(a=[1, 2], b=[4.0, 5.0]))
            csv_file = os.path.join(absltest.get_default_test_tmpdir(), name)
            utils_impl.atomic_write_to_csv(dataframe, csv_file)

            dataframe2 = utils_impl.atomic_read_from_csv(csv_file)
            pd.testing.assert_frame_equal(dataframe, dataframe2)
Esempio n. 7
0
    def clear_all_rounds(self) -> None:
        """Existing metrics for all rounds are cleared out.

    This method will atomically update the stored CSV file.
    """
        self._metrics = pd.DataFrame()
        utils_impl.atomic_write_to_csv(self._metrics, self._metrics_filename)
        self._latest_round_num = None
    def __init__(self,
                 root_metrics_dir: str = '/tmp',
                 prefix: str = 'experiment',
                 use_bz2: bool = True):
        """Returns an initialized `ScalarMetricsManager`.

    This class will maintain metrics in a CSV file in the filesystem. The path
    of the file is {`root_metrics_dir`}/{`prefix`}.metrics.csv (if use_bz2 is
    set to False) or {`root_metrics_dir`}/{`prefix`}.metrics.csv.bz2 (if
    use_bz2 is set to True). To use this class upon restart of an experiment at
    an earlier round number, you can initialize and then call the
    clear_rounds_after() method to remove all rows for round numbers later than
    the restart round number. This ensures that no duplicate rows of data exist
    in the CSV.

    Args:
      root_metrics_dir: A path on the filesystem to store CSVs.
      prefix: A string to use as the prefix of filename. Usually the name of a
        specific run in a larger grid of experiments sharing a common
        `root_metrics_dir`.
      use_bz2: A boolean indicating whether to zip the result metrics csv using
        bz2.

    Raises:
      ValueError: If `root_metrics_dir` is empty string.
      ValueError: If `prefix` is empty string.
      ValueError: If the specified metrics csv file already exists but does not
        contain a `round_num` column.
    """
        super().__init__()
        if not root_metrics_dir:
            raise ValueError(
                'Empty string passed for root_metrics_dir argument.')
        if not prefix:
            raise ValueError('Empty string passed for prefix argument.')

        if use_bz2:
            # Using .bz2 rather than .zip due to
            # https://github.com/pandas-dev/pandas/issues/26023
            self._metrics_filename = os.path.join(root_metrics_dir,
                                                  f'{prefix}.metrics.csv.bz2')
        else:
            self._metrics_filename = os.path.join(root_metrics_dir,
                                                  f'{prefix}.metrics.csv')
        if not tf.io.gfile.exists(self._metrics_filename):
            utils_impl.atomic_write_to_csv(pd.DataFrame(),
                                           self._metrics_filename)

        self._metrics = utils_impl.atomic_read_from_csv(self._metrics_filename)
        if not self._metrics.empty and 'round_num' not in self._metrics.columns:
            raise ValueError(
                f'The specified csv file ({self._metrics_filename}) already exists '
                'but was not created by ScalarMetricsManager (it does not contain a '
                '`round_num` column.')

        self._latest_round_num = (None if self._metrics.empty else
                                  self._metrics.round_num.max(axis=0))
Esempio n. 9
0
    def update_metrics(self, round_num,
                       metrics_to_append: Dict[str, Any]) -> Dict[str, float]:
        """Updates the stored metrics data with metrics for a specific round.

    The specified `round_num` must be later than the latest round number for
    which metrics exist in the stored metrics data. This method will atomically
    update the stored CSV file. Also, if stored metrics already exist and
    `metrics_to_append` contains a new, previously unseen metric name, a new
    column in the dataframe will be added for that metric, and all previous rows
    will fill in with NaN values for the metric.

    Args:
      round_num: Communication round at which `metrics_to_append` was collected.
      metrics_to_append: A dictionary of metrics collected during `round_num`.
        These metrics can be in a nested structure, but the nesting will be
        flattened for storage in the CSV (with the new keys equal to the paths
        in the nested structure).

    Returns:
      A `collections.OrderedDict` of the data just added in a new row to the
        pandas.DataFrame. Compared with the input `metrics_to_append`, this data
        is flattened, with the key names equal to the path in the nested
        structure. Also, `round_num` has been added as an additional key.

    Raises:
      ValueError: If the provided round number is negative.
      ValueError: If the provided round number is less than or equal to the
        latest round number in the stored metrics data.
    """
        if round_num < 0:
            raise ValueError(
                f'Attempting to append metrics for round {round_num}, '
                'which is negative.')
        if self._latest_round_num and round_num <= self._latest_round_num:
            raise ValueError(
                f'Attempting to append metrics for round {round_num}, '
                'but metrics already exist through round '
                f'{self._latest_round_num}.')

        # Add the round number to the metrics before storing to csv file. This will
        # be used if a restart occurs, to identify which metrics to trim in the
        # _clear_invalid_rounds() method.
        metrics_to_append['round_num'] = round_num

        flat_metrics = tree.flatten_with_path(metrics_to_append)
        flat_metrics = [('/'.join(map(str, path)), item)
                        for path, item in flat_metrics]
        flat_metrics = collections.OrderedDict(flat_metrics)
        self._metrics = self._metrics.append(flat_metrics, ignore_index=True)
        utils_impl.atomic_write_to_csv(self._metrics, self._metrics_filename)
        self._latest_round_num = round_num

        return flat_metrics
Esempio n. 10
0
 def on_epoch_end(self, epoch: int, logs: Dict[Any, Any] = None):
     results_path = os.path.join(self._path, 'metric_results.csv')
     if tf.io.gfile.exists(results_path):
         # Read the results until now.
         results_df = utils_impl.atomic_read_from_csv(results_path)
         # Slice off results after the current epoch, this indicates the job
         # restarted.
         results_df = results_df[:epoch]
         # Add the new epoch.
         results_df = results_df.append(logs, ignore_index=True)
     else:
         results_df = pd.DataFrame(logs, index=[epoch])
     utils_impl.atomic_write_to_csv(results_df, results_path)
Esempio n. 11
0
  def test_atomic_write(self):
    for name in ['foo.csv', 'baz.csv.bz2']:
      dataframe = pd.DataFrame(dict(a=[1, 2], b=[4.0, 5.0]))
      output_file = os.path.join(absltest.get_default_test_tmpdir(), name)
      utils_impl.atomic_write_to_csv(dataframe, output_file)
      dataframe2 = pd.read_csv(output_file, index_col=0)
      pd.testing.assert_frame_equal(dataframe, dataframe2)

      # Overwriting
      dataframe3 = pd.DataFrame(dict(a=[1, 2, 3], b=[4.0, 5.0, 6.0]))
      utils_impl.atomic_write_to_csv(dataframe3, output_file)
      dataframe4 = pd.read_csv(output_file, index_col=0)
      pd.testing.assert_frame_equal(dataframe3, dataframe4)
Esempio n. 12
0
  def test_constructor_raises_value_error_if_csvfile_is_invalid(self):
    dataframe_missing_round_num = pd.DataFrame.from_dict(
        _create_dummy_metrics())

    temp_dir = self.get_temp_dir()
    # This csvfile is 'invalid' in that it was not originally created by an
    # instance of ScalarMetricsManager, and is missing a column for
    # round_num.
    invalid_csvfile = os.path.join(temp_dir, 'foo.metrics.csv.bz2')
    utils_impl.atomic_write_to_csv(dataframe_missing_round_num, invalid_csvfile)

    with self.assertRaises(ValueError):
      metrics_manager.ScalarMetricsManager(temp_dir, prefix='foo')
Esempio n. 13
0
  def build(cls, exp_name, output_dir, eval_dataset, hparam_dict, model):
    """Constructs the MetricsHook.

    Args:
      exp_name: A unique filesystem-friendly name for the experiment.
      output_dir: A root output directory used for all experiment runs in a
        grid. The MetricsHook will combine this with exp_name to form suitable
        output directories for this run.
      eval_dataset: Evaluation dataset.
      hparam_dict: A dictionary of hyperparameters to be recorded to .csv and
        exported to TensorBoard.
      model: The model for evaluation.

    Returns:
      The `MetricsHook` object.
    """

    eval_dataset = eval_dataset.map(tuple)

    summary_logdir = os.path.join(output_dir, 'logdir/{}'.format(exp_name))
    _check_not_exists(summary_logdir)
    tf.io.gfile.makedirs(summary_logdir)

    summary_writer = tf.compat.v2.summary.create_file_writer(
        summary_logdir, name=exp_name)
    with summary_writer.as_default():
      hp.hparams(hparam_dict)

    # Using .bz2 rather than .zip due to
    # https://github.com/pandas-dev/pandas/issues/26023
    results_file = os.path.join(output_dir,
                                '{}.results.csv.bz2'.format(exp_name))

    # Also write the hparam_dict to a CSV:
    hparam_dict['results_file'] = results_file
    hparams_file = os.path.join(output_dir, '{}.hparams.csv'.format(exp_name))
    utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file)

    logging.info('Writing ...')
    logging.info('   result csv to: %s', results_file)
    logging.info('    summaries to: %s', summary_logdir)

    return cls(
        results_file=results_file,
        summary_writer=summary_writer,
        eval_dataset=eval_dataset,
        model=model)
Esempio n. 14
0
def _setup_outputs(root_output_dir,
                   experiment_name,
                   hparam_dict,
                   write_metrics_with_bz2=True,
                   rounds_per_profile=0):
    """Set up directories for experiment loops, write hyperparameters to disk."""

    if not experiment_name:
        raise ValueError('experiment_name must be specified.')

    create_if_not_exists(root_output_dir)

    checkpoint_dir = os.path.join(root_output_dir, 'checkpoints',
                                  experiment_name)
    create_if_not_exists(checkpoint_dir)
    checkpoint_mngr = checkpoint_manager.FileCheckpointManager(checkpoint_dir)

    results_dir = os.path.join(root_output_dir, 'results', experiment_name)
    create_if_not_exists(results_dir)
    metrics_mngr = metrics_manager.ScalarMetricsManager(
        results_dir, use_bz2=write_metrics_with_bz2)

    summary_logdir = os.path.join(root_output_dir, 'logdir', experiment_name)
    create_if_not_exists(summary_logdir)
    summary_writer = tf.summary.create_file_writer(summary_logdir)

    if hparam_dict:
        hparam_dict['metrics_file'] = metrics_mngr.metrics_filename
        hparams_file = os.path.join(results_dir, 'hparams.csv')
        utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file)

    logging.info('Writing...')
    logging.info('    checkpoints to: %s', checkpoint_dir)
    logging.info('    metrics csv to: %s', metrics_mngr.metrics_filename)
    logging.info('    summaries to: %s', summary_logdir)

    @contextlib.contextmanager
    def profiler(round_num):
        if (rounds_per_profile > 0 and round_num % rounds_per_profile == 0):
            with tf.profiler.experimental.Profile(summary_logdir):
                yield
        else:
            yield

    return checkpoint_mngr, metrics_mngr, summary_writer, profiler
Esempio n. 15
0
def run_experiment():
  """Runs the training experiment."""
  np.random.seed(FLAGS.random_seed)
  tf.random.set_random_seed(FLAGS.random_seed)

  total_examples = 341873 if FLAGS.only_digits else 671585
  emnist_train, emnist_test = create_nonfed_emnist(total_examples)
  steps_per_epoch = int(total_examples / FLAGS.batch_size)

  model = create_compiled_keras_model()

  # Define TensorBoard callback
  log_dir = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name, 'log_dir')
  try:
    tf.io.gfile.makedirs(log_dir)
  except tf.errors.OpError:
    pass  # Directory already exists, we'll simply reuse.
  tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)

  # Define CSV callback
  results_path = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name,
                              'results.csv')
  csv_logger = tf.keras.callbacks.CSVLogger(results_path)

  # Write the hyperparameters to a CSV:
  hparam_dict = collections.OrderedDict([
      (name, FLAGS[name].value) for name in hparam_flags
  ])
  hparam_dict['results_file'] = results_path
  hparams_file = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name,
                              'hparams.csv')
  utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file)

  model.fit(
      emnist_train,
      steps_per_epoch=steps_per_epoch,
      epochs=FLAGS.epochs,
      verbose=1,
      validation_data=emnist_test,
      callbacks=[tensorboard_callback, csv_logger])
  score = model.evaluate(emnist_test, verbose=0)
  print('Final test loss: %.4f' % score[0])
  print('Final test accuracy: %.4f' % score[1])
Esempio n. 16
0
    def __call__(self, train_metrics, eval_metrics, round_num):
        """A function suitable for passing as an eval hook to the training_loop.

    Args:
      train_metrics: A `dict` of training metrics computed in TFF.
      eval_metrics: A `dict` of evalutation metrics computed in TFF.
      round_num: The current round number.
    """
        metrics = {
            'train': train_metrics,
            'eval': eval_metrics,
            'round': round_num,
        }
        flat_metrics = tree.flatten_with_path(metrics)
        flat_metrics = [('/'.join(map(str, path)), item)
                        for path, item in flat_metrics]
        flat_metrics = collections.OrderedDict(flat_metrics)

        logging.info('Evaluation at round {:d}:\n{!s}'.format(
            round_num, pprint.pformat(flat_metrics)))

        # Also write metrics to a tf.summary logdir
        with self._summary_writer.as_default():
            for name, value in flat_metrics.items():
                tf.compat.v2.summary.scalar(name, value, step=round_num)

        if tf.io.gfile.exists(self._results_file):
            metrics = pd.read_csv(self._results_file,
                                  header=0,
                                  index_col=0,
                                  engine='c')
            # Remove everything after `round_num`, in case the experiment was
            # restarted at an earlier checkpoint we want to avoid duplicate metrics.
            metrics = metrics[:round_num]
            metrics = metrics.append(flat_metrics, ignore_index=True)
        else:
            metrics = pd.DataFrame(flat_metrics, index=[0])
        utils_impl.atomic_write_to_csv(metrics, self._results_file)
Esempio n. 17
0
    def __init__(self, experiment_name, output_dir, hparam_dict):
        """Returns an initalized `MetricsHook`.

    Args:
      experiment_name: A unique filesystem-friendly name for the experiment.
      output_dir: A root output directory used for all experiment runs in a
        grid. The `MetricsHook` will combine this with `experiment_name` to form
        suitable output directories for this run.
      hparam_dict: A dictionary of hyperparameters to be recorded to .csv and
        exported to TensorBoard.
    """

        summary_logdir = os.path.join(output_dir,
                                      'logdir/{}'.format(experiment_name))
        _check_not_exists(summary_logdir, FLAGS.disable_check_exists)
        tf.io.gfile.makedirs(summary_logdir)

        self._summary_writer = tf.compat.v2.summary.create_file_writer(
            summary_logdir, name=experiment_name)
        with self._summary_writer.as_default():
            hp.hparams(hparam_dict)

        # Using .bz2 rather than .zip due to
        # https://github.com/pandas-dev/pandas/issues/26023
        self._results_file = os.path.join(output_dir, experiment_name,
                                          'results.csv.bz2')

        # Also write the hparam_dict to a CSV:
        hparam_dict['results_file'] = self._results_file
        hparams_file = os.path.join(output_dir, experiment_name, 'hparams.csv')
        utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file)

        logging.info('Writing ...')
        logging.info('   result csv to: %s', self._results_file)
        logging.info('    summaries to: %s', summary_logdir)

        _check_not_exists(self._results_file, FLAGS.disable_check_exists)
Esempio n. 18
0
 def on_epoch_end(self, epoch, logs=None):
     epoch_path = os.path.join(self._path,
                               'results.{:02d}.csv'.format(epoch))
     utils_impl.atomic_write_to_csv(pd.Series(logs), epoch_path)
Esempio n. 19
0
def run_experiment():
    """Runs the training experiment."""
    training_set, validation_set, test_set = (
        dataset.construct_word_level_datasets(
            vocab_size=FLAGS.vocab_size,
            batch_size=FLAGS.batch_size,
            client_epochs_per_round=1,
            max_seq_len=FLAGS.sequence_length,
            max_training_elements_per_user=-1,
            num_validation_examples=FLAGS.num_validation_examples,
            num_test_examples=FLAGS.num_test_examples))
    centralized_train = training_set.create_tf_dataset_from_all_clients()

    def _lstm_fn():
        return tf.keras.layers.LSTM(FLAGS.latent_size, return_sequences=True)

    model = models.create_recurrent_model(
        FLAGS.vocab_size,
        FLAGS.embedding_size,
        FLAGS.num_layers,
        _lstm_fn,
        'stackoverflow-lstm',
        shared_embedding=FLAGS.shared_embedding)
    logging.info('Training model: %s', model.summary())
    optimizer = utils_impl.create_optimizer_from_flags('centralized')
    model.compile(loss=tf.keras.losses.sparse_categorical_crossentropy,
                  optimizer=optimizer,
                  weighted_metrics=['acc'])

    train_results_path = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name,
                                      'train_results')
    test_results_path = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name,
                                     'test_results')

    train_csv_logger = AtomicCSVLogger(train_results_path)
    test_csv_logger = AtomicCSVLogger(test_results_path)

    log_dir = os.path.join(FLAGS.root_output_dir, 'logdir', FLAGS.exp_name)
    try:
        tf.io.gfile.makedirs(log_dir)
        tf.io.gfile.makedirs(train_results_path)
        tf.io.gfile.makedirs(test_results_path)
    except tf.errors.OpError:
        pass  # log_dir already exists.

    train_tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=log_dir,
        write_graph=True,
        update_freq=FLAGS.tensorboard_update_frequency)

    test_tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)

    results_file = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name,
                                'results.csv.bz2')

    # Write the hyperparameters to a CSV:
    hparam_dict = collections.OrderedDict([(name, FLAGS[name].value)
                                           for name in hparam_flags])
    hparam_dict['results_file'] = results_file
    hparams_file = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name,
                                'hparams.csv')
    utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file)

    oov, bos, eos, pad = dataset.get_special_tokens(FLAGS.vocab_size)
    class_weight = {x: 1.0 for x in range(FLAGS.vocab_size)}
    class_weight[oov] = 0.0  # No credit for predicting OOV.
    class_weight[bos] = 0.0  # Shouldn't matter since this is never a target.
    class_weight[eos] = 1.0  # Model should learn to predict end of sentence.
    class_weight[pad] = 0.0  # No credit for predicting pad.

    model.fit(centralized_train,
              epochs=FLAGS.epochs,
              verbose=1,
              class_weight=class_weight,
              validation_data=validation_set,
              callbacks=[train_csv_logger, train_tensorboard_callback])
    score = model.evaluate(
        test_set,
        verbose=1,
        callbacks=[test_csv_logger, test_tensorboard_callback])
    logging.info('Final test loss: %.4f', score[0])
    logging.info('Final test accuracy: %.4f', score[1])
Esempio n. 20
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    experiment_output_dir = FLAGS.root_output_dir
    tensorboard_dir = os.path.join(experiment_output_dir, 'logdir',
                                   FLAGS.experiment_name)
    results_dir = os.path.join(experiment_output_dir, 'results',
                               FLAGS.experiment_name)

    for path in [experiment_output_dir, tensorboard_dir, results_dir]:
        try:
            tf.io.gfile.makedirs(path)
        except tf.errors.OpError:
            pass  # Directory already exists.

    hparam_dict = collections.OrderedDict([(name, FLAGS[name].value)
                                           for name in hparam_flags])
    hparam_dict['results_file'] = results_dir
    hparams_file = os.path.join(results_dir, 'hparams.csv')

    logging.info('Saving hyper parameters to: [%s]', hparams_file)
    utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file)

    train_dataset, eval_dataset = emnist_dataset.get_centralized_emnist_datasets(
        batch_size=FLAGS.batch_size, only_digits=False)

    optimizer = optimizer_utils.create_optimizer_fn_from_flags('centralized')()

    if FLAGS.model == 'cnn':
        model = emnist_models.create_conv_dropout_model(only_digits=False)
    elif FLAGS.model == '2nn':
        model = emnist_models.create_two_hidden_layer_model(only_digits=False)
    else:
        raise ValueError('Cannot handle model flag [{!s}].'.format(
            FLAGS.model))

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                  optimizer=optimizer,
                  metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

    logging.info('Training model:')
    logging.info(model.summary())

    csv_logger_callback = keras_callbacks.AtomicCSVLogger(results_dir)
    tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=tensorboard_dir)

    # Reduce the learning rate after a fixed number of epochs.
    def decay_lr(epoch, learning_rate):
        if (epoch + 1) % FLAGS.decay_epochs == 0:
            return learning_rate * FLAGS.lr_decay
        else:
            return learning_rate

    lr_callback = tf.keras.callbacks.LearningRateScheduler(decay_lr, verbose=1)

    history = model.fit(
        train_dataset,
        validation_data=eval_dataset,
        epochs=FLAGS.num_epochs,
        callbacks=[lr_callback, tensorboard_callback, csv_logger_callback])

    logging.info('Final metrics:')
    for name in ['loss', 'sparse_categorical_accuracy']:
        metric = history.history['val_{}'.format(name)][-1]
        logging.info('\t%s: %.4f', name, metric)
Esempio n. 21
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  tf.compat.v1.enable_v2_behavior()

  experiment_output_dir = FLAGS.root_output_dir
  tensorboard_dir = os.path.join(experiment_output_dir, 'logdir',
                                 FLAGS.experiment_name)
  results_dir = os.path.join(experiment_output_dir, 'results',
                             FLAGS.experiment_name)

  for path in [experiment_output_dir, tensorboard_dir, results_dir]:
    try:
      tf.io.gfile.makedirs(path)
    except tf.errors.OpError:
      pass  # Directory already exists.

  hparam_dict = collections.OrderedDict([
      (name, FLAGS[name].value) for name in hparam_flags
  ])
  hparam_dict['results_file'] = results_dir
  hparams_file = os.path.join(results_dir, 'hparams.csv')

  logging.info('Saving hyper parameters to: [%s]', hparams_file)
  utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file)

  cifar_train, cifar_test = dataset.get_centralized_cifar100(
      train_batch_size=FLAGS.batch_size, crop_shape=CROP_SHAPE)

  optimizer = optimizer_utils.create_optimizer_fn_from_flags('centralized')()
  model = resnet_models.create_resnet18(
      input_shape=CROP_SHAPE, num_classes=NUM_CLASSES)
  model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      optimizer=optimizer,
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

  logging.info('Training model:')
  logging.info(model.summary())

  csv_logger_callback = keras_callbacks.AtomicCSVLogger(results_dir)
  tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=tensorboard_dir)

  # Reduce the learning rate after a fixed number of epochs.
  def decay_lr(epoch, learning_rate):
    if (epoch + 1) % FLAGS.decay_epochs == 0:
      return learning_rate * FLAGS.lr_decay
    else:
      return learning_rate

  lr_callback = tf.keras.callbacks.LearningRateScheduler(decay_lr, verbose=1)

  history = model.fit(
      cifar_train,
      validation_data=cifar_test,
      epochs=FLAGS.num_epochs,
      callbacks=[lr_callback, tensorboard_callback, csv_logger_callback])

  logging.info('Final metrics:')
  for name in ['loss', 'sparse_categorical_accuracy']:
    metric = history.history['val_{}'.format(name)][-1]
    logging.info('\t%s: %.4f', name, metric)
Esempio n. 22
0
def run_experiment():
    """Runs the training experiment."""
    (_, stackoverflow_validation,
     stackoverflow_test) = dataset.construct_word_level_datasets(
         FLAGS.vocab_size, FLAGS.batch_size, 1, FLAGS.sequence_length, -1,
         FLAGS.num_validation_examples)
    centralized_train = dataset.get_centralized_train_dataset(
        FLAGS.vocab_size, FLAGS.batch_size, FLAGS.sequence_length,
        FLAGS.shuffle_buffer_size)

    def _lstm_fn(latent_size):
        return tf.keras.layers.LSTM(latent_size, return_sequences=True)

    model = models.create_recurrent_model(
        FLAGS.vocab_size,
        _lstm_fn,
        'stackoverflow-lstm',
        shared_embedding=FLAGS.shared_embedding)
    logging.info('Training model: %s', model.summary())
    optimizer = optimizer_utils.create_optimizer_fn_from_flags('centralized')()
    pad_token, oov_token, _, eos_token = dataset.get_special_tokens(
        FLAGS.vocab_size)
    model.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        optimizer=optimizer,
        metrics=[
            # Plus 4 for pad, oov, bos, eos
            keras_metrics.FlattenedCategoricalAccuracy(
                vocab_size=FLAGS.vocab_size + 4,
                name='accuracy_with_oov',
                masked_tokens=pad_token),
            keras_metrics.FlattenedCategoricalAccuracy(
                vocab_size=FLAGS.vocab_size + 4,
                name='accuracy_no_oov',
                masked_tokens=[pad_token, oov_token]),
            keras_metrics.FlattenedCategoricalAccuracy(
                vocab_size=FLAGS.vocab_size + 4,
                name='accuracy_no_oov_or_eos',
                masked_tokens=[pad_token, oov_token, eos_token]),
        ])

    train_results_path = os.path.join(FLAGS.root_output_dir, 'train_results',
                                      FLAGS.experiment_name)
    test_results_path = os.path.join(FLAGS.root_output_dir, 'test_results',
                                     FLAGS.experiment_name)

    train_csv_logger = keras_callbacks.AtomicCSVLogger(train_results_path)
    test_csv_logger = keras_callbacks.AtomicCSVLogger(test_results_path)

    log_dir = os.path.join(FLAGS.root_output_dir, 'logdir',
                           FLAGS.experiment_name)
    try:
        tf.io.gfile.makedirs(log_dir)
        tf.io.gfile.makedirs(train_results_path)
        tf.io.gfile.makedirs(test_results_path)
    except tf.errors.OpError:
        pass  # log_dir already exists.

    train_tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=log_dir,
        write_graph=True,
        update_freq=FLAGS.tensorboard_update_frequency)

    test_tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)

    # Write the hyperparameters to a CSV:
    hparam_dict = collections.OrderedDict([(name, FLAGS[name].value)
                                           for name in hparam_flags])
    hparams_file = os.path.join(FLAGS.root_output_dir, FLAGS.experiment_name,
                                'hparams.csv')
    utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file)

    model.fit(centralized_train,
              epochs=FLAGS.epochs,
              verbose=0,
              validation_data=stackoverflow_validation,
              callbacks=[train_csv_logger, train_tensorboard_callback])
    score = model.evaluate(
        stackoverflow_test,
        verbose=0,
        callbacks=[test_csv_logger, test_tensorboard_callback])
    logging.info('Final test loss: %.4f', score[0])
    logging.info('Final test accuracy: %.4f', score[1])
def run_experiment():
    """Runs the training experiment."""
    vocab = _create_vocab()
    (stackoverflow_train, stackoverflow_val,
     stackoverflow_test) = construct_word_level_datasets(vocab)

    num_training_steps = FLAGS.num_training_examples / FLAGS.batch_size

    def _lstm_fn():
        return tf.keras.layers.LSTM(FLAGS.latent_size, return_sequences=True)

    model = models.create_recurrent_model(FLAGS.vocab_size,
                                          FLAGS.embedding_size,
                                          FLAGS.num_layers, _lstm_fn,
                                          'stackoverflow-lstm')
    if FLAGS.optimizer == 'sgd':
        optimizer = tf.keras.optimizers.SGD(learning_rate=FLAGS.learning_rate,
                                            momentum=FLAGS.momentum)
    if FLAGS.optimizer == 'adagrad':
        optimizer = tf.keras.optimizers.Adagrad(
            learning_rate=FLAGS.learning_rate)
    model.compile(loss=tf.keras.losses.sparse_categorical_crossentropy,
                  optimizer=optimizer,
                  metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

    train_results_path = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name,
                                      'train_results')
    test_results_path = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name,
                                     'test_results')

    train_csv_logger = AtomicCSVLogger(train_results_path)
    test_csv_logger = AtomicCSVLogger(test_results_path)

    log_dir = os.path.join(FLAGS.root_output_dir, 'logdir', FLAGS.exp_name)
    try:
        tf.io.gfile.makedirs(log_dir)
        tf.io.gfile.makedirs(train_results_path)
        tf.io.gfile.makedirs(test_results_path)
    except tf.errors.OpError:
        pass  # log_dir already exists.

    tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=log_dir,
        write_graph=True,
        update_freq=FLAGS.tensorboard_update_frequency)

    results_file = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name,
                                'results.csv.bz2')

    # Write the hyperparameters to a CSV:
    hparam_dict = collections.OrderedDict([(name, FLAGS[name].value)
                                           for name in hparam_flags])
    hparam_dict['results_file'] = results_file
    hparams_file = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name,
                                'hparams.csv')
    utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file)

    model.fit(stackoverflow_train,
              steps_per_epoch=num_training_steps,
              epochs=25,
              verbose=1,
              validation_data=stackoverflow_val,
              callbacks=[train_csv_logger, tensorboard_callback])
    score = model.evaluate_generator(
        stackoverflow_test,
        verbose=1,
        callbacks=[test_csv_logger, tensorboard_callback])
    print('Final test loss: %.4f' % score[0])
    print('Final test accuracy: %.4f' % score[1])
Esempio n. 24
0
def run_experiment():
    """Runs the training experiment."""
    try:
        tf.io.gfile.makedirs(
            os.path.join(FLAGS.root_output_dir, FLAGS.exp_name))
    except tf.errors.OpError:
        pass

    train_set, validation_set, test_set = (
        dataset.construct_word_level_datasets(
            vocab_size=FLAGS.vocab_size,
            client_batch_size=FLAGS.batch_size,
            client_epochs_per_round=1,
            max_seq_len=FLAGS.sequence_length,
            max_elements_per_user=FLAGS.max_elements_per_user,
            centralized_train=True,
            shuffle_buffer_size=None,
            num_validation_examples=FLAGS.num_validation_examples,
            num_test_examples=FLAGS.num_test_examples))

    recurrent_model = tf.keras.layers.LSTM if FLAGS.lstm else tf.keras.layers.GRU

    def _layer_fn():
        return recurrent_model(FLAGS.latent_size, return_sequences=True)

    pad, oov, _, eos = dataset.get_special_tokens(FLAGS.vocab_size)

    model = models.create_recurrent_model(
        FLAGS.vocab_size,
        FLAGS.embedding_size,
        FLAGS.num_layers,
        _layer_fn,
        'stackoverflow-recurrent',
        shared_embedding=FLAGS.shared_embedding)
    logging.info('Training model: %s', model.summary())
    optimizer = utils_impl.create_optimizer_from_flags('centralized')
    model.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        optimizer=optimizer,
        metrics=[
            metrics.MaskedCategoricalAccuracy([pad], 'accuracy_with_oov'),
            metrics.MaskedCategoricalAccuracy([pad, oov], 'accuracy_no_oov'),
            metrics.MaskedCategoricalAccuracy([pad, oov, eos],
                                              'accuracy_no_oov_no_eos')
        ])

    train_results_path = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name,
                                      'train_results')
    test_results_path = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name,
                                     'test_results')

    train_csv_logger = AtomicCSVLogger(train_results_path)
    test_csv_logger = AtomicCSVLogger(test_results_path)

    log_dir = os.path.join(FLAGS.root_output_dir, 'logdir', FLAGS.exp_name)
    try:
        tf.io.gfile.makedirs(log_dir)
        tf.io.gfile.makedirs(train_results_path)
        tf.io.gfile.makedirs(test_results_path)
    except tf.errors.OpError:
        pass  # log_dir already exists.

    train_tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=log_dir,
        write_graph=True,
        update_freq=FLAGS.tensorboard_update_frequency)

    test_tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)

    results_file = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name,
                                'results.csv.bz2')

    # Write the hyperparameters to a CSV:
    hparam_dict = collections.OrderedDict([(name, FLAGS[name].value)
                                           for name in hparam_flags])
    hparam_dict['results_file'] = results_file
    hparams_file = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name,
                                'hparams.csv')
    utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file)

    model.fit(train_set,
              epochs=FLAGS.epochs,
              verbose=1,
              steps_per_epoch=FLAGS.steps_per_epoch,
              validation_data=validation_set,
              callbacks=[train_csv_logger, train_tensorboard_callback])
    score = model.evaluate(
        test_set,
        verbose=1,
        callbacks=[test_csv_logger, test_tensorboard_callback])
    logging.info('Final test loss: %.4f', score[0])
    logging.info('Final test accuracy: %.4f', score[1])
Esempio n. 25
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    experiment_output_dir = FLAGS.root_output_dir
    tensorboard_dir = os.path.join(experiment_output_dir, 'logdir',
                                   FLAGS.experiment_name)
    results_dir = os.path.join(experiment_output_dir, 'results',
                               FLAGS.experiment_name)

    for path in [experiment_output_dir, tensorboard_dir, results_dir]:
        try:
            tf.io.gfile.makedirs(path)
        except tf.errors.OpError:
            pass  # Directory already exists.

    hparam_dict = collections.OrderedDict([(name, FLAGS[name].value)
                                           for name in hparam_flags])
    hparam_dict['results_file'] = results_dir
    hparams_file = os.path.join(results_dir, 'hparams.csv')

    logging.info('Saving hyper parameters to: [%s]', hparams_file)
    utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file)

    train_dataset, eval_dataset = stackoverflow_lr_dataset.get_centralized_stackoverflow_datasets(
        batch_size=FLAGS.batch_size,
        vocab_tokens_size=FLAGS.vocab_tokens_size,
        vocab_tags_size=FLAGS.vocab_tags_size)

    optimizer = optimizer_utils.create_optimizer_fn_from_flags('centralized')()

    model = stackoverflow_lr_models.create_logistic_model(
        vocab_tokens_size=FLAGS.vocab_tokens_size,
        vocab_tags_size=FLAGS.vocab_tags_size)

    model.compile(loss=tf.keras.losses.BinaryCrossentropy(
        from_logits=False, reduction=tf.keras.losses.Reduction.SUM),
                  optimizer=optimizer,
                  metrics=[
                      tf.keras.metrics.Precision(),
                      tf.keras.metrics.Recall(top_k=5)
                  ])

    logging.info('Training model:')
    logging.info(model.summary())

    csv_logger_callback = keras_callbacks.AtomicCSVLogger(results_dir)
    tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=tensorboard_dir)

    # Reduce the learning rate after a fixed number of epochs.
    def decay_lr(epoch, learning_rate):
        if (epoch + 1) % FLAGS.decay_epochs == 0:
            return learning_rate * FLAGS.lr_decay
        else:
            return learning_rate

    lr_callback = tf.keras.callbacks.LearningRateScheduler(decay_lr, verbose=1)

    history = model.fit(
        train_dataset,
        validation_data=eval_dataset,
        epochs=FLAGS.num_epochs,
        callbacks=[lr_callback, tensorboard_callback, csv_logger_callback])

    logging.info('Final metrics:')
    for name in ['loss', 'precision', 'recall']:
        metric = history.history['val_{}'.format(name)][-1]
        logging.info('\t%s: %.4f', name, metric)
Esempio n. 26
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.compat.v1.enable_v2_behavior()

    experiment_output_dir = FLAGS.root_output_dir
    tensorboard_dir = os.path.join(experiment_output_dir, 'logdir',
                                   FLAGS.experiment_name)
    results_dir = os.path.join(experiment_output_dir, 'results',
                               FLAGS.experiment_name)

    for path in [experiment_output_dir, tensorboard_dir, results_dir]:
        try:
            tf.io.gfile.makedirs(path)
        except tf.errors.OpError:
            pass  # Directory already exists.

    hparam_dict = collections.OrderedDict([(name, FLAGS[name].value)
                                           for name in hparam_flags])
    hparam_dict['results_file'] = results_dir
    hparams_file = os.path.join(results_dir, 'hparams.csv')
    logging.info('Saving hyper parameters to: [%s]', hparams_file)
    utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file)

    train_client_data, test_client_data = (
        tff.simulation.datasets.shakespeare.load_data())

    def preprocess(ds):
        return dataset.convert_snippets_to_character_sequence_examples(
            ds, FLAGS.batch_size, epochs=1).cache()

    train_dataset = train_client_data.create_tf_dataset_from_all_clients()
    if FLAGS.shuffle_train_data:
        train_dataset = train_dataset.shuffle(buffer_size=10000)
    train_dataset = preprocess(train_dataset)

    eval_dataset = preprocess(
        test_client_data.create_tf_dataset_from_all_clients())

    optimizer = optimizer_utils.create_optimizer_fn_from_flags('centralized')()

    # Vocabulary with one OOV ID and zero for the mask.
    vocab_size = len(dataset.CHAR_VOCAB) + 2
    model = models.create_recurrent_model(vocab_size=vocab_size,
                                          batch_size=FLAGS.batch_size)
    model.compile(
        optimizer=optimizer,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[
            keras_metrics.FlattenedCategoricalAccuracy(vocab_size=vocab_size,
                                                       mask_zero=True)
        ])

    logging.info('Training model:')
    logging.info(model.summary())

    csv_logger_callback = keras_callbacks.AtomicCSVLogger(results_dir)
    tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=tensorboard_dir)

    # Reduce the learning rate every 20 epochs.
    def decay_lr(epoch, lr):
        if (epoch + 1) % 20 == 0:
            return lr * 0.1
        else:
            return lr

    lr_callback = tf.keras.callbacks.LearningRateScheduler(decay_lr, verbose=1)

    history = model.fit(
        train_dataset,
        validation_data=eval_dataset,
        epochs=FLAGS.num_epochs,
        callbacks=[lr_callback, tensorboard_callback, csv_logger_callback])

    logging.info('Final metrics:')
    for name in ['loss', 'accuracy']:
        metric = history.history['val_{}'.format(name)][-1]
        logging.info('\t%s: %.4f', name, metric)
Esempio n. 27
0
def run(
    keras_model: tf.keras.Model,
    train_dataset: tf.data.Dataset,
    experiment_name: str,
    root_output_dir: str,
    num_epochs: int,
    hparams_dict: Optional[Dict[str, Any]] = None,
    decay_epochs: Optional[int] = None,
    lr_decay: Optional[float] = None,
    validation_dataset: Optional[tf.data.Dataset] = None,
    test_dataset: Optional[tf.data.Dataset] = None
) -> tf.keras.callbacks.History:
    """Run centralized training for a given compiled `tf.keras.Model`.

  Args:
    keras_model: A compiled `tf.keras.Model`.
    train_dataset: The `tf.data.Dataset` to be used for training.
    experiment_name: Name of the experiment, used as part of the name of the
      output directory.
    root_output_dir: The top-level output directory. The directory
      `root_output_dir/experiment_name` will contain TensorBoard logs, metrics
      CSVs and other outputs.
    num_epochs: How many training epochs to perform.
    hparams_dict: An optional dict specifying hyperparameters. If provided, the
      hyperparameters will be written to CSV.
    decay_epochs: Number of training epochs before decaying the learning rate.
    lr_decay: How much to decay the learning rate by every `decay_epochs`.
    validation_dataset: An optional `tf.data.Dataset` used for validation during
      training.
    test_dataset: An optional `tf.data.Dataset` used for testing after all
      training has completed.

  Returns:
    A `tf.keras.callbacks.History` object.
  """
    tensorboard_dir = os.path.join(root_output_dir, 'logdir', experiment_name)
    results_dir = os.path.join(root_output_dir, 'results', experiment_name)

    for path in [root_output_dir, tensorboard_dir, results_dir]:
        tf.io.gfile.makedirs(path)

    if hparams_dict:
        hparams_file = os.path.join(results_dir, 'hparams.csv')
        logging.info('Saving hyper parameters to: [%s]', hparams_file)
        hparams_df = pd.DataFrame(hparams_dict, index=[0])
        utils_impl.atomic_write_to_csv(hparams_df, hparams_file)

    csv_logger_callback = keras_callbacks.AtomicCSVLogger(results_dir)
    tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=tensorboard_dir)
    training_callbacks = [tensorboard_callback, csv_logger_callback]

    if decay_epochs is not None and decay_epochs > 0:
        # Reduce the learning rate after a fixed number of epochs.
        def decay_lr(epoch, learning_rate):
            if (epoch + 1) % decay_epochs == 0:
                return learning_rate * lr_decay
            else:
                return learning_rate

        lr_callback = tf.keras.callbacks.LearningRateScheduler(decay_lr,
                                                               verbose=1)
        training_callbacks.append(lr_callback)

    logging.info('Training model:')
    logging.info(keras_model.summary())

    history = keras_model.fit(train_dataset,
                              validation_data=validation_dataset,
                              epochs=num_epochs,
                              callbacks=training_callbacks)

    logging.info('Final training metrics:')
    for metric in keras_model.metrics:
        name = metric.name
        metric = history.history[name][-1]
        logging.info('\t%s: %.4f', name, metric)

    if validation_dataset:
        logging.info('Final validation metrics:')
        for metric in keras_model.metrics:
            name = metric.name
            metric = history.history['val_{}'.format(name)][-1]
            logging.info('\t%s: %.4f', name, metric)

    if test_dataset:
        test_metrics = keras_model.evaluate(test_dataset, return_dict=True)
        logging.info('Test metrics:')
        for metric in keras_model.metrics:
            name = metric.name
            metric = test_metrics[name]
            logging.info('\t%s: %.4f', name, metric)

    return history