示例#1
0
  def testConstructingHookDoesntCreateEventFiles(self):
    output_dir = tempfile.mkdtemp()
    h1 = utils.GinConfigSaverHook(output_dir)
    h2 = utils.GinConfigSaverHook(output_dir)
    self.assertEqual(os.listdir(output_dir), [])

    def create_event_files(hook):
      with self.session() as sess:
        hook.after_create_session(sess)
      return [f for f in os.listdir(output_dir) if f.startswith('events')]

    self.assertEqual(len(create_event_files(h1)), 1)
    # Check that the second hook doesn't create another events file.
    self.assertEqual(len(create_event_files(h2)), 1)
示例#2
0
  def testConstructingHookDoesntCreateEventFiles(self):
    output_dir = tempfile.mkdtemp()
    h1 = utils.GinConfigSaverHook(output_dir)
    h2 = utils.GinConfigSaverHook(output_dir)
    self.assertEqual(os.listdir(output_dir), [])

    def create_event_files(hook):
      with tf.train.MonitoredTrainingSession(chief_only_hooks=[hook]):
        pass
      return [f for f in os.listdir(output_dir) if f.startswith('events')]

    self.assertEqual(len(create_event_files(h1)), 1)
    # Check that the second hook doesn't create another events file.
    self.assertEqual(len(create_event_files(h2)), 1)
示例#3
0
    def run_log_config_hook_maybe_with_summary(self, global_step_value,
                                               **kwargs):
        config.parse_config(GinConfigSaverHookTest.CONFIG_STR)

        configurable_fn()
        ConfigurableClass()
        no_args_fn()

        if global_step_value is not None:
            tf.get_variable(
                'global_step',
                shape=(),
                dtype=tf.int64,
                initializer=tf.constant_initializer(global_step_value),
                trainable=False)

        output_dir = tempfile.mkdtemp()
        summary_writer = tf.contrib.testing.FakeSummaryWriter(output_dir)
        h = utils.GinConfigSaverHook(output_dir,
                                     summary_writer=summary_writer,
                                     **kwargs)
        with tf.train.MonitoredSession(hooks=[h]):
            pass

        return output_dir, summary_writer
示例#4
0
  def run_log_config_hook_maybe_with_summary(self, global_step_value, **kwargs):
    config.parse_config(GinConfigSaverHookTest.CONFIG_STR)

    configurable_fn()
    ConfigurableClass()
    no_args_fn()

    output_dir = tempfile.mkdtemp()
    summary_writer = FakeSummaryWriter()
    h = utils.GinConfigSaverHook(
        output_dir, summary_writer=summary_writer, **kwargs)
    with self.session() as sess:
      if global_step_value is not None:
        global_step = tf.compat.v1.train.get_or_create_global_step()
        sess.run(global_step.assign(global_step_value))
      h.after_create_session(sess)

    return output_dir, summary_writer
