def test_step_counter_every_n_secs(self):
        with ops.Graph().as_default() as g, session_lib.Session() as sess:
            global_step = variables.get_or_create_global_step()
            train_op = state_ops.assign_add(global_step, 1)
            summary_writer = fake_summary_writer.FakeSummaryWriter(
                self.log_dir, g)
            hook = basic_session_run_hooks.StepCounterHook(
                summary_writer=summary_writer,
                every_n_steps=None,
                every_n_secs=0.1)

            hook.begin()
            sess.run(variables_lib.global_variables_initializer())
            mon_sess = monitored_session._HookedSession(sess, [hook])
            mon_sess.run(train_op)
            time.sleep(0.2)
            mon_sess.run(train_op)
            time.sleep(0.2)
            mon_sess.run(train_op)
            hook.end(sess)

            summary_writer.assert_summaries(test_case=self,
                                            expected_logdir=self.log_dir,
                                            expected_graph=g,
                                            expected_summaries={})
            self.assertTrue(summary_writer.summaries,
                            'No summaries were created.')
            self.assertItemsEqual([2, 3], summary_writer.summaries.keys())
            for summary in summary_writer.summaries.values():
                summary_value = summary[0].value[0]
                self.assertEqual('global_step/sec', summary_value.tag)
                self.assertGreater(summary_value.simple_value, 0)
    def test_global_step_name(self):
        with ops.Graph().as_default() as g, session_lib.Session() as sess:
            with variable_scope.variable_scope('bar'):
                foo_step = variable_scope.get_variable(
                    'foo',
                    initializer=0,
                    trainable=False,
                    collections=[
                        ops.GraphKeys.GLOBAL_STEP,
                        ops.GraphKeys.GLOBAL_VARIABLES
                    ])
            train_op = state_ops.assign_add(foo_step, 1)
            summary_writer = fake_summary_writer.FakeSummaryWriter(
                self.log_dir, g)
            hook = basic_session_run_hooks.StepCounterHook(
                summary_writer=summary_writer,
                every_n_steps=1,
                every_n_secs=None)

            hook.begin()
            sess.run(variables_lib.global_variables_initializer())
            mon_sess = monitored_session._HookedSession(sess, [hook])
            mon_sess.run(train_op)
            mon_sess.run(train_op)
            hook.end(sess)

            summary_writer.assert_summaries(test_case=self,
                                            expected_logdir=self.log_dir,
                                            expected_graph=g,
                                            expected_summaries={})
            self.assertTrue(summary_writer.summaries,
                            'No summaries were created.')
            self.assertItemsEqual([2], summary_writer.summaries.keys())
            summary_value = summary_writer.summaries[2][0].value[0]
            self.assertEqual('bar/foo/sec', summary_value.tag)
  def create_model_fn_ops(self, predictions, output_alternatives,
                          mode=model_fn.ModeKeys.INFER):

    return model_fn.ModelFnOps(
        model_fn.ModeKeys.INFER,
        predictions=predictions,
        loss=constant_op.constant([1]),
        train_op=control_flow_ops.no_op(),
        eval_metric_ops={
            "metric_key": (constant_op.constant(1.), control_flow_ops.no_op()),
            "loss": (constant_op.constant(1.), control_flow_ops.no_op()),
        },
        training_chief_hooks=[basic_session_run_hooks.StepCounterHook()],
        training_hooks=[basic_session_run_hooks.StepCounterHook()],
        output_alternatives=output_alternatives,
        scaffold=monitored_session.Scaffold())
Esempio n. 4
0
 def test_step_counter_every_n_steps(self):
   with ops.Graph().as_default() as g, session_lib.Session() as sess:
     variables.get_or_create_global_step()
     train_op = training_util._increment_global_step(1)
     summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)
     hook = basic_session_run_hooks.StepCounterHook(
         summary_writer=summary_writer, every_n_steps=10)
     hook.begin()
     sess.run(variables_lib.global_variables_initializer())
     mon_sess = monitored_session._HookedSession(sess, [hook])
     with test.mock.patch.object(tf_logging, 'warning') as mock_log:
       for _ in range(30):
         time.sleep(0.01)
         mon_sess.run(train_op)
       # logging.warning should not be called.
       self.assertIsNone(mock_log.call_args)
     hook.end(sess)
     summary_writer.assert_summaries(
         test_case=self,
         expected_logdir=self.log_dir,
         expected_graph=g,
         expected_summaries={})
     self.assertItemsEqual([11, 21], summary_writer.summaries.keys())
     for step in [11, 21]:
       summary_value = summary_writer.summaries[step][0].value[0]
       self.assertEqual('global_step/sec', summary_value.tag)
       self.assertGreater(summary_value.simple_value, 0)
