コード例 #1
0
  def _train_model(self, input_fn, hooks):
    all_hooks = []
    with ops.Graph().as_default() as g, g.device(self._device_fn):
      random_seed.set_random_seed(self._config.tf_random_seed)
      global_step_tensor = training.create_global_step(g)
      with ops.device('/cpu:0'):
        features, labels = input_fn()
      estimator_spec = self._call_model_fn(features, labels,
                                           model_fn_lib.ModeKeys.FIT)
      ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
      all_hooks.extend([
          training.NanTensorHook(estimator_spec.loss),
          training.LoggingTensorHook(
              {
                  'loss': estimator_spec.loss,
                  'step': global_step_tensor
              },
              every_n_iter=100)
      ])
      all_hooks.extend(hooks)
      all_hooks.extend(estimator_spec.training_hooks)

      scaffold = estimator_spec.scaffold or training.Scaffold()
      if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)):
        ops.add_to_collection(ops.GraphKeys.SAVERS,
                              training.Saver(
                                  sharded=True,
                                  max_to_keep=self._config.keep_checkpoint_max,
                                  defer_build=True))

      chief_hooks = []
      if (self._config.save_checkpoints_secs or
          self._config.save_checkpoints_steps):
        saver_hook_exists = any([
            isinstance(h, training.CheckpointSaverHook)
            for h in (all_hooks + chief_hooks +
                      estimator_spec.training_chief_hooks)
        ])
        if not saver_hook_exists:
          chief_hooks = [
              training.CheckpointSaverHook(
                  self._model_dir,
                  save_secs=self._config.save_checkpoints_secs,
                  save_steps=self._config.save_checkpoints_steps,
                  scaffold=scaffold)
          ]
      with training.MonitoredTrainingSession(
          master=self._config.master,
          is_chief=self._config.is_chief,
          checkpoint_dir=self._model_dir,
          scaffold=scaffold,
          hooks=all_hooks,
          chief_only_hooks=chief_hooks + estimator_spec.training_chief_hooks,
          save_checkpoint_secs=0,  # Saving is handled by a hook.
          save_summaries_steps=self._config.save_summary_steps,
          config=config_pb2.ConfigProto(allow_soft_placement=True)) as mon_sess:
        loss = None
        while not mon_sess.should_stop():
          _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
      return loss
コード例 #2
0
    def predict(
        self,
        input_fn,
        predict_keys=None,
        hooks=None,
        checkpoint_dir=None,
        yield_single_examples=True,
    ):
        """Arguments are same with Estimator.predict"""
        with context.graph_mode():
            hooks = estimator._check_hooks_type(hooks)
            # Check that model has been trained.
            if not checkpoint_dir:
                raise ValueError("No checkpoint_dir")
            with ops.Graph().as_default() as g, g.device(self._device_fn):
                random_seed.set_random_seed(self._config.tf_random_seed)
                self._create_and_assert_global_step(g)
                features, input_hooks = self._get_features_from_input_fn(
                    input_fn, model_fn_lib.ModeKeys.PREDICT
                )
                estimator_spec = self._call_model_fn(
                    features,
                    None,
                    model_fn_lib.ModeKeys.PREDICT,
                    self.config,
                )

                predictions = self._extract_keys(
                    estimator_spec.predictions, predict_keys
                )
                all_hooks = list(input_hooks)
                all_hooks.extend(hooks)
                all_hooks.extend(
                    list(estimator_spec.prediction_hooks or [])
                )
                with training.MonitoredTrainingSession(
                    is_chief=args.worker_type=="chief",
                    master=config.master,
                    checkpoint_dir=checkpoint_dir,
                    config=config.session_config,
                ) as mon_sess:

                    while not mon_sess.should_stop():
                        preds_evaluated = mon_sess.run(predictions)
                        if not yield_single_examples:
                            yield preds_evaluated
                        elif not isinstance(predictions, dict):
                            for pred in preds_evaluated:
                                yield pred
                        else:
                            for i in range(
                                self._extract_batch_length(preds_evaluated)
                            ):
                                yield {
                                    key: value[i]
                                    for key, value in six.iteritems(
                                        preds_evaluated
                                    )
                                }