示例#5
0
  def model_fn(self,
               features,
               labels,
               mode,
               config = None,
               params = None):
    """Estimator model_fn.

    Args:
      features: This is the first item returned from the input_fn and parsed by
        tensorspec_utils.validate_and_pack. A spec_structure which fulfills the
        requirements of the self.get_feature_specification.
      labels: This is the second item returned from the input_fn and parsed by
        tensorspec_utils.validate_and_pack. A spec_structure which fulfills the
        requirements of the self.get_feature_specification.
      mode: (ModeKeys) Specifies if this is training, evaluation or prediction.
      config: (Optional tf.estimator.RunConfig or contrib_tpu.RunConfig) Will
        receive what is passed to Estimator in config parameter, or the default
        config (tf.estimator.RunConfig). Allows updating things in your model_fn
        based on  configuration such as num_ps_replicas, or model_dir.
      params: An optional dict of hyper parameters that will be passed into
        input_fn and model_fn. Keys are names of parameters, values are basic
        python types. There are reserved keys for TPUEstimator, including
        'batch_size'.

    Raises:
      ValueError: If the mode key is not supported, not in [PREDICT, TRAIN,
        EVAL].

    Returns:
      An EstimatorSpec.
    """
    features = tensorspec_utils.validate_and_pack(
        expected_spec=self.get_feature_specification(mode),
        actual_tensors_or_spec=features,
        ignore_batch=True)
    if labels:
      labels = tensorspec_utils.validate_and_pack(
          expected_spec=self.get_label_specification(mode),
          actual_tensors_or_spec=labels,
          ignore_batch=True)
    inference_outputs = self.inference_network_fn(features, labels, mode,
                                                  config, params)
    update_ops = None
    if isinstance(inference_outputs, tuple):
      if len(inference_outputs) != 2:
        raise ValueError('Unknown output of inference_network_fn: '
                         'tuple of length %d' % len(inference_outputs))
      outputs = inference_outputs[0]
      update_ops = inference_outputs[1]
      inference_outputs = outputs

    if mode == tf.estimator.ModeKeys.PREDICT:
      model_fn_results = self.create_export_outputs_fn(features,
                                                       inference_outputs, mode,
                                                       config, params)
      export_outputs = None
      if isinstance(model_fn_results, tuple):
        predictions = model_fn_results[0]
        export_outputs = model_fn_results[1]
      elif isinstance(model_fn_results, dict):
        export_outputs = {}
        if len(model_fn_results) == 1:
          name, output = list(model_fn_results.items())[0]
          export_outputs[name] = tf.estimator.export.RegressionOutput(output)
        export_outputs[tf.saved_model.signature_constants
                       .DEFAULT_SERVING_SIGNATURE_DEF_KEY] = (
                           tf.estimator.export.PredictOutput(model_fn_results))
        predictions = model_fn_results
      else:
        raise ValueError('The create_export_outputs_fn should return a '
                         'tuple(predictions, export_outputs) or predictions.')

      return tf.estimator.EstimatorSpec(
          mode=mode, predictions=predictions, export_outputs=export_outputs)

    train_fn_result = self.model_train_fn(features, labels, inference_outputs,
                                          mode, config, params)
    if isinstance(train_fn_result, tf.Tensor):
      train_loss = train_fn_result
      train_outputs = {}
    elif isinstance(train_fn_result, tuple):
      train_loss = train_fn_result[0]
      train_outputs = train_fn_result[1]
    else:
      raise ValueError('The model_train_fn should return a '
                       'tuple(loss, train_outputs) or loss.')

    if mode == tf.estimator.ModeKeys.TRAIN:
      # Create the tf.train.Optimizer.
      optimizer = self.create_optimizer()

      train_op = self.create_train_op(train_loss, optimizer, update_ops,
                                      train_outputs)

      self.add_summaries(features, labels, inference_outputs, train_loss,
                         train_outputs, mode, config, params)

      # Now the optimizer has been created, therefore, the checkpoint could be
      # initialized.
      # No new variables are allowed to be added, otherwise
      # we would not initialize these variables.
      # Note, this feature is only available for train to bootstrap a model
      # (partially) from a different model. As soon as this checkpoint is
      # written all other modes will use the local checkpoint within model_dir.
      self.maybe_init_from_checkpoint()
      training_hooks = []

      # EstimatorSpec has training_chief_hooks, but TPUEstimatorSpec does not,
      # so we have to use training_hooks here and check is_chief.
      if config and config.is_chief:  # pytype: disable=attribute-error
        training_hooks.append(
            gin_utils.GinConfigSaverHook(
                config.model_dir, summarize_config=True))
        if hasattr(self, 'writer_init_ops'):
          training_hooks.append(V2SummaryInitHook(self.writer_init_ops[mode]))

      # `SyncReplicasOptimizer` needs to attach a training hook.
      if self._sync_replicas_optimizer:
        training_hooks.append(
            self._sync_replicas_optimizer.make_session_run_hook(
                config.is_chief))  # pytype: disable=attribute-error

      # Return the value of the property first since it might be changed.
      scaffold_fn = self.scaffold_fn
      scaffold = scaffold_fn()

      # In order to export asynchronously the saver has to be registered
      # in the graph collection. The scaffold function might register a
      # saver already which is why it is checked here and a saver only
      # added it has none has been added.
      if not tf.get_collection(tf.GraphKeys.SAVERS):
        # TODO(T2R_CONTRIBUTORS): Switch to using gin config for all saver params.
        keep_checkpoint_every_n_hours = None
        max_to_keep = None
        if config is not None:
          keep_checkpoint_every_n_hours = config.keep_checkpoint_every_n_hours
          max_to_keep = config.keep_checkpoint_max
        saver = gin_configurable_saver(
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
            max_to_keep=max_to_keep,
        )
        tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
      return tf.estimator.EstimatorSpec(
          mode=mode,
          loss=train_loss,
          train_op=train_op,
          training_hooks=training_hooks,
          scaffold=scaffold)

    if mode == tf.estimator.ModeKeys.EVAL:
      self.add_summaries(features, labels, inference_outputs, train_loss,
                         train_outputs, mode, config, params)

      eval_metrics = self.model_eval_fn(features, labels, inference_outputs,
                                        train_loss, train_outputs, mode, config,
                                        params)
      evaluation_hooks = self.get_eval_hooks(config, params)
      if config and config.is_chief:  # pytype: disable=attribute-error
        eval_name = params.get('eval_name', 'eval')  # pytype: disable=attribute-error
        evaluation_hooks.append(
            gin_utils.GinConfigSaverHook(
                os.path.join(config.model_dir, eval_name),
                summarize_config=True))
        if hasattr(self, 'writer_init_ops'):
          evaluation_hooks.append(V2SummaryInitHook(self.writer_init_ops[mode]))
      return tf.estimator.EstimatorSpec(
          mode=mode,
          loss=train_loss,
          eval_metric_ops=eval_metrics,
          evaluation_hooks=evaluation_hooks)

    raise ValueError('The mode {} is not supported yet.'.format(mode))
    def model_fn(self, features, labels, mode, config=None, params=None):
        """Estimator model_fn.

    Note, this function overwrites the model_fn of the wrapped t2r_model since
    is replaces specifications with their TPU corresponding calls and introduces
    additional casting conversion after the specification has been verified.

    Args:
      features: This is the first item returned from the input_fn and parsed by
        tensorspec_utils.validate_and_pack. A spec_structure which fulfills the
        requirements of the self.get_feature_specification.
      labels: This is the second item returned from the input_fn and parsed by
        tensorspec_utils.validate_and_pack. A spec_structure which fulfills the
        requirements of the self.get_feature_specification.
      mode: (ModeKeys) Specifies if this is training, evaluation or prediction.
      config: (Optional tf.estimator.RunConfig or tf.contrib.tpu.RunConfig) Will
        receive what is passed to Estimator in config parameter, or the default
        config (tf.estimator.RunConfig). Allows updating things in your model_fn
        based on  configuration such as num_ps_replicas, or model_dir.
      params: An optional dict of hyper parameters that will be passed into
        input_fn and model_fn. Keys are names of parameters, values are basic
        python types. There are reserved keys for TPUEstimator, including
        'batch_size'.

    Raises:
      ValueError: If the mode key is not supported, not in [PREDICT, TRAIN,
        EVAL].

    Returns:
      A TPUEstimatorSpec.
    """

        features = tensorspec_utils.validate_and_pack(
            expected_spec=self.get_feature_specification(mode),
            actual_tensors_or_spec=features,
            ignore_batch=True)
        if labels:
            labels = tensorspec_utils.validate_and_pack(
                expected_spec=self.get_label_specification(mode),
                actual_tensors_or_spec=labels,
                ignore_batch=True)

        # In order to support both TPU and CPU for inference, tensors
        # with dtype=bfloat16 will be casted to float32.
        # Note, despite casting the benefit of bfloat16 are still maintained
        # for TPUs since this operation is a noop on this platform.
        # See http://shortn/_TTg3ZyATRo for rationale.
        features = tensorspec_utils.cast_bfloat16_to_float32(features)
        if labels is not None:
            labels = tensorspec_utils.cast_bfloat16_to_float32(labels)

        inference_outputs = self._t2r_model.inference_network_fn(
            features, labels, mode, config, params)

        if mode == tf.estimator.ModeKeys.PREDICT:
            model_fn_results = self._t2r_model.create_export_outputs_fn(
                features, inference_outputs, mode, config, params)
            export_outputs = None
            if isinstance(model_fn_results, tuple):
                predictions = model_fn_results[0]
                export_outputs = model_fn_results[1]
            elif isinstance(model_fn_results, dict):
                export_outputs = {}
                if len(model_fn_results) == 1:
                    name, output = model_fn_results.items()[0]
                    export_outputs[
                        name] = tf.estimator.export.RegressionOutput(output)
                export_outputs[tf.saved_model.signature_constants.
                               DEFAULT_SERVING_SIGNATURE_DEF_KEY] = (
                                   tf.estimator.export.PredictOutput(
                                       model_fn_results))
                predictions = model_fn_results
            else:
                raise ValueError(
                    'The create_export_outputs_fn should return a '
                    'tuple(predictions, export_outputs) or predictions.')

            return tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                predictions=predictions,
                export_outputs=export_outputs)

        train_fn_result = self._t2r_model.model_train_fn(
            features, labels, inference_outputs, mode, config, params)
        if isinstance(train_fn_result, tf.Tensor):
            train_loss = train_fn_result
            train_outputs = {}
        elif isinstance(train_fn_result, tuple):
            train_loss = train_fn_result[0]
            train_outputs = train_fn_result[1]
        else:
            raise ValueError('The model_train_fn should return a '
                             'tuple(loss, train_outputs) or loss.')

        if mode == tf.estimator.ModeKeys.TRAIN:
            # Create the tf.train.Optimizer.
            optimizer = get_cross_shard_optimizer(
                self._t2r_model.create_optimizer())

            # Required for batch norm usage.
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = self._t2r_model.create_train_op(
                    train_loss, optimizer)

            self._t2r_model.add_summaries(features, labels, inference_outputs,
                                          train_loss, train_outputs, mode,
                                          config, params)

            # For TPUs the init has to happen in a scaffold function. Since the model
            # already contains one implementation which is internal to the model
            # this call is simply wrapped.
            # No new variables are allowed to be added, otherwise
            # we would not initialize these variables.
            # Note, this feature is only available for train to bootstrap a model
            # (partially) from a different model. As soon as this checkpoint is
            # written all other modes will use the local checkpoint within
            # model_dir.

            def create_scaffold_fn():
                """Creates a scaffold instance."""
                self._t2r_model.maybe_init_from_checkpoint()
                # Return the value of the property first since it might be changed.
                scaffold_fn = self._t2r_model.scaffold_fn
                scaffold = scaffold_fn()
                # In order to export asynchronously the saver has to be registered
                # in the graph collection. The scaffold function might register a
                # saver already which is why it is checked here and a saver only
                # added it has none has been added.
                if not tf.get_collection(tf.GraphKeys.SAVERS):
                    # TODO(T2R_CONTRIBUTORS): Switch to using gin config for all saver params.
                    keep_checkpoint_every_n_hours = None
                    max_to_keep = None
                    if config is not None:
                        keep_checkpoint_every_n_hours = config.keep_checkpoint_every_n_hours
                        max_to_keep = config.keep_checkpoint_max
                    saver = abstract_model.gin_configurable_saver(
                        keep_checkpoint_every_n_hours=
                        keep_checkpoint_every_n_hours,
                        max_to_keep=max_to_keep,
                    )
                    tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
                return scaffold

            training_hooks = []

            # EstimatorSpec has training_chief_hooks, but TPUEstimatorSpec does not,
            # so we have to use training_hooks here and check is_chief.
            if config and config.is_chief:  # pytype: disable=attribute-error
                training_hooks.append(
                    gin_utils.GinConfigSaverHook(config.model_dir,
                                                 summarize_config=True))

            return tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=train_loss,
                train_op=train_op,
                training_hooks=training_hooks,
                scaffold_fn=create_scaffold_fn)

        if mode == tf.estimator.ModeKeys.EVAL:
            self._t2r_model.add_summaries(features, labels, inference_outputs,
                                          train_loss, train_outputs, mode,
                                          config, params)
            eval_metrics = self._t2r_model.model_eval_fn(
                features, labels, inference_outputs, train_loss, train_outputs,
                mode, config, params)
            evaluation_hooks = self._t2r_model.get_eval_hooks(config, params)
            if config and config.is_chief:  # pytype: disable=attribute-error
                eval_name = params.get('eval_name', 'eval')  # pytype: disable=attribute-error
                evaluation_hooks.append(
                    gin_utils.GinConfigSaverHook(os.path.join(
                        config.model_dir, eval_name),
                                                 summarize_config=True))

            return tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=train_loss,
                eval_metrics=eval_metrics,
                evaluation_hooks=evaluation_hooks)

        raise ValueError('The mode {} is not supported yet.'.format(mode))
