コード例 #1
0
    def test_num_steps(self):
        logdir = self._test_dir('test_num_steps')
        with tf.Graph().as_default():
            gstep = tf.contrib.framework.get_or_create_global_step()
            do_step = tf.assign_add(gstep, 1)
            scaffold = supervised_session.Scaffold()
            # Do 3 steps and save.
            monitors = [tf.contrib.learn.monitors.StopAtStep(num_steps=3)]
            with supervised_session.SupervisedSession(
                    '', scaffold=scaffold, monitors=monitors) as session:
                session.run(do_step)
                self.assertFalse(session.should_stop())
                session.run(do_step)
                self.assertFalse(session.should_stop())
                session.run(do_step)
                self.assertTrue(session.should_stop())
                save_path = scaffold.saver.save(session.session,
                                                os.path.join(logdir, 'step-3'))
            # Restore and do 4 steps.
            def load_ckpt(scaffold, sess):
                scaffold.saver.restore(sess, save_path)

            scaffold = supervised_session.Scaffold(init_fn=load_ckpt)
            monitors = [tf.contrib.learn.monitors.StopAtStep(num_steps=4)]
            with supervised_session.SupervisedSession(
                    '', scaffold=scaffold, monitors=monitors) as session:
                self.assertEqual(3, session.run(gstep))
                session.run(do_step)
                self.assertFalse(session.should_stop())
                session.run(do_step)
                self.assertFalse(session.should_stop())
                session.run(do_step)
                self.assertFalse(session.should_stop())
                session.run(do_step)
                self.assertTrue(session.should_stop())
コード例 #2
0
 def test_recovery(self):
     logdir = self._test_dir('test_recovery')
     with tf.Graph().as_default():
         gstep = tf.contrib.framework.get_or_create_global_step()
         do_step = tf.assign_add(gstep, 1)
         scaffold = supervised_session.Scaffold()
         # Use a monitor to save the model every 100 steps.  It also saves it at
         # the end.
         monitors = [
             tf.contrib.learn.monitors.CheckpointSaver(logdir,
                                                       save_steps=1,
                                                       scaffold=scaffold)
         ]
         with supervised_session.SupervisedSession(
                 '',
                 scaffold=scaffold,
                 checkpoint_dir=logdir,
                 monitors=monitors) as session:
             self.assertEqual(0, session.run(gstep))
             self.assertEqual(1, session.run(do_step))
             self.assertEqual(2, session.run(do_step))
         # A restart will find the checkpoint and recover automatically.
         with supervised_session.SupervisedSession(
                 '', scaffold=scaffold, checkpoint_dir=logdir) as session:
             self.assertEqual(2, session.run(gstep))
コード例 #3
0
 def test_last_step(self):
   logdir = self._test_dir('test_last_step')
   with tf.Graph().as_default():
     scaffold = supervised_session.Scaffold()
     gstep = scaffold.global_step_tensor
     do_step = tf.assign_add(gstep, 1)
     # Run till step 3 and save.
     monitors = [tf.contrib.learn.monitors.StopAtStep(last_step=3)]
     with supervised_session.SupervisedSession('', scaffold=scaffold,
                                               monitors=monitors) as session:
       self.assertEqual(0, session.run(gstep))
       self.assertFalse(session.should_stop())
       self.assertEqual(1, session.run(do_step))
       self.assertFalse(session.should_stop())
       self.assertEqual(2, session.run(do_step))
       self.assertFalse(session.should_stop())
       self.assertEqual(3, session.run(do_step))
       self.assertTrue(session.should_stop())
       save_path = scaffold.saver.save(session.session,
                                       os.path.join(logdir, 'step-3'))
     # Run till step 5 and save.
     def load_ckpt(scaffold, sess):
       scaffold.saver.restore(sess, save_path)
     scaffold = supervised_session.Scaffold(init_fn=load_ckpt)
     monitors = [tf.contrib.learn.monitors.StopAtStep(last_step=5)]
     with supervised_session.SupervisedSession('', scaffold=scaffold,
                                               monitors=monitors) as session:
       self.assertEqual(3, session.run(gstep))
       self.assertFalse(session.should_stop())
       self.assertEqual(4, session.run(do_step))
       self.assertFalse(session.should_stop())
       self.assertEqual(5, session.run(do_step))
       self.assertTrue(session.should_stop())