コード例 #3
0
def get_workers(num_workers, replicas_to_aggregate, workers):
    sessions = []
    graphs = []
    train_ops = []
    for worker_id in range(num_workers):
        graph = ops.Graph()
        is_chief = (worker_id == 0)
        with graph.as_default():
            with ops.device("/job:ps/task:0"):
                global_step = variables.VariableV1(0,
                                                   name="global_step",
                                                   trainable=False)
                var_0 = variables.VariableV1(0.0, name="v0")
            with ops.device("/job:ps/task:1"):
                var_1 = variables.VariableV1(1.0, name="v1")
                var_sparse = variables.VariableV1([[3.0], [4.0]],
                                                  name="v_sparse")

            with ops.device("/job:worker/task:" + str(worker_id)):
                grads_0 = constant_op.constant(0.1 + worker_id * 0.2)
                grads_1 = constant_op.constant(0.9 + worker_id * 0.2)
                # This is to test against sparse gradients.
                grads_sparse = ops.IndexedSlices(
                    constant_op.constant([0.1 + worker_id * 0.2], shape=[1,
                                                                         1]),
                    constant_op.constant([1]), constant_op.constant([2, 1]))
                sgd_opt = gradient_descent.GradientDescentOptimizer(2.0)
                sync_rep_opt = training.SyncReplicasOptimizer(
                    sgd_opt,
                    replicas_to_aggregate=replicas_to_aggregate,
                    total_num_replicas=num_workers)
                train_op = [
                    sync_rep_opt.apply_gradients(
                        zip([grads_0, grads_1, grads_sparse],
                            [var_0, var_1, var_sparse]),
                        global_step=global_step)
                ]
                sync_replicas_hook = sync_rep_opt.make_session_run_hook(
                    is_chief, num_tokens=num_workers)

            # Creates MonitoredSession
            session = training.MonitoredTrainingSession(
                master=workers[worker_id].target,
                is_chief=is_chief,
                hooks=[sync_replicas_hook])

        sessions.append(session)
        graphs.append(graph)
        train_ops.append(train_op)

    return sessions, graphs, train_ops
コード例 #4
0
def _get_workers(num_workers, steps, workers):
    sessions = []
    graphs = []
    train_ops = []
    for worker_id in range(num_workers):
        graph = ops.Graph()
        is_chief = (worker_id == 0)
        with graph.as_default():
            worker_device = "/job:worker/task:%d/cpu:0" % (worker_id)
            ma_coustom = model_average_optimizer.ModelAverageCustomGetter(
                worker_device=worker_device)
            with variable_scope.variable_scope(
                    "", custom_getter=ma_coustom), ops.device(
                        device_setter.replica_device_setter(
                            worker_device=worker_device,
                            ps_device="/job:ps/task:0/cpu:0",
                            ps_tasks=1)):

                global_step = variables.Variable(0,
                                                 name="global_step",
                                                 trainable=False)
                var_0 = variable_scope.get_variable(initializer=0.0, name="v0")
                var_1 = variable_scope.get_variable(initializer=1.0, name="v1")

            with ops.device("/job:worker/task:" + str(worker_id)):
                if worker_id == 0:
                    grads_0 = constant_op.constant(-1.0)
                    grads_1 = constant_op.constant(-1.0)
                else:
                    grads_0 = constant_op.constant(-2.0)
                    grads_1 = constant_op.constant(-2.0)
                sgd_opt = gradient_descent.GradientDescentOptimizer(1.0)
                opt = model_average_optimizer.ModelAverageOptimizer(
                    opt=sgd_opt,
                    num_worker=num_workers,
                    ma_custom_getter=ma_coustom,
                    is_chief=is_chief,
                    interval_steps=steps)
                train_op = [
                    opt.apply_gradients([[grads_0, var_0], [grads_1, var_1]],
                                        global_step)
                ]
            easgd_hook = opt.make_session_run_hook()
            # Creates MonitoredSession
            sess = training.MonitoredTrainingSession(workers[worker_id].target,
                                                     hooks=[easgd_hook])

        sessions.append(sess)
        graphs.append(graph)
        train_ops.append(train_op)
    return sessions, graphs, train_ops