示例#7
0
def train(
    root_dir,
    agent,
    environment,
    training_loops,
    steps_per_loop=1,
    additional_metrics=(),
    # Params for checkpoints, summaries, and logging
    train_checkpoint_interval=10,
    policy_checkpoint_interval=10,
    log_interval=10,
    summary_interval=10):
  """A training driver."""

  if not common.resource_variables_enabled():
    raise RuntimeError(common.MISSING_RESOURCE_VARIABLES_ERROR)

  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(train_dir)
  train_summary_writer.set_as_default()

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(batch_size=environment.batch_size),
        tf_metrics.AverageEpisodeLengthMetric(
            batch_size=environment.batch_size),
    ] + list(additional_metrics)

    # Add to replay buffer and other agent specific observers.
    replay_buffer = build_replay_buffer(agent, environment.batch_size,
                                        steps_per_loop)
    agent_observers = [replay_buffer.add_batch] + train_metrics

    driver = dynamic_step_driver.DynamicStepDriver(
        env=environment,
        policy=agent.policy,
        num_steps=steps_per_loop * environment.batch_size,
        observers=agent_observers)

    collect_op, _ = driver.run()
    batch_size = driver.env.batch_size
    dataset = replay_buffer.as_dataset(
        sample_batch_size=batch_size,
        num_steps=steps_per_loop,
        single_deterministic_pass=True)
    trajectories, unused_info = tf.data.experimental.get_single_element(dataset)
    train_op = agent.train(experience=trajectories)
    clear_replay_op = replay_buffer.clear()

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        max_to_keep=1,
        agent=agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        max_to_keep=None,
        policy=agent.policy,
        global_step=global_step)

    summary_ops = []
    for train_metric in train_metrics:
      summary_ops.append(
          train_metric.tf_summaries(
              train_step=global_step, step_metrics=train_metrics[:2]))

    init_agent_op = agent.initialize()

    config_saver = utils.GinConfigSaverHook(train_dir, summarize_config=True)
    config_saver.begin()

    with tf.compat.v1.Session() as sess:
      # Initialize the graph.
      train_checkpointer.initialize_or_restore(sess)
      common.initialize_uninitialized_variables(sess)

      config_saver.after_create_session(sess)

      global_step_call = sess.make_callable(global_step)
      global_step_val = global_step_call()

      sess.run(train_summary_writer.init())
      sess.run(collect_op)

      if global_step_val == 0:
        # Save an initial checkpoint so the evaluator runs for global_step=0.
        policy_checkpointer.save(global_step=global_step_val)
        sess.run(init_agent_op)

      collect_call = sess.make_callable(collect_op)
      train_step_call = sess.make_callable([train_op, summary_ops])
      clear_replay_call = sess.make_callable(clear_replay_op)

      timed_at_step = global_step_val
      time_acc = 0
      steps_per_second_ph = tf.compat.v1.placeholder(
          tf.float32, shape=(), name='steps_per_sec_ph')
      steps_per_second_summary = tf.compat.v2.summary.scalar(
          name='global_steps_per_sec',
          data=steps_per_second_ph,
          step=global_step)

      for _ in range(training_loops):
        # Collect and train.
        start_time = time.time()
        collect_call()
        total_loss, _ = train_step_call()
        clear_replay_call()
        global_step_val = global_step_call()

        time_acc += time.time() - start_time

        total_loss = total_loss.loss

        if global_step_val % log_interval == 0:
          logging.info('step = %d, loss = %f', global_step_val, total_loss)
          steps_per_sec = (global_step_val - timed_at_step) / time_acc
          logging.info('%.3f steps/sec', steps_per_sec)
          sess.run(
              steps_per_second_summary,
              feed_dict={steps_per_second_ph: steps_per_sec})
          timed_at_step = global_step_val
          time_acc = 0

        if global_step_val % train_checkpoint_interval == 0:
          train_checkpointer.save(global_step=global_step_val)

        if global_step_val % policy_checkpoint_interval == 0:
          policy_checkpointer.save(global_step=global_step_val)