コード例 #4
0
 def test_recover_and_retry_on_aborted_error(self):
     # Tests that we silently retry and recover on abort.  This test uses
     # a CheckpointSaver to have something to recover from.
     logdir = self._test_dir('test_recover_and_retry_on_aborted_error')
     with tf.Graph().as_default():
         gstep = tf.contrib.framework.get_or_create_global_step()
         do_step = tf.assign_add(gstep, 1)
         scaffold = supervised_session.Scaffold()
         abort_monitor = RaiseOnceAtStepN(
             3, tf.errors.AbortedError(None, None, 'Abort'))
         # Save after each step.
         ckpt_monitor = tf.contrib.learn.monitors.CheckpointSaver(
             logdir, save_steps=1, scaffold=scaffold)
         monitors = [abort_monitor, ckpt_monitor]
         with supervised_session.SupervisedSession(
                 '',
                 scaffold=scaffold,
                 checkpoint_dir=logdir,
                 monitors=monitors) as session:
             self.assertEqual(0, session.run(gstep))
             self.assertEqual(1, session.run(do_step))
             self.assertEqual(2, session.run(do_step))
             self.assertFalse(session.should_stop())
             # Here at step 3, the monitor triggers and raises AbortedError.  The
             # SupervisedSession automatically restores and retries.
             self.assertEqual(3, session.run(do_step))
             self.assertTrue(abort_monitor.raised)
             self.assertFalse(session.should_stop())
             self.assertEqual(4, session.run(do_step))
             self.assertFalse(session.should_stop())
コード例 #5
0
 def test_retry_on_aborted_error(self):
     # Tests that we silently retry on abort.  Note that this does not test
     # recovery as we do not use a CheckpointSaver in this test.
     with tf.Graph().as_default():
         gstep = tf.contrib.framework.get_or_create_global_step()
         do_step = tf.assign_add(gstep, 1)
         scaffold = supervised_session.Scaffold()
         monitor = RaiseOnceAtStepN(
             3, tf.errors.AbortedError(None, None, 'Abort'))
         with supervised_session.SupervisedSession('',
                                                   scaffold=scaffold,
                                                   monitors=[monitor
                                                             ]) as session:
             self.assertEqual(0, session.run(gstep))
             self.assertEqual(1, session.run(do_step))
             self.assertEqual(2, session.run(do_step))
             self.assertFalse(session.should_stop())
             # Here at step 3, the monitor triggers and raises AbortedError.  The
             # SupervisedSession automatically retries and restart from a freshly
             # initialized session, so the step is back to 0 and running do_step
             # moves it to 1.
             self.assertEqual(1, session.run(do_step))
             self.assertFalse(session.should_stop())
             self.assertTrue(monitor.raised)
             self.assertEqual(2, session.run(do_step))
             self.assertFalse(session.should_stop())
コード例 #6
0
 def test_stop_cleanly_when_no_exception_in_with_body(self):
   # Tests that regular exceptions pass through
   with tf.Graph().as_default():
     scaffold = supervised_session.Scaffold()
     gstep = scaffold.global_step_tensor
     do_step = tf.assign_add(gstep, 1)
     session = supervised_session.SupervisedSession('', scaffold=scaffold)
     with session:
       self.assertEqual(1, session.run(do_step))
       self.assertEqual(2, session.run(do_step))
       self.assertFalse(session.should_stop())
     # Should have closed.
     self.assertTrue(session.should_stop())
     self.assertTrue(session._is_closed())
コード例 #7
0
 def test_stop_cleanly_on_custom_exception_in_with_body(self):
   with tf.Graph().as_default():
     scaffold = supervised_session.Scaffold()
     gstep = scaffold.global_step_tensor
     do_step = tf.assign_add(gstep, 1)
     exception_types = [tf.errors.OutOfRangeError, StopIteration]
     session = supervised_session.SupervisedSession(
         '', scaffold=scaffold, clean_stop_exception_types=exception_types)
     with session:
       self.assertEqual(1, session.run(do_step))
       self.assertEqual(2, session.run(do_step))
       self.assertFalse(session.should_stop())
       raise StopIteration('EOI')
     # Should have closed.
     self.assertTrue(session.should_stop())
     self.assertTrue(session._is_closed())