コード例 #5
0
def _get_workers(num_workers, period, workers, moving_rate):
    sessions = []
    graphs = []
    train_ops = []
    for worker_id in range(num_workers):
        graph = ops.Graph()
        is_chief = (worker_id == 0)
        with graph.as_default():
            worker_device = "/job:worker/task:%d/cpu:0" % (worker_id)
            ea_coustom = ElasticAverageCustomGetter(
                worker_device=worker_device)
            with variable_scope.variable_scope(
                    '', custom_getter=ea_coustom), ops.device(
                        device_setter.replica_device_setter(
                            worker_device=worker_device,
                            ps_device="/job:ps/task:0/cpu:0",
                            ps_tasks=1)):
                global_step = variables.Variable(0,
                                                 name='global_step',
                                                 trainable=False)
                var_0 = variable_scope.get_variable(initializer=0.0, name="v0")
                var_1 = variable_scope.get_variable(initializer=1.0, name="v1")

            with ops.device("/job:worker/task:" + str(worker_id)):
                grads_0 = constant_op.constant(-1.0)
                grads_1 = constant_op.constant(-1.0)

                sgd_opt = gradient_descent.GradientDescentOptimizer(1.0)
                opt = ElasticAverageOptimizer(opt=sgd_opt,
                                              num_worker=num_workers,
                                              moving_rate=moving_rate,
                                              communication_period=period,
                                              ea_custom_getter=ea_coustom)
                train_op = [
                    opt.apply_gradients(([grads_0, var_0], [grads_1, var_1]),
                                        global_step)
                ]
                easgd_hook = opt.make_session_run_hook(is_chief, worker_id)
            # Creates MonitoredSession
            sess = training.MonitoredTrainingSession(workers[worker_id].target,
                                                     hooks=[easgd_hook])

        sessions.append(sess)
        graphs.append(graph)
        train_ops.append(train_op)

    return sessions, graphs, train_ops
