def begin(self):
        """Build eval graph and restoring op."""
        self._timer.reset()
        self._graph = ops.Graph()
        self._global_step_tensor = training_util._get_or_create_global_step_read(
        )  # pylint: disable=protected-access
        with self._graph.as_default():
            (self._scaffold, self._update_op, self._eval_dict,
             self._all_hooks) = self._estimator._evaluate_build_graph(
                 self._input_fn, self._hooks, checkpoint_path=None)

            for h in self._all_hooks:
                if isinstance(h, tpu_estimator.TPUInfeedOutfeedSessionHook):
                    h._should_initialize_tpu = False  # pylint: disable=protected-access

            if self._scaffold.saver is not None:
                raise ValueError('InMemoryEval does not support custom saver')
            if self._scaffold.init_fn is not None:
                raise ValueError(
                    'InMemoryEval does not support custom init_fn')

            self._var_name_to_eval_var = {
                v.name: v
                for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
            }
            self._var_name_to_placeholder = {
                v.name: array_ops.placeholder(v.dtype)
                for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
            }
    def begin(self):
        """Build eval graph and restoring op."""
        self._timer.reset()
        self._graph = ops.Graph()
        self._global_step_tensor = training_util._get_or_create_global_step_read(
        )  # pylint: disable=protected-access
        with self._graph.as_default():
            with variable_scope.variable_scope('', use_resource=True):
                training_util.get_or_create_global_step()
            features, input_hooks = self._estimator._get_features_from_input_fn(  # pylint: disable=protected-access
                self._input_fn, model_fn_lib.ModeKeys.PREDICT)
            estimator_spec = self._estimator._call_model_fn(  # pylint: disable=protected-access
                features, None, model_fn_lib.ModeKeys.PREDICT,
                self._estimator.config)

            self._all_hooks = list(input_hooks) + list(
                estimator_spec.prediction_hooks)
            self._predictions = self._estimator._extract_keys(  # pylint: disable=protected-access
                estimator_spec.predictions,
                predict_keys=None)
            self._var_name_to_eval_var = {
                v.name: v
                for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
            }
            self._var_name_to_placeholder = {
                v.name: array_ops.placeholder(v.dtype)
                for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
            }
            logging.info('Placeholders: %s', self._var_name_to_placeholder)

            for h in self._all_hooks:
                logging.info('Hook: %s', h)
                if isinstance(h, tpu_estimator.TPUInfeedOutfeedSessionHook):
                    h._should_initialize_tpu = False  # pylint: disable=protected-access
Exemple #3
0
 def begin(self):
     self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
     if self._global_step_tensor is None:
         raise RuntimeError("Global step should be created to use LearningRateHook.")
     self._args = [self._global_step_tensor, self._learning_rate_var]
     self._learning_rate_ph = tf.placeholder(tf.float32, name='learning_rate_ph')
     self._learning_rate_op = self._learning_rate_var.assign(self._learning_rate_ph)
 def begin(self):
     self._next_step = None
     self._global_step_tensor = training_util._get_or_create_global_step_read(
     )  # pylint: disable=protected-access
     if self._global_step_tensor is None:
         raise RuntimeError(
             "Global step should be created to use ProfilerHook.")
Exemple #5
0
 def begin(self):
     self._worker_is_started = False
     self._global_step_tensor = training_util._get_or_create_global_step_read(
     )  # pylint: disable=protected-access
     if self._global_step_tensor is None:
         raise RuntimeError(
             "Global step should be created to use _GlobalStepWaiterHook.")
Exemple #6
0
 def begin(self):
     self._global_step_tensor = training_util._get_or_create_global_step_read(
     )  # pylint: disable=protected-access
     if self._global_step_tensor is None:
         raise RuntimeError(
             'Global step should be created to use StopAtCheckpointStepHook.'
         )
    def _setup(self):
        self.steps = 0
        self.session = tf.Session()
        (train_x, train_y), (test_x, test_y) = iris_data.load_data()
        self.train_x = train_x
        self.train_y = train_y

        self.test_x = test_x
        self.test_y = test_y

        # Feature columns describe how to use the input.
        my_feature_columns = []
        for key in train_x.keys():
            my_feature_columns.append(
                tf.feature_column.numeric_column(key=key))

        layer_size = int(self.config['layer_size'])

        # Build 2 hidden layer DNN with 10, 10 units respectively.
        self.classifier = tf.estimator.Estimator(
            model_fn=my_model,
            params={
                'feature_columns': my_feature_columns,
                # Two hidden layers of 10 nodes each.
                'hidden_units': [layer_size, layer_size],
                # The model must choose between 3 classes.
                'n_classes': 3,
            })

        self.saver = None
        self.global_step_tensor = training_util._get_or_create_global_step_read(
        )  # pylint: disable=protected-access