def MonitoredTrainingSession(
        master='',  # pylint: disable=invalid-name
        is_chief=True,
        checkpoint_dir=None,
        hooks=None,
        scaffold=None,
        config=None):
    """Creates a `MonitoredSession` for training.

  For a chief, this utility sets proper session initializer/restorer. It also
  creates hooks related to checkpoint and summary saving. For workers, this
  utility sets proper session creator which waits for the chief to
  inialize/restore.


  Args:
    master: `String` the TensorFlow master to use.
    is_chief: If `True`, it will take care of initialization and recovery the
      underlying TensorFlow session. If `False`, it will wait on a chief to
      initialize or recover the TensorFlow session.
    checkpoint_dir: A string.  Optional path to a directory where to restore
      variables.
    hooks: Optional list of `SessionRunHook` objects.
    scaffold: A `Scaffold` used for gathering or building supportive ops. If
      not specified, a default one is created. It's used to finalize the graph.
    config: `ConfigProto` proto used to configure the session.

  Returns:
    A `MonitoredSession` object.
  """
    hooks = hooks or []
    scaffold = scaffold or Scaffold()
    if not is_chief:
        session_creator = WorkerSessionCreator(scaffold=scaffold,
                                               master=master,
                                               config=config)
    else:
        session_creator = ChiefSessionCreator(scaffold=scaffold,
                                              checkpoint_dir=checkpoint_dir,
                                              master=master,
                                              config=config)
        hooks.extend([
            basic_session_run_hooks.StepCounterHook(output_dir=checkpoint_dir),
            basic_session_run_hooks.SummarySaverHook(
                scaffold=scaffold, output_dir=checkpoint_dir),
            basic_session_run_hooks.CheckpointSaverHook(checkpoint_dir,
                                                        save_secs=600,
                                                        scaffold=scaffold),
        ])

    return MonitoredSession(session_creator=session_creator, hooks=hooks)
 def model_fn(features, labels, mode):
   del features, labels
   global_step = training_util.get_global_step()
   if mode == model_fn_lib.ModeKeys.TRAIN:
     train_hook1 = basic_session_run_hooks.StepCounterHook(
         every_n_steps=1, output_dir=self.get_temp_dir())
     train_hook2 = tf.compat.v1.test.mock.MagicMock(
         wraps=tf.compat.v1.train.SessionRunHook(),
         spec=tf.compat.v1.train.SessionRunHook)
     return model_fn_lib.EstimatorSpec(
         mode,
         loss=tf.constant(1.),
         train_op=global_step.assign_add(1),
         training_hooks=[train_hook1, train_hook2])
   if mode == model_fn_lib.ModeKeys.EVAL:
     eval_hook1 = basic_session_run_hooks.StepCounterHook(
         every_n_steps=1, output_dir=self.get_temp_dir())
     eval_hook2 = tf.compat.v1.test.mock.MagicMock(
         wraps=tf.compat.v1.train.SessionRunHook(),
         spec=tf.compat.v1.train.SessionRunHook)
     return model_fn_lib.EstimatorSpec(
         mode=mode,
         loss=tf.constant(1.),
         evaluation_hooks=[eval_hook1, eval_hook2])