コード例 #8
0
 def test_raises_regular_exceptions_in_with_body(self):
   # Tests that regular exceptions in "with body" are seen outside.
   with tf.Graph().as_default():
     scaffold = supervised_session.Scaffold()
     gstep = scaffold.global_step_tensor
     do_step = tf.assign_add(gstep, 1)
     session = supervised_session.SupervisedSession('', scaffold=scaffold)
     # We should see that exception.
     with self.assertRaisesRegexp(RuntimeError, 'regular exception'):
       with session:
         self.assertEqual(1, session.run(do_step))
         self.assertEqual(2, session.run(do_step))
         self.assertFalse(session.should_stop())
         # Will be visible outside the "with body".
         raise RuntimeError('regular exception')
     # Should have closed.
     self.assertTrue(session.should_stop())
     self.assertTrue(session._is_closed())
コード例 #9
0
 def test_exit_cleanly_on_stop_iteration_exception(self):
     # Tests that we stop cleanly when OutOfRange is raised.
     with tf.Graph().as_default():
         gstep = tf.contrib.framework.get_or_create_global_step()
         do_step = tf.assign_add(gstep, 1)
         scaffold = supervised_session.Scaffold()
         monitor = RaiseOnceAtStepN(1, StopIteration)
         session = supervised_session.SupervisedSession('',
                                                        scaffold=scaffold,
                                                        monitors=[monitor])
         # session should cleanly exit from the context.
         with session:
             self.assertEqual(0, session.run(gstep))
             self.assertFalse(session.should_stop())
             # Here at step 1, the monitor triggers and raises StopIteration. The
             # session should go into should_stop() mode. It should raise the
             # exception. So next step should not be executed.
             session.run(do_step)
             self.assertTrue(False)
         self.assertTrue(session.should_stop())
コード例 #10
0
 def test_stop_cleanly_on_out_of_range_exception(self):
   # Tests that we stop cleanly when OutOfRange is raised.
   with tf.Graph().as_default():
     scaffold = supervised_session.Scaffold()
     gstep = scaffold.global_step_tensor
     do_step = tf.assign_add(gstep, 1)
     monitor = RaiseOnceAtStepN(
         3, tf.errors.OutOfRangeError(None, None, 'EOI'))
     with supervised_session.SupervisedSession('', scaffold=scaffold,
                                               monitors=[monitor]) as session:
       self.assertEqual(0, session.run(gstep))
       self.assertEqual(1, session.run(do_step))
       self.assertEqual(2, session.run(do_step))
       self.assertFalse(session.should_stop())
       # Here at step 3, the monitor triggers and raises OutOfRange.  The
       # session should go into should_stop() mode.  We do not get a result
       # in that case.
       self.assertEqual(None, session.run(do_step))
       self.assertTrue(monitor.raised)
       self.assertTrue(session.should_stop())
コード例 #11
0
 def test_regular_exception_pass_through_in_with_body(self):
   # Tests that regular exceptions just pass through a "with
   # SupervisedSession" block and set the session in stop mode.
   with tf.Graph().as_default():
     scaffold = supervised_session.Scaffold()
     gstep = scaffold.global_step_tensor
     do_step = tf.assign_add(gstep, 1)
     monitor = RaiseOnceAtStepN(3, RuntimeError('regular exception'))
     session = supervised_session.SupervisedSession('', scaffold=scaffold,
                                                    monitors=[monitor])
     with self.assertRaisesRegexp(RuntimeError, 'regular exception'):
       with session:
         self.assertEqual(0, session.run(gstep))
         self.assertEqual(1, session.run(do_step))
         self.assertEqual(2, session.run(do_step))
         self.assertFalse(session.should_stop())
         # This triggers the monitor and raises the exception
         session.run(do_step)
         # We should not hit this
         self.assertFalse(True)
     self.assertTrue(monitor.raised)
     self.assertTrue(session.should_stop())
