示例#1
0
    def testBasicTrainLoopExceptionAborts(self):
        logdir = _test_dir("basic_train_loop_exception_aborts")

        def train_fn(unused_sess):
            train_fn.counter += 1
            if train_fn.counter == 3:
                raise RuntimeError("Failed")

        # Function attribute use to count the number of calls.
        train_fn.counter = 0

        with ops.Graph().as_default():
            sv = supervisor.Supervisor(logdir=logdir)
            with self.assertRaisesRegex(RuntimeError, "Failed"):
                basic_loops.basic_train_loop(sv, train_fn)
  def testBasicTrainLoopExceptionAborts(self):
    logdir = _test_dir("basic_train_loop_exception_aborts")
    sv = supervisor.Supervisor(logdir=logdir)

    def train_fn(unused_sess):
      train_fn.counter += 1
      if train_fn.counter == 3:
        raise RuntimeError("Failed")

    # Function attribute use to count the number of calls.
    train_fn.counter = 0

    with ops.Graph().as_default():
      with self.assertRaisesRegexp(RuntimeError, "Failed"):
        basic_loops.basic_train_loop(sv, train_fn)
  def testBasicTrainLoop(self):
    logdir = _test_dir("basic_train_loop")
    sv = supervisor.Supervisor(logdir=logdir)
    # Counts the number of calls.
    num_calls = [0]

    def train_fn(unused_sess, sv, y, a):
      num_calls[0] += 1
      self.assertEqual("y", y)
      self.assertEqual("A", a)
      if num_calls[0] == 3:
        sv.request_stop()

    with ops.Graph().as_default():
      basic_loops.basic_train_loop(
          sv, train_fn, args=(sv, "y"), kwargs={"a": "A"})
      self.assertEqual(3, num_calls[0])
示例#4
0
    def testBasicTrainLoop(self):
        logdir = _test_dir("basic_train_loop")
        # Counts the number of calls.
        num_calls = [0]

        def train_fn(unused_sess, sv, y, a):
            num_calls[0] += 1
            self.assertEqual("y", y)
            self.assertEqual("A", a)
            if num_calls[0] == 3:
                sv.request_stop()

        with ops.Graph().as_default():
            sv = supervisor.Supervisor(logdir=logdir)
            basic_loops.basic_train_loop(sv,
                                         train_fn,
                                         args=(sv, "y"),
                                         kwargs={"a": "A"})
            self.assertEqual(3, num_calls[0])
示例#5
0
    def testBasicTrainLoopRetryOnAborted(self):
        logdir = _test_dir("basic_train_loop_exception_aborts")

        class AbortAndRetry:
            def __init__(self):
                self.num_calls = 0
                self.retries_left = 2

            def train_fn(self, unused_sess):
                self.num_calls += 1
                if self.num_calls % 3 == 2:
                    self.retries_left -= 1
                if self.retries_left > 0:
                    raise errors_impl.AbortedError(None, None, "Aborted here")
                else:
                    raise RuntimeError("Failed Again")

        with ops.Graph().as_default():
            sv = supervisor.Supervisor(logdir=logdir)
            aar = AbortAndRetry()
            with self.assertRaisesRegex(RuntimeError, "Failed Again"):
                basic_loops.basic_train_loop(sv, aar.train_fn)
            self.assertEqual(0, aar.retries_left)
  def testBasicTrainLoopRetryOnAborted(self):
    logdir = _test_dir("basic_train_loop_exception_aborts")
    sv = supervisor.Supervisor(logdir=logdir)

    class AbortAndRetry(object):

      def __init__(self):
        self.num_calls = 0
        self.retries_left = 2

      def train_fn(self, unused_sess):
        self.num_calls += 1
        if self.num_calls % 3 == 2:
          self.retries_left -= 1
        if self.retries_left > 0:
          raise errors_impl.AbortedError(None, None, "Aborted here")
        else:
          raise RuntimeError("Failed Again")

    with ops.Graph().as_default():
      aar = AbortAndRetry()
      with self.assertRaisesRegexp(RuntimeError, "Failed Again"):
        basic_loops.basic_train_loop(sv, aar.train_fn)
      self.assertEquals(0, aar.retries_left)