Exemple #8
0
    def _train_model_default(self, input_fn, hooks, saving_listeners,
                             save_best_ckpt):
        """Initiate training with `input_fn`, without `DistributionStrategies`.

        Args:
            input_fn: A function that provides input data for training as mini-batches.
            hooks: List of `tf.train.SessionRunHook` subclass instances. Used for
            callbacks inside the training loop.
            saving_listeners: list of `tf.train.CheckpointSaverListener` objects. Used
            for callbacks that run immediately before or after checkpoint savings.
            save_best_ckpt: boolean

        Returns:
            Loss from training
        """
        worker_hooks = []
        with ops.Graph().as_default() as g:
            random_seed.set_random_seed(self._config.tf_random_seed)
            global_step_tensor = self._create_and_assert_global_step(g)
            training_util._get_or_create_global_step_read(g)

            if self._train_with_eval:
                self.handler = array_ops.placeholder(dtypes.string,
                                                     shape=(),
                                                     name="Handler")
                features, labels, self.train_iterator, self.eval_iterator, input_hooks = (
                    self._get_features_and_labels_for_train_and_eval(
                        input_fn, self.handler))
            else:
                self.handler, self.train_iterator, self.eval_iterator = None, None, None
                features, labels, input_hooks = (
                    self._get_features_and_labels_from_input_fn(
                        input_fn, model_fn_lib.ModeKeys.TRAIN))

            worker_hooks.extend(input_hooks)

            estimator_spec = self._call_model_fn(features, labels,
                                                 model_fn_lib.ModeKeys.TRAIN,
                                                 self.config)
            self._feed_dict = self._params["model_instances"][0].feed_dict

            return self._train_with_estimator_spec(estimator_spec,
                                                   worker_hooks, hooks,
                                                   global_step_tensor,
                                                   saving_listeners,
                                                   save_best_ckpt)
Exemple #9
0
 def begin(self):
     if self._summary_writer is None and self._output_dir:
         self._summary_writer = SummaryWriterCache.get(self._output_dir)
     self._global_step_tensor = training_util._get_or_create_global_step_read(
     )
     if self._global_step_tensor is None:
         raise RuntimeError(
             "Global step should be created to use StepCounterHook.")
Exemple #10
0
 def begin(self):
   if self._summary_writer is None and self._output_dir:
     self._summary_writer = SummaryWriterCache.get(self._output_dir)
   self._next_step = None
   self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
   if self._global_step_tensor is None:
     raise RuntimeError(
         "Global step should be created to use SummarySaverHook.")
 def begin(self):
   self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
   self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
   if self._global_step_tensor is None:
     raise RuntimeError(
         "Global step should be created to use CheckpointSaverHook.")
   for l in self._listeners:
     l.begin()
Exemple #12
0
 def begin(self):
     self._global_step_tensor = training_util._get_or_create_global_step_read(
     )
     if self._global_step_tensor is None:
         raise RuntimeError(
             "Global step should be created to use EarlyStoppingHook.")
     self._prev_step = -1
     self._step = 0
Exemple #13
0
 def begin(self):
   self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
   self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
   if self._global_step_tensor is None:
     raise RuntimeError(
         "Global step should be created to use CheckpointSaverHook.")
   for l in self._listeners:
     l.begin()