コード例 #12
0
 def test_stop_cleanly_on_custom_exception(self):
   # Tests that we stop cleanly when an exception type of
   # our choice is raised (StopIteration here.)
   with tf.Graph().as_default():
     scaffold = supervised_session.Scaffold()
     gstep = scaffold.global_step_tensor
     do_step = tf.assign_add(gstep, 1)
     monitor = RaiseOnceAtStepN(3, StopIteration('I choose you'))
     exception_types = [tf.errors.OutOfRangeError, StopIteration]
     with supervised_session.SupervisedSession(
         '', scaffold=scaffold,
         monitors=[monitor],
         clean_stop_exception_types=exception_types) as session:
       self.assertEqual(0, session.run(gstep))
       self.assertEqual(1, session.run(do_step))
       self.assertEqual(2, session.run(do_step))
       self.assertFalse(session.should_stop())
       # Here at step 3, the monitor triggers and raises StopIteration.  The
       # session should go into should_stop() mode.  We do not get a result
       # in that case.
       self.assertEqual(None, session.run(do_step))
       self.assertTrue(monitor.raised)
       self.assertTrue(session.should_stop())
コード例 #13
0
 def test_defaults(self):
     with tf.Graph().as_default():
         with supervised_session.SupervisedSession('') as session:
             self.assertEqual(
                 0, session.run(session.scaffold.global_step_tensor))