コード例 #6
0
  def _train_model(self, input_fn, hooks, saving_listeners):
    worker_hooks = []
    with ops.Graph().as_default() as g, g.device(self._device_fn):
      random_seed.set_random_seed(self._config.tf_random_seed)
      global_step_tensor = self._create_and_assert_global_step(g)
      global_step_read_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
      features, labels = self._get_features_and_labels_from_input_fn(
          input_fn, model_fn_lib.ModeKeys.TRAIN)
      with ops.control_dependencies([global_step_read_tensor]):
        estimator_spec = self._call_model_fn(
            features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
      # Check if the user created a loss summary, and add one if they didn't.
      # We assume here that the summary is called 'loss'. If it is not, we will
      # make another one with the name 'loss' to ensure it shows up in the right
      # graph in TensorBoard.
      if not any([x.op.name == 'loss'
                  for x in ops.get_collection(ops.GraphKeys.SUMMARIES)]):
        summary.scalar('loss', estimator_spec.loss)
      ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
      worker_hooks.extend(hooks)
      worker_hooks.extend([
          training.NanTensorHook(estimator_spec.loss),
          training.LoggingTensorHook(
              {
                  'loss': estimator_spec.loss,
                  'step': global_step_tensor
              },
              every_n_iter=100)
      ])
      worker_hooks.extend(estimator_spec.training_hooks)

      if not (estimator_spec.scaffold.saver or
              ops.get_collection(ops.GraphKeys.SAVERS)):
        ops.add_to_collection(
            ops.GraphKeys.SAVERS,
            training.Saver(
                sharded=True,
                max_to_keep=self._config.keep_checkpoint_max,
                keep_checkpoint_every_n_hours=(
                    self._config.keep_checkpoint_every_n_hours),
                defer_build=True,
                save_relative_paths=True))

      chief_hooks = []
      all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks)
      saver_hooks = [
          h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)]
      if (self._config.save_checkpoints_secs or
          self._config.save_checkpoints_steps):
        if not saver_hooks:
          chief_hooks = [
              training.CheckpointSaverHook(
                  self._model_dir,
                  save_secs=self._config.save_checkpoints_secs,
                  save_steps=self._config.save_checkpoints_steps,
                  scaffold=estimator_spec.scaffold)
          ]
          saver_hooks = [chief_hooks[0]]
      if saving_listeners:
        if not saver_hooks:
          raise ValueError(
              'There should be a CheckpointSaverHook to use saving_listeners. '
              'Please set one of the RunConfig.save_checkpoints_steps or '
              'RunConfig.save_checkpoints_secs.')
        else:
          # It is expected to have one CheckpointSaverHook. If multiple, we pick
          # up the first one to add listener.
          saver_hooks[0]._listeners.extend(saving_listeners)  # pylint: disable=protected-access
      with training.MonitoredTrainingSession(
          master=self._config.master,
          is_chief=self._config.is_chief,
          checkpoint_dir=self._model_dir,
          scaffold=estimator_spec.scaffold,
          hooks=worker_hooks,
          chief_only_hooks=(
              tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)),
          save_checkpoint_secs=0,  # Saving is handled by a hook.
          save_summaries_steps=self._config.save_summary_steps,
          config=self._session_config,
          log_step_count_steps=self._config.log_step_count_steps) as mon_sess:
        loss = None
        while not mon_sess.should_stop():
          _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
      return loss
コード例 #7
0
ファイル: estimator.py プロジェクト: liuenliang/tensorflow-1
    def _train_model(self, input_fn, hooks):
        all_hooks = []
        with ops.Graph().as_default() as g, g.device(self._device_fn):
            random_seed.set_random_seed(self._config.tf_random_seed)
            global_step_tensor = self._create_and_assert_global_step(g)
            features, labels = self._get_features_and_labels_from_input_fn(
                input_fn, model_fn_lib.ModeKeys.TRAIN)
            estimator_spec = self._call_model_fn(features, labels,
                                                 model_fn_lib.ModeKeys.TRAIN,
                                                 self.config)
            ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
            all_hooks.extend(hooks)
            all_hooks.extend([
                training.NanTensorHook(estimator_spec.loss),
                training.LoggingTensorHook(
                    {
                        'loss': estimator_spec.loss,
                        'step': global_step_tensor
                    },
                    every_n_iter=100)
            ])
            all_hooks.extend(estimator_spec.training_hooks)

            if not (estimator_spec.scaffold.saver
                    or ops.get_collection(ops.GraphKeys.SAVERS)):
                ops.add_to_collection(
                    ops.GraphKeys.SAVERS,
                    training.Saver(
                        sharded=True,
                        max_to_keep=self._config.keep_checkpoint_max,
                        keep_checkpoint_every_n_hours=(
                            self._config.keep_checkpoint_every_n_hours),
                        defer_build=True,
                        save_relative_paths=True))

            chief_hooks = []
            if (self._config.save_checkpoints_secs
                    or self._config.save_checkpoints_steps):
                saver_hook_exists = any([
                    isinstance(h, training.CheckpointSaverHook)
                    for h in (all_hooks + chief_hooks +
                              list(estimator_spec.training_chief_hooks))
                ])
                if not saver_hook_exists:
                    chief_hooks = [
                        training.CheckpointSaverHook(
                            self._model_dir,
                            save_secs=self._config.save_checkpoints_secs,
                            save_steps=self._config.save_checkpoints_steps,
                            scaffold=estimator_spec.scaffold)
                    ]
            with training.MonitoredTrainingSession(
                    master=self._config.master,
                    is_chief=self._config.is_chief,
                    checkpoint_dir=self._model_dir,
                    scaffold=estimator_spec.scaffold,
                    hooks=all_hooks,
                    chief_only_hooks=(
                        tuple(chief_hooks) +
                        tuple(estimator_spec.training_chief_hooks)),
                    save_checkpoint_secs=0,  # Saving is handled by a hook.
                    save_summaries_steps=self._config.save_summary_steps,
                    config=self._session_config,
                    log_step_count_steps=self._config.log_step_count_steps
            ) as mon_sess:
                loss = None
                while not mon_sess.should_stop():
                    _, loss = mon_sess.run(
                        [estimator_spec.train_op, estimator_spec.loss])
            return loss