Exemple #14
0
 def begin(self):
     # if self._summary_writer is None and self._output_dir:
     #   self._summary_writer = SummaryWriterCache.get(self._output_dir)
     self._global_step_tensor = training_util._get_or_create_global_step_read(
     )  # pylint: disable=protected-access
     if self._global_step_tensor is None:
         raise RuntimeError(
             "Global step should be created to use StepCounterHook.")
     self._summary_tag = training_util.get_global_step().op.name + "/sec"
 def begin(self):
     self._file_writer = writer.FileWriter(self._output_dir,
                                           filename_suffix="",
                                           session=ops.get_default_session)
     self._next_step = None
     self._global_step_tensor = training_util._get_or_create_global_step_read(
     )  # pylint: disable=protected-access
     if self._global_step_tensor is None:
         raise RuntimeError(
             "Global step should be created to use ProfilerHook.")
 def begin(self):
     self._summary_writer = writer.FileWriter(
         self._checkpoint_dir,
         session=ops.get_default_session,
         filename_suffix="")
     self._global_step_tensor = training_util._get_or_create_global_step_read(
     )  # pylint: disable=protected-access
     if self._global_step_tensor is None:
         raise RuntimeError(
             "Global step should be created to use CheckpointSaverHook.")
     for l in self._listeners:
         l.begin()
 def begin(self):
     if self._summary_writer is None and self._output_dir:
         self._summary_writer = writer.FileWriter(
             self._output_dir,
             session=ops.get_default_session,
             filename_suffix="")
     self._global_step_tensor = training_util._get_or_create_global_step_read(
     )  # pylint: disable=protected-access
     if self._global_step_tensor is None:
         raise RuntimeError(
             "Global step should be created to use StepCounterHook.")
     self._summary_tag = training_util.get_global_step().op.name + "/sec"
 def test_reads_before_increments(self):
   with ops.Graph().as_default():
     training_util.create_global_step()
     read_tensor = training_util._get_or_create_global_step_read()
     inc_op = training_util._increment_global_step(1)
     inc_three_op = training_util._increment_global_step(3)
     with monitored_session.MonitoredTrainingSession() as sess:
       read_value, _ = sess.run([read_tensor, inc_op])
       self.assertEqual(0, read_value)
       read_value, _ = sess.run([read_tensor, inc_three_op])
       self.assertEqual(1, read_value)
       read_value = sess.run(read_tensor)
       self.assertEqual(4, read_value)
Exemple #19
0
 def test_reads_before_increments(self):
     with ops.Graph().as_default():
         training_util.create_global_step()
         read_tensor = training_util._get_or_create_global_step_read()
         inc_op = training_util._increment_global_step(1)
         inc_three_op = training_util._increment_global_step(3)
         with monitored_session.MonitoredTrainingSession() as sess:
             read_value, _ = sess.run([read_tensor, inc_op])
             self.assertEqual(0, read_value)
             read_value, _ = sess.run([read_tensor, inc_three_op])
             self.assertEqual(1, read_value)
             read_value = sess.run(read_tensor)
             self.assertEqual(4, read_value)
Exemple #20
0
 def get_temperature(self, add_to_tensorboard=False):
     global_step = tf.cast(training_util._get_or_create_global_step_read(),
                           tf.float32)
     e = tf.constant(0.000001)
     temperature = tf.maximum(
         .01,
         tf.constant(1.0) / (log(global_step, 3.0) + e))
     if add_to_tensorboard:
         with tf.control_dependencies(
             [tf.print("Temperature", temperature, summarize=-1)]):
             tf.summary.histogram("Temperature",
                                  temperature,
                                  family=self.scope_name)
     return temperature
 def begin(self):
     if self.summary_writer is None and self.output_dir:
         self.summary_writer = SummaryWriterCache.get(self.output_dir)
     graph = ops.get_default_graph()
     self.fake_seq = graph.get_tensor_by_name("model/" + FAKE_PROTEINS +
                                              ":0")
     self.labels = graph.get_tensor_by_name("model/" + LABELS + ":0")
     self.d_score = graph.get_tensor_by_name("model/d_score:0")
     self.global_step_tensor = training_util._get_or_create_global_step_read(
     )
     if self.global_step_tensor is None:
         raise RuntimeError("Could not global step tensor")
     if self.fake_seq is None:
         raise RuntimeError("Could not get fake seq tensor")
Exemple #22
0
 def begin(self):
     g = tf.get_default_graph()
     for _, tensor in self._tensor_dict.items():
         if tensor.graph != g:
             raise ValueError("metric tensor %s graph not equal "
                              "to current graph %s." % (tensor, g))
     self._next_step = None
     self._global_step_tensor = \
         training_util._get_or_create_global_step_read() # pylint: disable=protected-access
     if self._global_step_tensor is None:
         raise RuntimeError("Global step should be created "
                            "to use GlobalStepTensorStatsHook.")
     self._global_step_key = "global_step"
     i = 0
     while self._global_step_key in self._tensor_dict:
         self._global_step_key = "global_step_" + str(i)
         i += 1
     self._tensor_dict[self._global_step_key] = self._global_step_tensor
