def __init__(self,
              master,
              is_chief=True,
              checkpoint_dir=None,
              monitors=None,
              scaffold=None,
              config=None,
              clean_stop_exception_types=None):
     self._graph = ops.get_default_graph()
     self._master = master
     self._checkpoint_dir = checkpoint_dir
     self._is_chief = is_chief
     self._config = config
     self._clean_stop_exception_types = clean_stop_exception_types
     self._monitors = monitors or []
     self._scaffold = scaffold or Scaffold()
     # Finalize and write the graph.
     self._graph.finalize()
     # Create the session.
     self._session_manager = sm.SessionManager(
         local_init_op=self._scaffold.local_init_op,
         ready_op=self._scaffold.ready_op,
         graph=ops.get_default_graph())
     self._sess = recoverable_session.RecoverableSession(
         self._create_session)
     # Call the begin() method of monitors.
     self._init_step = self._tf_sess.run(self._scaffold.global_step_tensor)
     for monitor in self._monitors:
         monitor.begin(max_steps=None, init_step=self._init_step)
     # Write the graph out, note: this uses self._init_step.
     self.write_graph()
 def test_recovery(self):
   with self.test_session() as sess:
     c = tf.constant(0)
     v = tf.identity(c)
     # List of 3 sessions to use for recovery.  The first one aborts
     # after 1 run() call, the second after 2 run calls, the third
     # after 3 run calls.
     sessions_to_use = [AbortAtNSession(sess, x + 1)
                        for x in range(3)]
     self.assertEqual(3, len(sessions_to_use))
     # Make the recoverable session uses these 3 sessions in sequence by
     # passing a factory that pops from the session_to_use list.
     recoverable_sess = recoverable_session.RecoverableSession(
         lambda: sessions_to_use.pop(0))
     self.assertEqual(2, len(sessions_to_use))  # One session popped.
     # Using first session.
     self.assertEqual(51, recoverable_sess.run(v, feed_dict={c: 51}))
     self.assertEqual(2, len(sessions_to_use))  # Still 2 sessions available
     # This will fail and recover by picking up the second session.
     self.assertEqual(42, recoverable_sess.run(v, feed_dict={c: 42}))
     self.assertEqual(1, len(sessions_to_use))  # Still 1 session available
     self.assertEqual(33, recoverable_sess.run(v, feed_dict={c: 33}))
     self.assertEqual(1, len(sessions_to_use))  # Still 1 session available
     # This will fail and recover by picking up the last session.
     self.assertEqual(24, recoverable_sess.run(v, feed_dict={c: 24}))
     self.assertEqual(0, len(sessions_to_use))  # All sessions used.
     self.assertEqual(11, recoverable_sess.run(v, feed_dict={c: 11}))
     self.assertEqual(0, recoverable_sess.run(v, feed_dict={c: 0}))
     # This will fail and throw a real error as the pop() will fail.
     with self.assertRaisesRegexp(IndexError, 'pop from empty list'):
       recoverable_sess.run(v, feed_dict={c: -12})
 def test_run(self):
   with self.test_session() as sess:
     c = tf.constant(0)
     v = tf.identity(c)
     recoverable_sess = recoverable_session.RecoverableSession(lambda: sess)
     self.assertEqual(51, recoverable_sess.run(v, feed_dict={c: 51}))
 def test_properties(self):
   with self.test_session() as sess:
     tf.constant(0.0)
     recoverable_sess = recoverable_session.RecoverableSession(lambda: sess)
     self.assertEquals(sess.graph, recoverable_sess.graph)
     self.assertEquals(sess.sess_str, recoverable_sess.sess_str)