コード例 #14
0
def _supervised_train(graph,
                      output_dir,
                      train_op,
                      loss_op,
                      global_step_tensor=None,
                      init_op=None,
                      init_feed_dict=None,
                      init_fn=None,
                      log_every_steps=10,
                      supervisor_is_chief=True,
                      supervisor_master='',
                      supervisor_save_model_secs=600,
                      keep_checkpoint_max=5,
                      supervisor_save_summaries_steps=100,
                      feed_fn=None,
                      steps=None,
                      fail_on_nan_loss=True,
                      monitors=None,
                      max_steps=None):
    """Train a model via supervised_session.

  Given `graph`, a directory to write outputs to (`output_dir`), and some ops,
  run a training loop. The given `train_op` performs one step of training on the
  model. The `loss_op` represents the objective function of the training. It is
  expected to increment the `global_step_tensor`, a scalar integer tensor
  counting training steps. This function uses `Supervisor` to initialize the
  graph (from a checkpoint if one is available in `output_dir`), write summaries
  defined in the graph, and write regular checkpoints as defined by
  `supervisor_save_model_secs`.

  Training continues until `global_step_tensor` evaluates to `max_steps`, or, if
  `fail_on_nan_loss`, until `loss_op` evaluates to `NaN`. In that case the
  program is terminated with exit code 1.

  Args:
    graph: A graph to train. It is expected that this graph is not in use
      elsewhere.
    output_dir: A directory to write outputs to.
    train_op: An op that performs one training step when run.
    loss_op: A scalar loss tensor.
    global_step_tensor: A tensor representing the global step. If none is given,
      one is extracted from the graph using the same logic as in `Supervisor`.
    init_op: An op that initializes the graph. If `None`, use `Supervisor`'s
      default.
    init_feed_dict: A dictionary that maps `Tensor` objects to feed values.
      This feed dictionary will be used when `init_op` is evaluated.
    init_fn: Optional callable passed to Supervisor to initialize the model.
    log_every_steps: Output logs regularly. The logs contain timing data and the
      current loss.
    supervisor_is_chief: Whether the current process is the chief supervisor in
      charge of restoring the model and running standard services.
    supervisor_master: The master string to use when preparing the session.
    supervisor_save_model_secs: Save a checkpoint every
      `supervisor_save_model_secs` seconds when training.
    keep_checkpoint_max: The maximum number of recent checkpoint files to
      keep. As new files are created, older files are deleted. If None or 0,
      all checkpoint files are kept. This is simply passed as the max_to_keep
      arg to tf.Saver constructor.
    supervisor_save_summaries_steps: Save summaries every
      `supervisor_save_summaries_steps` seconds when training.
    feed_fn: A function that is called every iteration to produce a `feed_dict`
      passed to `session.run` calls. Optional.
    steps: Trains for this many steps (e.g. current global step + `steps`).
    fail_on_nan_loss: If true, raise `NanLossDuringTrainingError` if `loss_op`
      evaluates to `NaN`. If false, continue training as if nothing happened.
    monitors: List of `BaseMonitor` subclass instances. Used for callbacks
      inside the training loop.
    max_steps: Number of total steps for which to train model. If `None`,
      train forever. Two calls fit(steps=100) means 200 training iterations.
      On the other hand two calls of fit(max_steps=100) means, second call
      will not do any iteration since first call did all 100 steps.

  Returns:
    The final loss value.

  Raises:
    ValueError: If `output_dir`, `train_op`, `loss_op`, or `global_step_tensor`
      is not provided. See `tf.contrib.framework.get_global_step` for how we
      look up the latter if not provided explicitly.
    NanLossDuringTrainingError: If `fail_on_nan_loss` is `True`, and loss ever
      evaluates to `NaN`.
    ValueError: If both `steps` and `max_steps` are not `None`.
  """
    if (steps is not None) and (max_steps is not None):
        raise ValueError('Can not provide both steps and max_steps.')
    if not output_dir:
        raise ValueError('Output directory should be non-empty %s.' %
                         output_dir)
    if train_op is None:
        raise ValueError('Missing train_op.')
    if loss_op is None:
        raise ValueError('Missing loss_op.')
    if monitors is None:
        monitors = []
    if not isinstance(monitors, list):
        raise ValueError('Monitors should be a list.')
    with graph.as_default():
        global_step_tensor = contrib_variables.assert_or_get_global_step(
            graph, global_step_tensor)
    if global_step_tensor is None:
        raise ValueError(
            'No "global_step" was provided or found in the graph.')

    if max_steps is not None:
        try:
            start_step = checkpoints.load_variable(output_dir,
                                                   global_step_tensor.name)
            if max_steps <= start_step:
                logging.info(
                    'Skipping training since max_steps has already saved.')
                return None
        except:  # pylint: disable=bare-except
            pass

    with graph.as_default():
        # See question about adding the summary writer to the scaffold.
        if supervisor_is_chief:
            summary_writer = summary_writer_cache.SummaryWriterCache.get(
                output_dir)
            monitors.extend([
                monitors_lib.StepCounter(summary_writer=summary_writer),
                monitors_lib.NanLoss(loss_op,
                                     fail_on_nan_loss=fail_on_nan_loss),
                monitors_lib.PrintTensor({'loss': loss_op.name},
                                         every_n=log_every_steps),
            ])

        # Finalize graph and add savers
        # TODO(ispir): remove keep_checkpoint_max from Scaffold interface
        scaffold = supervised_session.Scaffold(
            global_step_tensor=global_step_tensor,
            init_op=init_op,
            init_feed_dict=init_feed_dict,
            init_fn=init_fn,
            keep_checkpoint_max=keep_checkpoint_max)
        if supervisor_is_chief:
            if scaffold.summary_op is not None:
                monitors.append(
                    monitors_lib.SummarySaver(
                        scaffold.summary_op,
                        save_steps=supervisor_save_summaries_steps,
                        summary_writer=summary_writer))
            if supervisor_save_model_secs:
                monitors.append(
                    monitors_lib.CheckpointSaver(
                        # Make CheckpointSaver use a timer or change arg to be steps.
                        3 * supervisor_save_summaries_steps,
                        scaffold.saver,
                        output_dir))

        if steps is not None or max_steps is not None:
            monitors.append(monitors_lib.StopAtStep(steps, max_steps))
        if not supervisor_is_chief:
            # Prune list of monitor to the ones runnable on all workers.
            monitors = [
                monitor for monitor in monitors if monitor.run_on_all_workers
            ]

        with supervised_session.SupervisedSession(
                supervisor_master,
                is_chief=supervisor_is_chief,
                checkpoint_dir=output_dir,
                monitors=monitors,
                scaffold=scaffold) as super_sess:
            loss = None
            while not super_sess.should_stop():
                _, loss = super_sess.run([train_op, loss_op],
                                         feed_fn() if feed_fn else None)
            return loss