def test_recovery(self): with self.test_session() as sess: c = tf.constant(0) v = tf.identity(c) # List of 3 sessions to use for recovery. The first one aborts # after 1 run() call, the second after 2 run calls, the third # after 3 run calls. sessions_to_use = [AbortAtNSession(sess, x + 1) for x in range(3)] self.assertEqual(3, len(sessions_to_use)) # Make the recoverable session uses these 3 sessions in sequence by # passing a factory that pops from the session_to_use list. recoverable_sess = monitored_session._RecoverableSession( lambda: sessions_to_use.pop(0)) self.assertEqual(2, len(sessions_to_use)) # One session popped. # Using first session. self.assertEqual(51, recoverable_sess.run(v, feed_dict={c: 51})) self.assertEqual( 2, len(sessions_to_use)) # Still 2 sessions available # This will fail and recover by picking up the second session. self.assertEqual(42, recoverable_sess.run(v, feed_dict={c: 42})) self.assertEqual(1, len(sessions_to_use)) # Still 1 session available self.assertEqual(33, recoverable_sess.run(v, feed_dict={c: 33})) self.assertEqual(1, len(sessions_to_use)) # Still 1 session available # This will fail and recover by picking up the last session. self.assertEqual(24, recoverable_sess.run(v, feed_dict={c: 24})) self.assertEqual(0, len(sessions_to_use)) # All sessions used. self.assertEqual(11, recoverable_sess.run(v, feed_dict={c: 11})) self.assertEqual(0, recoverable_sess.run(v, feed_dict={c: 0})) # This will fail and throw a real error as the pop() will fail. with self.assertRaisesRegexp(IndexError, 'pop from empty list'): recoverable_sess.run(v, feed_dict={c: -12})
def test_run(self): with self.test_session() as sess: c = tf.constant(0) v = tf.identity(c) recoverable_sess = monitored_session._RecoverableSession( lambda: sess) self.assertEqual(51, recoverable_sess.run(v, feed_dict={c: 51}))
def test_properties(self): with self.test_session() as sess: tf.constant(0.0) recoverable_sess = monitored_session._RecoverableSession( lambda: sess) self.assertEquals(sess.graph, recoverable_sess.graph) self.assertEquals(sess.sess_str, recoverable_sess.sess_str)
def test_recovery(self): with self.test_session() as sess: c = tf.constant(0) v = tf.identity(c) # List of 3 sessions to use for recovery. The first one aborts # after 1 run() call, the second after 2 run calls, the third # after 3 run calls. sessions_to_use = [AbortAtNSession(sess, x + 1) for x in range(3)] self.assertEqual(3, len(sessions_to_use)) # Make the recoverable session uses these 3 sessions in sequence by # passing a factory that pops from the session_to_use list. recoverable_sess = monitored_session._RecoverableSession( lambda: sessions_to_use.pop(0)) self.assertEqual(2, len(sessions_to_use)) # One session popped. # Using first session. self.assertEqual(51, recoverable_sess.run(v, feed_dict={c: 51})) self.assertEqual(2, len(sessions_to_use)) # Still 2 sessions available # This will fail and recover by picking up the second session. self.assertEqual(42, recoverable_sess.run(v, feed_dict={c: 42})) self.assertEqual(1, len(sessions_to_use)) # Still 1 session available self.assertEqual(33, recoverable_sess.run(v, feed_dict={c: 33})) self.assertEqual(1, len(sessions_to_use)) # Still 1 session available # This will fail and recover by picking up the last session. self.assertEqual(24, recoverable_sess.run(v, feed_dict={c: 24})) self.assertEqual(0, len(sessions_to_use)) # All sessions used. self.assertEqual(11, recoverable_sess.run(v, feed_dict={c: 11})) self.assertEqual(0, recoverable_sess.run(v, feed_dict={c: 0})) # This will fail and throw a real error as the pop() will fail. with self.assertRaisesRegexp(IndexError, 'pop from empty list'): recoverable_sess.run(v, feed_dict={c: -12})
def test_run(self): with self.test_session() as sess: c = tf.constant(0) v = tf.identity(c) recoverable_sess = monitored_session._RecoverableSession( self._SessionReturner(sess)) self.assertEqual(51, recoverable_sess.run(v, feed_dict={c: 51}))
def test_properties(self): with self.test_session() as sess: tf.constant(0.0) recoverable_sess = monitored_session._RecoverableSession( self._SessionReturner(sess)) self.assertEquals(sess.graph, recoverable_sess.graph) self.assertEquals(sess.sess_str, recoverable_sess.sess_str)