Esempio n. 7
0
 def test_log_warning_if_global_step_not_increased(self):
   with ops.Graph().as_default(), session_lib.Session() as sess:
     variables.get_or_create_global_step()
     train_op = training_util._increment_global_step(0)  # keep same.
     sess.run(variables_lib.global_variables_initializer())
     hook = basic_session_run_hooks.StepCounterHook(
         every_n_steps=1, every_n_secs=None)
     hook.begin()
     mon_sess = monitored_session._HookedSession(sess, [hook])
     mon_sess.run(train_op)  # Run one step to record global step.
     with test.mock.patch.object(tf_logging, 'warning') as mock_log:
       for _ in range(30):
         mon_sess.run(train_op)
       self.assertRegexpMatches(
           str(mock_log.call_args),
           'global step.*has not been increased')
     hook.end(sess)
Esempio n. 8
0
def _monitored_train(graph,
                     output_dir,
                     train_op,
                     loss_op,
                     global_step_tensor=None,
                     init_op=None,
                     init_feed_dict=None,
                     init_fn=None,
                     log_every_steps=10,
                     supervisor_is_chief=True,
                     supervisor_master='',
                     supervisor_save_model_secs=600,
                     supervisor_save_model_steps=None,
                     keep_checkpoint_max=5,
                     supervisor_save_summaries_secs=None,
                     supervisor_save_summaries_steps=100,
                     feed_fn=None,
                     steps=None,
                     fail_on_nan_loss=True,
                     hooks=None,
                     max_steps=None):
  """Train a model via monitored_session.

  Given `graph`, a directory to write outputs to (`output_dir`), and some ops,
  run a training loop. The given `train_op` performs one step of training on the
  model. The `loss_op` represents the objective function of the training. It is
  expected to increment the `global_step_tensor`, a scalar integer tensor
  counting training steps. This function uses `Supervisor` to initialize the
  graph (from a checkpoint if one is available in `output_dir`), write summaries
  defined in the graph, and write regular checkpoints as defined by
  `supervisor_save_model_secs`.

  Training continues until `global_step_tensor` evaluates to `max_steps`, or, if
  `fail_on_nan_loss`, until `loss_op` evaluates to `NaN`. In that case the
  program is terminated with exit code 1.

  Args:
    graph: A graph to train. It is expected that this graph is not in use
      elsewhere.
    output_dir: A directory to write outputs to.
    train_op: An op that performs one training step when run.
    loss_op: A scalar loss tensor.
    global_step_tensor: A tensor representing the global step. If none is given,
      one is extracted from the graph using the same logic as in `Supervisor`.
    init_op: An op that initializes the graph. If `None`, use `Supervisor`'s
      default.
    init_feed_dict: A dictionary that maps `Tensor` objects to feed values.
      This feed dictionary will be used when `init_op` is evaluated.
    init_fn: Optional callable passed to Supervisor to initialize the model.
    log_every_steps: Output logs regularly. The logs contain timing data and the
      current loss. A `0` or negative value disables logging.
    supervisor_is_chief: Whether the current process is the chief supervisor in
      charge of restoring the model and running standard services.
    supervisor_master: The master string to use when preparing the session.
    supervisor_save_model_secs: Save checkpoints every this many seconds. Can
        not be specified with `supervisor_save_model_steps`.
    supervisor_save_model_steps: Save checkpoints every this many steps. Can not
        be specified with `supervisor_save_model_secs`.
    keep_checkpoint_max: The maximum number of recent checkpoint files to
      keep. As new files are created, older files are deleted. If None or 0,
      all checkpoint files are kept. This is simply passed as the max_to_keep
      arg to `tf.Saver` constructor.
    supervisor_save_summaries_secs: Save summaries every
      `supervisor_save_summaries_secs` seconds when training.
    supervisor_save_summaries_steps: Save summaries every
      `supervisor_save_summaries_steps` steps when training. Exactly one of
      `supervisor_save_model_steps` and `supervisor_save_model_secs` should be
      specified, and the other should be None.
    feed_fn: A function that is called every iteration to produce a `feed_dict`
      passed to `session.run` calls. Optional.
    steps: Trains for this many steps (e.g. current global step + `steps`).
    fail_on_nan_loss: If true, raise `NanLossDuringTrainingError` if `loss_op`
      evaluates to `NaN`. If false, continue training as if nothing happened.
    hooks: List of `SessionRunHook` subclass instances. Used for callbacks
      inside the training loop.
    max_steps: Number of total steps for which to train model. If `None`,
      train forever. Two calls fit(steps=100) means 200 training iterations.
      On the other hand two calls of fit(max_steps=100) means, second call
      will not do any iteration since first call did all 100 steps.

  Returns:
    The final loss value.

  Raises:
    ValueError: If `output_dir`, `train_op`, `loss_op`, or `global_step_tensor`
      is not provided. See `tf.contrib.framework.get_global_step` for how we
      look up the latter if not provided explicitly.
    NanLossDuringTrainingError: If `fail_on_nan_loss` is `True`, and loss ever
      evaluates to `NaN`.
    ValueError: If both `steps` and `max_steps` are not `None`.
  """
  if (steps is not None) and (max_steps is not None):
    raise ValueError('Can not provide both steps and max_steps.')
  if not output_dir:
    raise ValueError('Output directory should be non-empty %s.' % output_dir)
  if train_op is None:
    raise ValueError('Missing train_op.')
  if loss_op is None:
    raise ValueError('Missing loss_op.')
  if hooks is None:
    hooks = []
  if not isinstance(hooks, list):
    raise ValueError('Hooks should be a list.')
  with graph.as_default():
    global_step_tensor = contrib_variables.assert_or_get_global_step(
        graph, global_step_tensor)
  if global_step_tensor is None:
    raise ValueError('No "global_step" was provided or found in the graph.')

  if max_steps is not None:
    try:
      start_step = load_variable(output_dir, global_step_tensor.name)
      if max_steps <= start_step:
        logging.info('Skipping training since max_steps has already saved.')
        return None
    except:  # pylint: disable=bare-except
      pass

  # Adapted SessionRunHooks such as ExportMonitor depend on the
  # CheckpointSaverHook to be executed before they should be executed.
  # The `hooks` param comprises of deprecated monitor hooks
  # (such as ExportMonitor). Appending them after the basic_session_run_hooks.
  all_hooks = []
  with graph.as_default():
    all_hooks.append(basic_session_run_hooks.NanTensorHook(
        loss_op, fail_on_nan_loss=fail_on_nan_loss))
    if log_every_steps > 0:
      all_hooks.append(basic_session_run_hooks.LoggingTensorHook({
          'loss': loss_op.name,
          'step': global_step_tensor.name
      }, every_n_iter=log_every_steps))

    def make_saver():
      return tf_saver.Saver(
          sharded=True, max_to_keep=keep_checkpoint_max, defer_build=True,
          write_version=saver_pb2.SaverDef.V1)

    scaffold = monitored_session.Scaffold(
        init_op=init_op,
        init_feed_dict=init_feed_dict,
        init_fn=init_fn,
        saver=monitored_session.Scaffold.get_or_default('saver',
                                                        ops.GraphKeys.SAVERS,
                                                        make_saver))

    if not supervisor_is_chief:
      session_creator = monitored_session.WorkerSessionCreator(
          scaffold=scaffold,
          master=supervisor_master)
    else:
      session_creator = monitored_session.ChiefSessionCreator(
          scaffold=scaffold,
          checkpoint_dir=output_dir,
          master=supervisor_master)
      summary_writer = summary_io.SummaryWriterCache.get(output_dir)
      all_hooks.append(
          basic_session_run_hooks.StepCounterHook(
              summary_writer=summary_writer))
      all_hooks.append(
          basic_session_run_hooks.SummarySaverHook(
              save_secs=supervisor_save_summaries_secs,
              save_steps=supervisor_save_summaries_steps,
              summary_writer=summary_writer,
              scaffold=scaffold))
      if (supervisor_save_model_secs is not None
          or supervisor_save_model_steps is not None):
        all_hooks.append(
            basic_session_run_hooks.CheckpointSaverHook(
                output_dir,
                save_secs=supervisor_save_model_secs,
                save_steps=supervisor_save_model_steps,
                scaffold=scaffold))

    if steps is not None or max_steps is not None:
      all_hooks.append(basic_session_run_hooks.StopAtStepHook(steps, max_steps))
    all_hooks.extend(hooks)

    with monitored_session.MonitoredSession(
        session_creator=session_creator,
        hooks=all_hooks) as super_sess:
      loss = None
      while not super_sess.should_stop():
        _, loss = super_sess.run([train_op, loss_op], feed_fn() if feed_fn else
                                 None)
    summary_io.SummaryWriterCache.clear()
    return loss