Exemple #23
0
    def begin(self):
        """
        Is called once before the default graph in the active tensorflow session is
        finalized and the training has starts.
        The hook can modify the graph by adding new operations to it.
        After the begin() call the graph will be finalized and the other callbacks can not modify
        the graph anymore. Second call of begin() on the same graph, should not change the graph.
        """
        # Create a summary writer if possible.
        if self._summary_writer is None and self._output_dir:
            self._summary_writer = summary_io.SummaryWriterCache.get(
                self._output_dir)

        # Get read access to the global step tensor.
        self._global_step_tensor = training_util._get_or_create_global_step_read(
        )  # pylint: disable=protected-access
        if self._global_step_tensor is None:
            raise RuntimeError(
                "Global step should be created to use StepCounterHook.")
Exemple #24
0
 def test_reads_from_cache(self):
     with ops.Graph().as_default():
         training_util.create_global_step()
         first = training_util._get_or_create_global_step_read()
         second = training_util._get_or_create_global_step_read()
         self.assertEqual(first, second)
Exemple #25
0
 def test_global_step_read_is_none_if_there_is_no_global_step(self):
     with ops.Graph().as_default():
         self.assertIsNone(training_util._get_or_create_global_step_read())
         training_util.create_global_step()
         self.assertIsNotNone(
             training_util._get_or_create_global_step_read())
Exemple #26
0
    def _train_model(self, input_fn, hooks):
        all_hooks = []
        with ops.Graph().as_default() as g, g.device(self._device_fn):
            random_seed.set_random_seed(self._config.tf_random_seed)
            global_step_tensor = self._create_and_assert_global_step(g)
            global_step_read_tensor = training_util._get_or_create_global_step_read(
            )  # pylint: disable=protected-access
            with ops.control_dependencies([global_step_read_tensor]):
                features, labels = self._get_features_and_labels_from_input_fn(
                    input_fn, model_fn_lib.ModeKeys.TRAIN)
            estimator_spec = self._call_model_fn(features, labels,
                                                 model_fn_lib.ModeKeys.TRAIN,
                                                 self.config)
            ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
            all_hooks.extend(hooks)
            all_hooks.extend([
                training.NanTensorHook(estimator_spec.loss),
                training.LoggingTensorHook(
                    {
                        'loss': estimator_spec.loss,
                        'step': global_step_tensor
                    },
                    every_n_iter=100)
            ])
            all_hooks.extend(estimator_spec.training_hooks)

            if not (estimator_spec.scaffold.saver
                    or ops.get_collection(ops.GraphKeys.SAVERS)):
                ops.add_to_collection(
                    ops.GraphKeys.SAVERS,
                    training.Saver(
                        sharded=True,
                        max_to_keep=self._config.keep_checkpoint_max,
                        keep_checkpoint_every_n_hours=(
                            self._config.keep_checkpoint_every_n_hours),
                        defer_build=True,
                        save_relative_paths=True))

            chief_hooks = []
            if (self._config.save_checkpoints_secs
                    or self._config.save_checkpoints_steps):
                saver_hook_exists = any([
                    isinstance(h, training.CheckpointSaverHook)
                    for h in (all_hooks + chief_hooks +
                              list(estimator_spec.training_chief_hooks))
                ])
                if not saver_hook_exists:
                    chief_hooks = [
                        training.CheckpointSaverHook(
                            self._model_dir,
                            save_secs=self._config.save_checkpoints_secs,
                            save_steps=self._config.save_checkpoints_steps,
                            scaffold=estimator_spec.scaffold)
                    ]
            with training.MonitoredTrainingSession(
                    master=self._config.master,
                    is_chief=self._config.is_chief,
                    checkpoint_dir=self._model_dir,
                    scaffold=estimator_spec.scaffold,
                    hooks=all_hooks,
                    chief_only_hooks=(
                        tuple(chief_hooks) +
                        tuple(estimator_spec.training_chief_hooks)),
                    save_checkpoint_secs=0,  # Saving is handled by a hook.
                    save_summaries_steps=self._config.save_summary_steps,
                    config=self._session_config,
                    log_step_count_steps=self._config.log_step_count_steps
            ) as mon_sess:
                loss = None
                while not mon_sess.should_stop():
                    _, loss = mon_sess.run(
                        [estimator_spec.train_op, estimator_spec.loss])
            return loss
 def begin(self):
     self.merged_ops = tf.summary.merge_all()
     self._global_step_tensor = training_util._get_or_create_global_step_read(
     )
 def begin(self):
     self._global_step_tensor = training_util._get_or_create_global_step_read(
     )  # pylint: disable=protected-access