コード例 #8
0
ファイル: train.py プロジェクト: Jack44Wang/sgnmt
def create_training_session(config):
    """Creates a MonitoredTrainingSession for training"""
    return training.MonitoredTrainingSession(checkpoint_dir=config.output_path,
                                             save_checkpoint_secs=1200)
コード例 #9
0
def _get_workers(num_workers, period, workers, moving_rate, num_ps=1):
    sessions = []
    graphs = []
    train_ops = []
    savers = []
    for worker_id in range(num_workers):
        graph = ops.Graph()
        is_chief = (worker_id == 0)
        with graph.as_default():
            worker_device = "/job:worker/task:%d/cpu:0" % (worker_id)
            ea_custom = ElasticAverageCustomGetter(worker_device=worker_device)
            with variable_scope.variable_scope(
                    "", custom_getter=ea_custom), ops.device(
                        device_setter.replica_device_setter(
                            worker_device=worker_device,
                            ps_device="/job:ps/task:0/cpu:0",
                            ps_tasks=1)):
                global_step = training_util.get_or_create_global_step()
                var_0 = variable_scope.get_variable(initializer=0.0, name="v0")
                var_1 = variable_scope.get_variable(initializer=1.0, name="v1")
            if num_ps > 1:
                with variable_scope.variable_scope(
                        "",
                        partitioner=partitioned_variables.
                        fixed_size_partitioner(num_ps, axis=0),
                        custom_getter=ea_custom), ops.device(
                            device_setter.replica_device_setter(
                                worker_device=worker_device,
                                ps_device="/job:ps/task:0/cpu:0",
                                ps_tasks=num_ps)):

                    partition_var = variable_scope.get_variable(
                        'partition_var',
                        shape=[2, 4],
                        initializer=init_ops.ones_initializer)
                    part_0 = list(partition_var)[0]
                    part_1 = list(partition_var)[1]

            with ops.device("/job:worker/task:" + str(worker_id)):
                grads_0 = constant_op.constant(-1.0)
                grads_1 = constant_op.constant(-1.0)
                grads_part_0 = constant_op.constant([[-1., -1., -1., -1.]])
                grads_part_1 = constant_op.constant([[-1., -1., -1., -1.]])

                sgd_opt = gradient_descent.GradientDescentOptimizer(1.0)
                opt = ElasticAverageOptimizer(opt=sgd_opt,
                                              num_worker=num_workers,
                                              moving_rate=moving_rate,
                                              communication_period=period,
                                              ea_custom_getter=ea_custom)
                if num_ps == 1:
                    train_op = [
                        opt.apply_gradients(
                            ([grads_0, var_0], [grads_1, var_1]), global_step)
                    ]
                else:
                    train_op = [
                        opt.apply_gradients(
                            ([grads_0, var_0], [grads_1, var_1],
                             [grads_part_0, part_0], [grads_part_1, part_1]),
                            global_step)
                    ]
                easgd_hook = opt.make_session_run_hook(is_chief, worker_id)
                saver = opt.swapping_saver()
            # Creates MonitoredSession
            sess = training.MonitoredTrainingSession(workers[worker_id].target,
                                                     hooks=[easgd_hook])

        sessions.append(sess)
        graphs.append(graph)
        train_ops.append(train_op)
        savers.append(saver)

    return sessions, graphs, train_ops, savers