示例#8
0
    def model_fn(self, features, labels, mode, config=None, params=None):
        """Estimator model_fn.

    Args:
      features: This is the first item returned from the input_fn and parsed by
        tensorspec_utils.validate_and_pack. A spec_structure which fulfills the
        requirements of the self.get_feature_specification.
      labels: This is the second item returned from the input_fn and parsed by
        tensorspec_utils.validate_and_pack. A spec_structure which fulfills the
        requirements of the self.get_feature_specification.
      mode: (ModeKeys) Specifies if this is training, evaluation or prediction.
      config: (Optional tf.estimator.RunConfig or tf.contrib.tpu.RunConfig) Will
        receive what is passed to Estimator in config parameter, or the default
        config (tf.estimator.RunConfig). Allows updating things in your model_fn
        based on  configuration such as num_ps_replicas, or model_dir.
      params: An optional dict of hyper parameters that will be passed into
        input_fn and model_fn. Keys are names of parameters, values are basic
        python types. There are reserved keys for TPUEstimator, including
        'batch_size'.

    Raises:
      ValueError: If the mode key is not supported, not in [PREDICT, TRAIN,
        EVAL].

    Returns:
      An EstimatorSpec.
    """

        features = tensorspec_utils.validate_and_pack(
            expected_spec=self.get_feature_specification(mode),
            actual_tensors_or_spec=features,
            ignore_batch=True)
        if labels:
            labels = tensorspec_utils.validate_and_pack(
                expected_spec=self.get_label_specification(mode),
                actual_tensors_or_spec=labels,
                ignore_batch=True)
        inference_outputs = self.inference_network_fn(features, labels, mode,
                                                      config, params)

        # After inference_fn no new variables are allowed to be added, otherwise
        # we would not initialize these variables.
        self.maybe_init_from_checkpoint()

        if mode == tf.estimator.ModeKeys.PREDICT:
            model_fn_results = self.create_export_outputs_fn(
                features, inference_outputs, mode, config, params)
            export_outputs = None
            if isinstance(model_fn_results, tuple):
                predictions = model_fn_results[0]
                export_outputs = model_fn_results[1]
            elif isinstance(model_fn_results, dict):
                export_outputs = {}
                if len(model_fn_results) == 1:
                    name, output = list(model_fn_results.items())[0]
                    export_outputs[
                        name] = tf.estimator.export.RegressionOutput(output)
                export_outputs[tf.saved_model.signature_constants.
                               DEFAULT_SERVING_SIGNATURE_DEF_KEY] = (
                                   tf.estimator.export.PredictOutput(
                                       model_fn_results))
                predictions = model_fn_results
            else:
                raise ValueError(
                    'The create_export_outputs_fn should return a '
                    'tuple(predictions, export_outputs) or predictions.')

            return tf.estimator.EstimatorSpec(mode=mode,
                                              predictions=predictions,
                                              export_outputs=export_outputs)

        train_fn_result = self.model_train_fn(features, labels,
                                              inference_outputs, mode, config,
                                              params)
        if isinstance(train_fn_result, tf.Tensor):
            train_loss = train_fn_result
            train_outputs = {}
        elif isinstance(train_fn_result, tuple):
            train_loss = train_fn_result[0]
            train_outputs = train_fn_result[1]
        else:
            raise ValueError('The model_train_fn should return a '
                             'tuple(loss, train_outputs) or loss.')

        if mode == tf.estimator.ModeKeys.TRAIN:
            # Create the tf.train.Optimizer.
            optimizer = self.create_optimizer()

            # Required for batch norm usage.
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = self.create_train_op(train_loss, optimizer)

            self.add_summaries(features, labels, inference_outputs, train_loss,
                               train_outputs, mode, config, params)

            training_hooks = []

            # EstimatorSpec has training_chief_hooks, but TPUEstimatorSpec does not,
            # so we have to use training_hooks here and check is_chief.
            if config and config.is_chief:  # pytype: disable=attribute-error
                training_hooks.append(
                    gin_utils.GinConfigSaverHook(config.model_dir,
                                                 summarize_config=True))

            # `SyncReplicasOptimizer` needs to attach a training hook.
            if self._sync_replicas_optimizer:
                training_hooks.append(
                    self._sync_replicas_optimizer.make_session_run_hook(
                        config.is_chief))  # pytype: disable=attribute-error

            return tf.estimator.EstimatorSpec(mode=mode,
                                              loss=train_loss,
                                              train_op=train_op,
                                              training_hooks=training_hooks,
                                              scaffold=self._scaffold_fn())

        if mode == tf.estimator.ModeKeys.EVAL:
            self.add_summaries(features, labels, inference_outputs, train_loss,
                               train_outputs, mode, config, params)
            eval_metrics = self.model_eval_fn(features, labels,
                                              inference_outputs, train_loss,
                                              train_outputs, mode, config,
                                              params)
            evaluation_hooks = self.get_eval_hooks(config, params)
            if config and config.is_chief:  # pytype: disable=attribute-error
                eval_name = params.get('eval_name', 'eval')  # pytype: disable=attribute-error
                evaluation_hooks.append(
                    gin_utils.GinConfigSaverHook(os.path.join(
                        config.model_dir, eval_name),
                                                 summarize_config=True))
            return tf.estimator.EstimatorSpec(
                mode=mode,
                loss=train_loss,
                eval_metric_ops=eval_metrics,
                evaluation_hooks=evaluation_hooks)

        raise ValueError('The mode {} is not supported yet.'.format(mode))