Esempio n. 9
0
def MonitoredTrainingSession(
        master='',  # pylint: disable=invalid-name
        is_chief=True,
        checkpoint_dir=None,
        scaffold=None,
        hooks=None,
        chief_only_hooks=None,
        save_checkpoint_secs=600,
        save_summaries_steps=100,
        save_summaries_secs=None,
        config=None,
        stop_grace_period_secs=120,
        log_step_count_steps=100):
    """Creates a `MonitoredSession` for training.

  For a chief, this utility sets proper session initializer/restorer. It also
  creates hooks related to checkpoint and summary saving. For workers, this
  utility sets proper session creator which waits for the chief to
  initialize/restore.


  Args:
    master: `String` the TensorFlow master to use.
    is_chief: If `True`, it will take care of initialization and recovery the
      underlying TensorFlow session. If `False`, it will wait on a chief to
      initialize or recover the TensorFlow session.
    checkpoint_dir: A string.  Optional path to a directory where to restore
      variables.
    scaffold: A `Scaffold` used for gathering or building supportive ops. If
      not specified, a default one is created. It's used to finalize the graph.
    hooks: Optional list of `SessionRunHook` objects.
    chief_only_hooks: list of `SessionRunHook` objects. Activate these hooks if
      `is_chief==True`, ignore otherwise.
    save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved
      using a default checkpoint saver. If `save_checkpoint_secs` is set to
      `None`, then the default checkpoint saver isn't used.
    save_summaries_steps: The frequency, in number of global steps, that the
      summaries are written to disk using a default summary saver. If both
      `save_summaries_steps` and `save_summaries_secs` are set to `None`, then
      the default summary saver isn't used.
    save_summaries_secs: The frequency, in secs, that the summaries are written
      to disk using a default summary saver.  If both `save_summaries_steps` and
      `save_summaries_secs` are set to `None`, then the default summary saver
      isn't used.
    config: an instance of `tf.ConfigProto` proto used to configure the session.
      It's the `config` argument of constructor of `tf.Session`.
    stop_grace_period_secs: Number of seconds given to threads to stop after
      `close()` has been called.
    log_step_count_steps: The frequency, in number of global steps, that the
      global step/sec is logged.

  Returns:
    A `MonitoredSession` object.
  """
    scaffold = scaffold or Scaffold()
    if not is_chief:
        session_creator = WorkerSessionCreator(scaffold=scaffold,
                                               master=master,
                                               config=config)
        return MonitoredSession(session_creator=session_creator,
                                hooks=hooks or [],
                                stop_grace_period_secs=stop_grace_period_secs)

    all_hooks = []
    if chief_only_hooks:
        all_hooks.extend(chief_only_hooks)
    session_creator = ChiefSessionCreator(scaffold=scaffold,
                                          checkpoint_dir=checkpoint_dir,
                                          master=master,
                                          config=config)

    if checkpoint_dir:
        all_hooks.append(
            basic_session_run_hooks.StepCounterHook(
                output_dir=checkpoint_dir, every_n_steps=log_step_count_steps))

        if (save_summaries_steps
                and save_summaries_steps > 0) or (save_summaries_secs
                                                  and save_summaries_secs > 0):
            all_hooks.append(
                basic_session_run_hooks.SummarySaverHook(
                    scaffold=scaffold,
                    save_steps=save_summaries_steps,
                    save_secs=save_summaries_secs,
                    output_dir=checkpoint_dir))
        if save_checkpoint_secs and save_checkpoint_secs > 0:
            all_hooks.append(
                basic_session_run_hooks.CheckpointSaverHook(
                    checkpoint_dir,
                    save_secs=save_checkpoint_secs,
                    scaffold=scaffold))

    if hooks:
        all_hooks.extend(hooks)
    return MonitoredSession(session_creator=session_creator,
                            hooks=all_hooks,
                            stop_grace_period_secs=stop_grace_period_secs)
