Example #1
0
    def test_invalid_parameters_on_construction(self):
        """Tests invalid parameters on construction."""
        with self.assertRaises(ValueError) as ve:
            tpu_estimator._TPUStopAtStepHook(util_lib.IterationsPerLoopCounter(
                value=10, unit='count'),
                                             num_steps=None,
                                             final_step=None)
            self.assertEqual(
                ve.exception.message,
                'One of num_steps or final_step must be specified.')

        with self.assertRaises(ValueError) as ve:
            tpu_estimator._TPUStopAtStepHook(util_lib.IterationsPerLoopCounter(
                value=10, unit='count'),
                                             num_steps=10,
                                             final_step=100)
            self.assertEqual(
                ve.exception.message,
                'Only one of num_steps or final_step can be specified.')

        with self.assertRaises(ValueError) as ve:
            tpu_estimator._TPUStopAtStepHook(util_lib.IterationsPerLoopCounter(
                value=10, unit='secs'),
                                             num_steps=10,
                                             final_step=100)
            self.assertEqual(
                ve.exception.message,
                'Only `count` or `seconds` are accepted as the `iterations_per_loop` '
                'unit.')
  def _validate_initialization(self, iterations_per_loop_counter, num_steps):
    with tf.Session() as sess:
      global_step_tensor = training_util.get_or_create_global_step(sess.graph)
      global_step_tensor.load(0, session=sess)
      self.assertEqual(sess.run(global_step_tensor), 0)

      hook = tpu_estimator._TPUStopAtStepHook(
          iterations_per_loop_counter, num_steps=num_steps)
      self.assertEqual(1, hook._next_iteration_count)
      self.assertEqual(num_steps, hook._num_steps)
      self.assertEqual(None, hook._final_step)
      self.assertEqual(iterations_per_loop_counter.value,
                       hook._iterations_per_loop_counter.value)
      self.assertEqual(iterations_per_loop_counter.unit,
                       hook._iterations_per_loop_counter.unit)
      if iterations_per_loop_counter.unit == 'count':
        with self.assertRaises(AttributeError) as ve:
          _ = hook.iteration_count_estimator
          self.assertIn('object has no attribute', ve.message)
      else:
        self.assertIsInstance(hook._iteration_count_estimator,
                              iteration_count_estimator.IterationCountEstimator)
Example #3
0
    def _validate_hook_life_cycle(self, iterations_per_loop_counter,
                                  num_steps):
        """Test execute hook life-cycle.

    This test validates:
    - Correctly updating the iterations both for `iterations_per_loop_counter`
      specified as both `count` and `seconds`
    - Terminates the session.run() by signaling termination `request_stop()`
    - The computation of the final iterations count when the remaining step
      count is smaller than the iterations_per_loop_counter.value.

    Args:
      iterations_per_loop_counter: This is the number of train steps running in
        TPU before returning to CPU host for each `Session.run`. Can be
        specified as `count` or `seconds`.
      num_steps: Number of steps to execute.
    """
        with self.test_session() as sess:
            global_step_tensor = tf.compat.v1.train.get_or_create_global_step(
                sess.graph)
            global_step_tensor.load(0, session=sess)
            self.assertEqual(sess.run(global_step_tensor), 0)

            default_iterations = 1
            hook = tpu_estimator._TPUStopAtStepHook(
                iterations_per_loop_counter, num_steps=num_steps)
            self.assertEqual(default_iterations, hook._next_iteration_count)
            self.assertEqual(num_steps, hook._num_steps)
            self.assertEqual(None, hook._final_step)
            self.assertEqual(iterations_per_loop_counter.value,
                             hook._iterations_per_loop_counter.value)
            self.assertEqual(iterations_per_loop_counter.unit,
                             hook._iterations_per_loop_counter.unit)

            def _step(hook, is_final, expected_iterations):
                hook.begin()
                hook.after_create_session(sess, None)

                class RunContextMock(object):
                    def __init__(self, session):
                        self.session = session
                        self.stop = False

                    def request_stop(self):
                        self.stop = True

                class RunValues(object):
                    def __init__(self, elapsed_time_secs):
                        self.results = {'elapsed_time': elapsed_time_secs}

                run_context = RunContextMock(sess)
                run_values = RunValues(1)
                time.sleep(1.0)
                hook.after_run(run_context, run_values)
                if is_final:
                    self.assertEqual(hook._next_iteration_count,
                                     expected_iterations)
                    self.assertEqual(run_context.stop, is_final)
                else:
                    self.assertLessEqual(
                        abs(hook._next_iteration_count - expected_iterations),
                        1)

            # Estimates iterations when global_step < final_step.
            global_step = sess.run(tf.compat.v1.train.get_global_step())
            self.assertEqual(global_step, 0)
            _step(hook, is_final=False, expected_iterations=3)

            # Estimates iterations when global_step < final_step.
            global_step_tensor.load(2, session=sess)
            _step(hook, is_final=False, expected_iterations=3)

            # Estimates iterations when global_step < final_step, and
            # (final_step - global_step) < estimated-iterations.
            global_step_tensor.load(4, session=sess)
            _step(hook, is_final=False, expected_iterations=1)

            # Estimates iterations when global_step == final_step.
            global_step_tensor.load(5, session=sess)
            _step(hook, is_final=True, expected_iterations=0)