def testBasicTrainLoopExceptionAborts(self): logdir = _test_dir("basic_train_loop_exception_aborts") def train_fn(unused_sess): train_fn.counter += 1 if train_fn.counter == 3: raise RuntimeError("Failed") # Function attribute use to count the number of calls. train_fn.counter = 0 with ops.Graph().as_default(): sv = supervisor.Supervisor(logdir=logdir) with self.assertRaisesRegex(RuntimeError, "Failed"): basic_loops.basic_train_loop(sv, train_fn)
def testBasicTrainLoopExceptionAborts(self): logdir = _test_dir("basic_train_loop_exception_aborts") sv = supervisor.Supervisor(logdir=logdir) def train_fn(unused_sess): train_fn.counter += 1 if train_fn.counter == 3: raise RuntimeError("Failed") # Function attribute use to count the number of calls. train_fn.counter = 0 with ops.Graph().as_default(): with self.assertRaisesRegexp(RuntimeError, "Failed"): basic_loops.basic_train_loop(sv, train_fn)
def testBasicTrainLoop(self): logdir = _test_dir("basic_train_loop") sv = supervisor.Supervisor(logdir=logdir) # Counts the number of calls. num_calls = [0] def train_fn(unused_sess, sv, y, a): num_calls[0] += 1 self.assertEqual("y", y) self.assertEqual("A", a) if num_calls[0] == 3: sv.request_stop() with ops.Graph().as_default(): basic_loops.basic_train_loop( sv, train_fn, args=(sv, "y"), kwargs={"a": "A"}) self.assertEqual(3, num_calls[0])
def testBasicTrainLoop(self): logdir = _test_dir("basic_train_loop") # Counts the number of calls. num_calls = [0] def train_fn(unused_sess, sv, y, a): num_calls[0] += 1 self.assertEqual("y", y) self.assertEqual("A", a) if num_calls[0] == 3: sv.request_stop() with ops.Graph().as_default(): sv = supervisor.Supervisor(logdir=logdir) basic_loops.basic_train_loop(sv, train_fn, args=(sv, "y"), kwargs={"a": "A"}) self.assertEqual(3, num_calls[0])
def testBasicTrainLoopRetryOnAborted(self): logdir = _test_dir("basic_train_loop_exception_aborts") class AbortAndRetry: def __init__(self): self.num_calls = 0 self.retries_left = 2 def train_fn(self, unused_sess): self.num_calls += 1 if self.num_calls % 3 == 2: self.retries_left -= 1 if self.retries_left > 0: raise errors_impl.AbortedError(None, None, "Aborted here") else: raise RuntimeError("Failed Again") with ops.Graph().as_default(): sv = supervisor.Supervisor(logdir=logdir) aar = AbortAndRetry() with self.assertRaisesRegex(RuntimeError, "Failed Again"): basic_loops.basic_train_loop(sv, aar.train_fn) self.assertEqual(0, aar.retries_left)
def testBasicTrainLoopRetryOnAborted(self): logdir = _test_dir("basic_train_loop_exception_aborts") sv = supervisor.Supervisor(logdir=logdir) class AbortAndRetry(object): def __init__(self): self.num_calls = 0 self.retries_left = 2 def train_fn(self, unused_sess): self.num_calls += 1 if self.num_calls % 3 == 2: self.retries_left -= 1 if self.retries_left > 0: raise errors_impl.AbortedError(None, None, "Aborted here") else: raise RuntimeError("Failed Again") with ops.Graph().as_default(): aar = AbortAndRetry() with self.assertRaisesRegexp(RuntimeError, "Failed Again"): basic_loops.basic_train_loop(sv, aar.train_fn) self.assertEquals(0, aar.retries_left)