コード例 #10
0
def _get_workers(num_workers, period, workers, num_ps=1):
    sessions = []
    graphs = []
    train_ops = []
    for worker_id in range(num_workers):
        graph = ops.Graph()
        is_chief = (worker_id == 0)
        with graph.as_default():
            worker_device = "/job:worker/task:%d/cpu:0" % (worker_id)
            ps_device = device_setter.replica_device_setter(
                worker_device=worker_device,
                ps_device="/job:ps/task:0/cpu:0",
                ps_tasks=1)
            agn_getter = agn_optimizer.AGNCustomGetter(
                worker_device=worker_device)
            with variable_scope.variable_scope(
                    "", custom_getter=agn_getter), ops.device(ps_device):
                global_step = training_util.get_or_create_global_step()
                var_0 = variable_scope.get_variable(initializer=0.0, name="v0")
                var_1 = variable_scope.get_variable(initializer=0.5, name="v1")
            if num_ps > 1:
                with variable_scope.variable_scope(
                        "",
                        partitioner=partitioned_variables.
                        fixed_size_partitioner(num_ps, axis=0),
                        custom_getter=agn_getter), ops.device(ps_device):

                    partition_var = variable_scope.get_variable(
                        "partition_var",
                        shape=[2, 4],
                        initializer=init_ops.zeros_initializer)
                    part_0 = list(partition_var)[0]
                    part_1 = list(partition_var)[1]

            with ops.device("/job:worker/task:" + str(worker_id)):
                grads_0 = constant_op.constant(-1.0)
                grads_1 = constant_op.constant(-1.0)
                grads_part_0 = constant_op.constant([[-1., -1., -1., -1.]])
                grads_part_1 = constant_op.constant([[-1., -1., -1., -1.]])

                optimizer = \
                    adam.AdamOptimizer(learning_rate=0.1, beta1=0.0, beta2=0.0)
                opt = agn_optimizer.AGNOptimizer(optimizer,
                                                 num_worker=num_workers,
                                                 communication_period=period,
                                                 custom_getter=agn_getter)
                if num_ps == 1:
                    train_op = [
                        opt.apply_gradients(
                            ([grads_0, var_0], [grads_1, var_1]), global_step)
                    ]
                else:
                    train_op = [
                        opt.apply_gradients(
                            ([grads_0, var_0], [grads_1, var_1],
                             [grads_part_0, part_0], [grads_part_1, part_1]),
                            global_step)
                    ]
                hook = opt.make_session_run_hook(is_chief, worker_id)
            # Creates MonitoredSession
            sess = training.MonitoredTrainingSession(workers[worker_id].target,
                                                     hooks=[hook])

        sessions.append(sess)
        graphs.append(graph)
        train_ops.append(train_op)

    return sessions, graphs, train_ops