Esempio n. 10
0
def train(train_op,
          logdir,
          master='',
          is_chief=True,
          scaffold=None,
          hooks=None,
          chief_only_hooks=None,
          save_checkpoint_secs=600,
          save_summaries_steps=100,
          config=None):
    """Runs the training loop.

  Args:
    train_op: A `Tensor` that, when executed, will apply the gradients and
      return the loss value.
    logdir: The directory where the graph and checkpoints are saved.
    master: The URL of the master.
    is_chief: Specifies whether or not the training is being run by the primary
      replica during replica training.
    scaffold: An tf.train.Scaffold instance.
    hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the
      training loop.
    chief_only_hooks: List of `tf.train.SessionRunHook` instances which are run
      inside the training loop for the chief trainer only.
    save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved
      using a default checkpoint saver. If `save_checkpoint_secs` is set to
      `None`, then the default checkpoint saver isn't used.
    save_summaries_steps: The frequency, in number of global steps, that the
      summaries are written to disk using a default summary saver. If
      `save_summaries_steps` is set to `None`, then the default summary saver
      isn't used.
    config: An instance of `tf.ConfigProto`.

  Returns:
    the value of the loss function after training.

  Raises:
    ValueError: if `logdir` is `None` and either `save_checkpoint_secs` or
    `save_summaries_steps` are `None.
  """
    # TODO(nsilberman): move this logic into monitored_session.py
    scaffold = scaffold or monitored_session.Scaffold()

    hooks = hooks or []

    if is_chief:
        session_creator = monitored_session.ChiefSessionCreator(
            scaffold=scaffold,
            checkpoint_dir=logdir,
            master=master,
            config=config)

        if chief_only_hooks:
            hooks.extend(chief_only_hooks)

        hooks.append(
            basic_session_run_hooks.StepCounterHook(output_dir=logdir))

        if save_summaries_steps:
            if logdir is None:
                raise ValueError(
                    'logdir cannot be None when save_summaries_steps is None')
            hooks.append(
                basic_session_run_hooks.SummarySaverHook(
                    scaffold=scaffold,
                    save_steps=save_summaries_steps,
                    output_dir=logdir))

        if save_checkpoint_secs:
            if logdir is None:
                raise ValueError(
                    'logdir cannot be None when save_checkpoint_secs is None')
            hooks.append(
                basic_session_run_hooks.CheckpointSaverHook(
                    logdir, save_secs=save_checkpoint_secs, scaffold=scaffold))
    else:
        session_creator = monitored_session.WorkerSessionCreator(
            scaffold=scaffold, master=master, config=config)

    with monitored_session.MonitoredSession(session_creator=session_creator,
                                            hooks=hooks) as session:
        loss = None
        while not session.should_stop():
            loss = session.run(train_op)
    return loss