Exemple #29
0
 def begin(self):
     self._global_step_tensor = training_util._get_or_create_global_step_read(
     )
     self._previous_step = 0
     self._step = 0
 def test_reads_from_cache(self):
   with ops.Graph().as_default():
     training_util.create_global_step()
     first = training_util._get_or_create_global_step_read()
     second = training_util._get_or_create_global_step_read()
     self.assertEqual(first, second)
 def test_global_step_read_is_none_if_there_is_no_global_step(self):
   with ops.Graph().as_default():
     self.assertIsNone(training_util._get_or_create_global_step_read())
     training_util.create_global_step()
     self.assertIsNotNone(training_util._get_or_create_global_step_read())
Exemple #32
0
 def begin(self):
   self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
   if self._global_step_tensor is None:
     raise RuntimeError(
         'Global step should be created to use StopAtCheckpointStepHook.')
Exemple #33
0
 def begin(self):
     self._summary_writer = tf.summary.FileWriterCache.get(
         self._summary_dir)
     self._global_step_tensor = training_util._get_or_create_global_step_read(
     )
  def _train_model(self, input_fn, hooks, saving_listeners):
    worker_hooks = []
    with ops.Graph().as_default() as g, g.device(self._device_fn):
      random_seed.set_random_seed(self._config.tf_random_seed)
      global_step_tensor = self._create_and_assert_global_step(g)
      global_step_read_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
      features, labels = self._get_features_and_labels_from_input_fn(
          input_fn, model_fn_lib.ModeKeys.TRAIN)
      with ops.control_dependencies([global_step_read_tensor]):
        estimator_spec = self._call_model_fn(
            features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
      # Check if the user created a loss summary, and add one if they didn't.
      # We assume here that the summary is called 'loss'. If it is not, we will
      # make another one with the name 'loss' to ensure it shows up in the right
      # graph in TensorBoard.
      if not any([x.op.name == 'loss'
                  for x in ops.get_collection(ops.GraphKeys.SUMMARIES)]):
        summary.scalar('loss', estimator_spec.loss)
      ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
      worker_hooks.extend(hooks)
      worker_hooks.extend([
          training.NanTensorHook(estimator_spec.loss),
          training.LoggingTensorHook(
              {
                  'loss': estimator_spec.loss,
                  'step': global_step_tensor
              },
              every_n_iter=100)
      ])
      worker_hooks.extend(estimator_spec.training_hooks)

      if not (estimator_spec.scaffold.saver or
              ops.get_collection(ops.GraphKeys.SAVERS)):
        ops.add_to_collection(
            ops.GraphKeys.SAVERS,
            training.Saver(
                sharded=True,
                max_to_keep=self._config.keep_checkpoint_max,
                keep_checkpoint_every_n_hours=(
                    self._config.keep_checkpoint_every_n_hours),
                defer_build=True,
                save_relative_paths=True))

      chief_hooks = []
      all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks)
      saver_hooks = [
          h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)]
      if (self._config.save_checkpoints_secs or
          self._config.save_checkpoints_steps):
        if not saver_hooks:
          chief_hooks = [
              training.CheckpointSaverHook(
                  self._model_dir,
                  save_secs=self._config.save_checkpoints_secs,
                  save_steps=self._config.save_checkpoints_steps,
                  scaffold=estimator_spec.scaffold)
          ]
          saver_hooks = [chief_hooks[0]]
      if saving_listeners:
        if not saver_hooks:
          raise ValueError(
              'There should be a CheckpointSaverHook to use saving_listeners. '
              'Please set one of the RunConfig.save_checkpoints_steps or '
              'RunConfig.save_checkpoints_secs.')
        else:
          # It is expected to have one CheckpointSaverHook. If multiple, we pick
          # up the first one to add listener.
          saver_hooks[0]._listeners.extend(saving_listeners)  # pylint: disable=protected-access
      with training.MonitoredTrainingSession(
          master=self._config.master,
          is_chief=self._config.is_chief,
          checkpoint_dir=self._model_dir,
          scaffold=estimator_spec.scaffold,
          hooks=worker_hooks,
          chief_only_hooks=(
              tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)),
          save_checkpoint_secs=0,  # Saving is handled by a hook.
          save_summaries_steps=self._config.save_summary_steps,
          config=self._session_config,
          log_step_count_steps=self._config.log_step_count_steps) as mon_sess:
        loss = None
        while not mon_sess.should_stop():
          _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
      return loss
 def begin(self):
   self._next_step = None
   self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
   if self._global_step_tensor is None:
     raise RuntimeError("Global step should be created to use ProfilerHook.")