コード例 #11
0
    def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks,
                                   global_step_tensor, saving_listeners,
                                   save_best_ckpt):
        """Train a model with the given Estimator Spec."""
        if self._warm_start_settings:
            logging.info('Warm-starting with WarmStartSettings: %s' %
                         (self._warm_start_settings, ))
            warm_starting_util.warm_start(*self._warm_start_settings)
        worker_hooks.extend(hooks)
        worker_hooks.append(training.NanTensorHook(estimator_spec.loss))
        if self._config.log_step_count_steps is not None:
            tensors = {"loss": estimator_spec.loss, "step": global_step_tensor}
            tensors.update({
                key.replace("/", ""): val
                for key, val in estimator_spec.predictions.items()
                if "/" in key
            })
            worker_hooks.append(
                training.LoggingTensorHook(
                    tensors, every_n_iter=self._config.log_step_count_steps))
        worker_hooks.extend(estimator_spec.training_hooks)

        # Create Saver object
        if not (estimator_spec.scaffold.saver
                or ops.get_collection(ops.GraphKeys.SAVERS)):
            ops.add_to_collection(
                ops.GraphKeys.SAVERS,
                training.Saver(sharded=True,
                               max_to_keep=self._config.keep_checkpoint_max,
                               keep_checkpoint_every_n_hours=(
                                   self._config.keep_checkpoint_every_n_hours),
                               defer_build=True,
                               save_relative_paths=True))

        chief_hooks = []
        all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks)
        saver_hooks = [
            h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)
        ]
        if (self._config.save_checkpoints_secs
                or self._config.save_checkpoints_steps):
            if not saver_hooks:
                chief_hooks = [
                    training.CheckpointSaverHook(
                        self._model_dir,
                        save_secs=self._config.save_checkpoints_secs,
                        save_steps=self._config.save_checkpoints_steps,
                        scaffold=estimator_spec.scaffold)
                ]
                saver_hooks = [chief_hooks[0]]
        if saving_listeners:
            if not saver_hooks:
                raise ValueError(
                    'There should be a CheckpointSaverHook to use saving_listeners. '
                    'Please set one of the RunConfig.save_checkpoints_steps or '
                    'RunConfig.save_checkpoints_secs.')
            else:
                # It is expected to have one CheckpointSaverHook. If multiple, we pick
                # up the first one to add listener.
                saver_hooks[0]._listeners.extend(saving_listeners)  # pylint: disable=protected-access

        if self._train_with_eval:
            self.dataset_handle_hook = IteratorStringHandleHook(
                self.train_iterator, self.eval_iterator)
            worker_hooks.append(self.dataset_handle_hook)
            self._predict_keys = estimator_spec.predictions

        if save_best_ckpt:
            EvaluatorCls = self._params.get("evaluator", None)
            if not issubclass(EvaluatorCls, EvaluateBase):
                raise TypeError(
                    "Parameter `evaluator` must be a EvaluateBase instance, but got {}"
                    .format(type(EvaluatorCls)))
            eval_kwargs = self._params.get("eval_kwargs", {})
            eval_steps = self._params.get("eval_steps", 2500)
            primary_metric = self._params.get("primary_metric", None)
            secondary_metric = self._params.get("secondary_metric", None)

            # We must construct Evaluator inside a graph scope
            evaluator = EvaluatorCls(self, **eval_kwargs)

            worker_hooks.append(
                BestCheckpointSaverHook(evaluator=evaluator,
                                        checkpoint_dir=self._model_dir,
                                        compare_fn=partial(
                                            evaluator.compare,
                                            primary_metric=primary_metric,
                                            secondary_metric=secondary_metric),
                                        tag=self._params["args"].tag,
                                        save_steps=eval_steps))

        # Training session monitor
        with training.MonitoredTrainingSession(
                master=self._config.master,
                is_chief=self._config.is_chief,
                checkpoint_dir=self._model_dir,
                scaffold=estimator_spec.scaffold,
                hooks=worker_hooks,
                chief_only_hooks=(tuple(chief_hooks) +
                                  tuple(estimator_spec.training_chief_hooks)),
                save_checkpoint_secs=0,
                save_summaries_steps=self._config.save_summary_steps,
                config=self._session_config,
                log_step_count_steps=self._config.log_step_count_steps
        ) as mon_sess:
            loss = None

            # Make sure that use self.dataset_handle_hook.xxx_handle after create MonitoredSession()
            self._feed_dict = _add_key_value(
                self._feed_dict, self.handler,
                self.dataset_handle_hook.train_handle)
            while not mon_sess.should_stop():
                _, loss = mon_sess.run(
                    [estimator_spec.train_op, estimator_spec.loss],
                    self._feed_dict)
            return loss