Esempio n. 11
0
def PartialRestoreSession(
        master='',  # pylint: disable=invalid-name
        is_chief=True,
        checkpoint_dir=None,
        restore_var_list=None,
        scaffold=None,
        hooks=None,
        chief_only_hooks=None,
        save_checkpoint_secs=600,
        save_summaries_steps=monitored_session.USE_DEFAULT,
        save_summaries_secs=monitored_session.USE_DEFAULT,
        config=None,
        stop_grace_period_secs=120,
        log_step_count_steps=100):
    """Creates a `MonitoredSession` for training.

    Supports partial restoration from checkpoints with parameter
    `restore_var_list`, by adding `CheckpointRestorerHook`.

  For a chief, this utility sets proper session initializer/restorer. It also
  creates hooks related to checkpoint and summary saving. For workers, this
  utility sets proper session creator which waits for the chief to
  initialize/restore. Please check `tf.train.MonitoredSession` for more
  information.


  Args:
    master: `String` the TensorFlow master to use.
    is_chief: If `True`, it will take care of initialization and recovery the
      underlying TensorFlow session. If `False`, it will wait on a chief to
      initialize or recover the TensorFlow session.
    checkpoint_dir: A string.  Optional path to a directory where to restore
      variables.
    restore_var_list: a list of variables, optional, if not all variables should
      be recovered from checkpoint.
      Useful when changing network structures during training, i.e., finetuning
      a pretrained model with new layers.
    scaffold: A `Scaffold` used for gathering or building supportive ops. If
      not specified, a default one is created. It's used to finalize the graph.
    hooks: Optional list of `SessionRunHook` objects.
    chief_only_hooks: list of `SessionRunHook` objects. Activate these hooks if
      `is_chief==True`, ignore otherwise.
    save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved
      using a default checkpoint saver. If `save_checkpoint_secs` is set to
      `None`, then the default checkpoint saver isn't used.
    save_summaries_steps: The frequency, in number of global steps, that the
      summaries are written to disk using a default summary saver. If both
      `save_summaries_steps` and `save_summaries_secs` are set to `None`, then
      the default summary saver isn't used. Default 100.
    save_summaries_secs: The frequency, in secs, that the summaries are written
      to disk using a default summary saver.  If both `save_summaries_steps` and
      `save_summaries_secs` are set to `None`, then the default summary saver
      isn't used. Default not enabled.
    config: an instance of `tf.ConfigProto` proto used to configure the session.
      It's the `config` argument of constructor of `tf.Session`.
    stop_grace_period_secs: Number of seconds given to threads to stop after
      `close()` has been called.
    log_step_count_steps: The frequency, in number of global steps, that the
      global step/sec is logged.

  Returns:
    A `MonitoredSession` object.
  """
    if save_summaries_steps == monitored_session.USE_DEFAULT \
            and save_summaries_secs == monitored_session.USE_DEFAULT:
        save_summaries_steps = 100
        save_summaries_secs = None
    elif save_summaries_secs == monitored_session.USE_DEFAULT:
        save_summaries_secs = None
    elif save_summaries_steps == monitored_session.USE_DEFAULT:
        save_summaries_steps = None

    scaffold = scaffold or monitored_session.Scaffold()
    if not is_chief:
        session_creator = monitored_session.WorkerSessionCreator(
            scaffold=scaffold, master=master, config=config)
        return monitored_session.MonitoredSession(
            session_creator=session_creator,
            hooks=hooks or [],
            stop_grace_period_secs=stop_grace_period_secs)

    all_hooks = []
    if chief_only_hooks:
        all_hooks.extend(chief_only_hooks)
    if restore_var_list is None:
        restore_checkpoint_dir = checkpoint_dir
    else:
        restore_checkpoint_dir = None
        all_hooks.append(
            CheckpointRestorerHook(checkpoint_dir, var_list=restore_var_list))
        all_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
        missing_vars = filter(lambda v: not (v in restore_var_list), all_vars)
        logging.warning("MonitoredTrainingSession not restoring %s",
                        missing_vars)
    session_creator = monitored_session.ChiefSessionCreator(
        scaffold=scaffold,
        checkpoint_dir=restore_checkpoint_dir,
        master=master,
        config=config)

    if checkpoint_dir:
        all_hooks.append(
            basic_session_run_hooks.StepCounterHook(
                output_dir=checkpoint_dir, every_n_steps=log_step_count_steps))

        if (save_summaries_steps
                and save_summaries_steps > 0) or (save_summaries_secs
                                                  and save_summaries_secs > 0):
            all_hooks.append(
                basic_session_run_hooks.SummarySaverHook(
                    scaffold=scaffold,
                    save_steps=save_summaries_steps,
                    save_secs=save_summaries_secs,
                    output_dir=checkpoint_dir))
        if save_checkpoint_secs and save_checkpoint_secs > 0:
            all_hooks.append(
                basic_session_run_hooks.CheckpointSaverHook(
                    checkpoint_dir,
                    save_secs=save_checkpoint_secs,
                    scaffold=scaffold))

    if hooks:
        all_hooks.extend(hooks)
    return monitored_session.MonitoredSession(
        session_creator=session_creator,
        hooks=all_hooks,
        stop_grace_period_secs=stop_grace_period_secs)