Exemple #36
0
  def _train_model(self, input_fn, hooks, saving_listeners):
    worker_hooks = []
    with ops.Graph().as_default() as g, g.device(self._device_fn):
      random_seed.set_random_seed(self._config.tf_random_seed)
      global_step_tensor = self._create_and_assert_global_step(g)
      global_step_read_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
      with ops.control_dependencies([global_step_read_tensor]):
        features, labels = self._get_features_and_labels_from_input_fn(
            input_fn, model_fn_lib.ModeKeys.TRAIN)
      estimator_spec = self._call_model_fn(
          features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
      # Check if the user created a loss summary, and add one if they didn't.
      # We assume here that the summary is called 'loss'. If it is not, we will
      # make another one with the name 'loss' to ensure it shows up in the right
      # graph in TensorBoard.
      if not any([x.op.name == 'loss'
                  for x in ops.get_collection(ops.GraphKeys.SUMMARIES)]):
        summary.scalar('loss', estimator_spec.loss)
      ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
      worker_hooks.extend(hooks)
      worker_hooks.extend([
          training.NanTensorHook(estimator_spec.loss),
          training.LoggingTensorHook(
              {
                  'loss': estimator_spec.loss,
                  'step': global_step_tensor
              },
              every_n_iter=100)
      ])
      worker_hooks.extend(estimator_spec.training_hooks)

      if not (estimator_spec.scaffold.saver or
              ops.get_collection(ops.GraphKeys.SAVERS)):
        ops.add_to_collection(
            ops.GraphKeys.SAVERS,
            training.Saver(
                sharded=True,
                max_to_keep=self._config.keep_checkpoint_max,
                keep_checkpoint_every_n_hours=(
                    self._config.keep_checkpoint_every_n_hours),
                defer_build=True,
                save_relative_paths=True))

      chief_hooks = []
      all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks)
      saver_hooks = [
          h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)]
      if (self._config.save_checkpoints_secs or
          self._config.save_checkpoints_steps):
        if not saver_hooks:
          chief_hooks = [
              training.CheckpointSaverHook(
                  self._model_dir,
                  save_secs=self._config.save_checkpoints_secs,
                  save_steps=self._config.save_checkpoints_steps,
                  scaffold=estimator_spec.scaffold)
          ]
          saver_hooks = [chief_hooks[0]]
      if saving_listeners:
        if not saver_hooks:
          raise ValueError(
              'There should be a CheckpointSaverHook to use saving_listeners. '
              'Please set one of the RunConfig.save_checkpoints_steps or '
              'RunConfig.save_checkpoints_secs.')
        else:
          # It is expected to have one CheckpointSaverHook. If multiple, we pick
          # up the first one to add listener.
          saver_hooks[0]._listeners.extend(saving_listeners)  # pylint: disable=protected-access
      with training.MonitoredTrainingSession(
          master=self._config.master,
          is_chief=self._config.is_chief,
          checkpoint_dir=self._model_dir,
          scaffold=estimator_spec.scaffold,
          hooks=worker_hooks,
          chief_only_hooks=(
              tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)),
          save_checkpoint_secs=0,  # Saving is handled by a hook.
          save_summaries_steps=self._config.save_summary_steps,
          config=self._session_config,
          log_step_count_steps=self._config.log_step_count_steps) as mon_sess:
        loss = None
        while not mon_sess.should_stop():
          _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
      return loss
 def begin(self):
   self._worker_is_started = False
   self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
   if self._global_step_tensor is None:
     raise RuntimeError(
         "Global step should be created to use _GlobalStepWaiterHook.")