Esempio n. 12
0
def MonitoredTrainingSession(
        master='',  # pylint: disable=invalid-name
        is_chief=True,
        checkpoint_dir=None,
        scaffold=None,
        hooks=None,
        chief_only_hooks=None,
        save_checkpoint_secs=None,
        save_summaries_steps=None,
        save_summaries_secs=None,
        config=None,
        stop_grace_period_secs=120,
        log_step_count_steps=100,
        max_wait_secs=7200,
        save_checkpoint_steps=None):
    """Creates a `MonitoredSession` for training.

  For a chief, this utility sets proper session initializer/restorer. It also
  creates hooks related to checkpoint and summary saving. For workers, this
  utility sets proper session creator which waits for the chief to
  initialize/restore. Please check `tf.train.MonitoredSession` for more
  information.


  Args:
    master: `String` the TensorFlow master to use.
    is_chief: If `True`, it will take care of initialization and recovery the
      underlying TensorFlow session. If `False`, it will wait on a chief to
      initialize or recover the TensorFlow session.
    checkpoint_dir: A string.  Optional path to a directory where to restore
      variables.
    scaffold: A `Scaffold` used for gathering or building supportive ops. If
      not specified, a default one is created. It's used to finalize the graph.
    hooks: Optional list of `SessionRunHook` objects.
    chief_only_hooks: list of `SessionRunHook` objects. Activate these hooks if
      `is_chief==True`, ignore otherwise.
    save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved
      using a default checkpoint saver. If both `save_checkpoint_steps` and
      `save_checkpoint_secs` are set to `None`, then the default checkpoint
      saver isn't used. If both are provided, then only `save_checkpoint_secs`
      is used. Default 600.
    save_summaries_steps: The frequency, in number of global steps, that the
      summaries are written to disk using a default summary saver. If both
      `save_summaries_steps` and `save_summaries_secs` are set to `None`, then
      the default summary saver isn't used. Default 100.
    save_summaries_secs: The frequency, in secs, that the summaries are written
      to disk using a default summary saver.  If both `save_summaries_steps` and
      `save_summaries_secs` are set to `None`, then the default summary saver
      isn't used. Default not enabled.
    config: an instance of `tf.ConfigProto` proto used to configure the session.
      It's the `config` argument of constructor of `tf.Session`.
    stop_grace_period_secs: Number of seconds given to threads to stop after
      `close()` has been called.
    log_step_count_steps: The frequency, in number of global steps, that the
      global step/sec is logged.
    max_wait_secs: Maximum time workers should wait for the session to
      become available. This should be kept relatively short to help detect
      incorrect code, but sometimes may need to be increased if the chief takes
      a while to start up.
    save_checkpoint_steps: The frequency, in number of global steps, that a
      checkpoint is saved using a default checkpoint saver. If both
      `save_checkpoint_steps` and `save_checkpoint_secs` are set to `None`, then
      the default checkpoint saver isn't used. If both are provided, then only
      `save_checkpoint_secs` is used. Default not enabled.

  Returns:
    A `MonitoredSession` object.
  """

    save_checkpoint_secs = 1800
    save_checkpoint_steps = None

    all_hooks = []
    if chief_only_hooks:
        all_hooks.extend(chief_only_hooks)
    session_creator = tf.train.ChiefSessionCreator(
        scaffold=scaffold,
        checkpoint_dir=checkpoint_dir,
        master=master,
        config=config)

    if checkpoint_dir:
        if log_step_count_steps and log_step_count_steps > 0:
            all_hooks.append(
                basic_session_run_hooks.StepCounterHook(
                    output_dir=checkpoint_dir,
                    every_n_steps=log_step_count_steps,
                    summary_writer=tf.summary.FileWriter(checkpoint_dir)))

        if (save_checkpoint_secs and save_checkpoint_secs > 0) or (
                save_checkpoint_steps and save_checkpoint_steps > 0):
            all_hooks.append(
                nabu_hooks.CheckpointSaverHook(
                    checkpoint_dir,
                    save_steps=save_checkpoint_steps,
                    save_secs=save_checkpoint_secs,
                    scaffold=scaffold))

    if hooks:
        all_hooks.extend(hooks)
    return tf.train.MonitoredSession(
        session_creator=session_creator,
        hooks=all_hooks,
        stop_grace_period_secs=stop_grace_period_secs)