def _create_session(self):
        """Factory for the RecoverableSession.

    Returns:
      A session, initialized or recovered as needed.
    """
        if self._is_chief:
            tf_sess = self._session_manager.prepare_session(
                self._master,
                saver=self._scaffold.saver,
                checkpoint_dir=self._checkpoint_dir,
                config=self._config,
                init_op=self._scaffold.init_op,
                init_feed_dict=self._scaffold.init_feed_dict,
                init_fn=self._scaffold.init_fn)
        else:
            tf_sess = self._session_manager.wait_for_session(
                self._master, config=self._config)
        # Keep the tf_sess for quick runs of global step when needed.
        self._tf_sess = tf_sess
        self._coord = coordinator.Coordinator(
            clean_stop_exception_types=self._clean_stop_exception_types)
        self._coordinated_threads_to_join = queue_runner.start_queue_runners(
            sess=tf_sess, coord=self._coord)
        return coordinated_session.CoordinatedSession(
            monitored_session.MonitoredSession(
                tf_sess, self._monitors, self._scaffold.global_step_tensor),
            self._coord, self._coordinated_threads_to_join)
예제 #2
0
 def test_should_stop_on_coord_stop(self):
   with self.test_session() as sess:
     coord = tf.train.Coordinator()
     coord_sess = coordinated_session.CoordinatedSession(sess, coord, [])
     self.assertFalse(coord_sess.should_stop())
     coord.request_stop()
     self.assertTrue(coord_sess.should_stop())
 def test_stop_threads_on_exception(self):
     with self.test_session() as sess:
         c = tf.constant(0)
         v = tf.identity(c)
         coord = tf.train.Coordinator()
         threads = [
             threading.Thread(target=BusyWaitForCoordStop, args=(coord, ))
             for _ in range(3)
         ]
         for t in threads:
             t.start()
         coord_sess = coordinated_session.CoordinatedSession(
             sess, coord, threads)
         self.assertFalse(coord_sess.should_stop())
         for t in threads:
             self.assertTrue(t.is_alive())
         self.assertEqual(0, coord_sess.run(c))
         for t in threads:
             self.assertTrue(t.is_alive())
         self.assertEqual(1, coord_sess.run(v, feed_dict={c: 1}))
         for t in threads:
             self.assertTrue(t.is_alive())
         with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
                                      'both fed and fetched'):
             coord_sess.run(c, feed_dict={c: 2})
         for t in threads:
             self.assertFalse(t.is_alive())
         self.assertTrue(coord.should_stop())
         self.assertTrue(coord_sess.should_stop())
예제 #4
0
 def test_run(self):
   with self.test_session() as sess:
     c = tf.constant(0)
     v = tf.identity(c)
     coord = tf.train.Coordinator()
     coord_sess = coordinated_session.CoordinatedSession(sess, coord, [])
     self.assertEqual(42, coord_sess.run(v, feed_dict={c: 42}))
예제 #5
0
 def test_properties(self):
   with self.test_session() as sess:
     tf.constant(0.0)
     coord = tf.train.Coordinator()
     coord_sess = coordinated_session.CoordinatedSession(sess, coord, [])
     self.assertEquals(sess.graph, coord_sess.graph)
     self.assertEquals(sess.sess_str, coord_sess.sess_str)
예제 #6
0
 def test_stop_threads_on_close(self):
   with self.test_session() as sess:
     coord = tf.train.Coordinator()
     threads = [threading.Thread(target=BusyWaitForCoordStop,
                                 args=(coord,)) for _ in range(3)]
     for t in threads:
       t.start()
     coord_sess = coordinated_session.CoordinatedSession(sess, coord, threads)
     coord_sess.close()
     for t in threads:
       self.assertFalse(t.is_alive())
     self.assertTrue(coord.should_stop())
     self.assertTrue(coord_sess.should_stop())
예제 #7
0
 def test_request_stop_on_exception(self):
     with self.test_session() as sess:
         c = tf.constant(0)
         v = tf.identity(c)
         coord = tf.train.Coordinator()
         coord_sess = coordinated_session.CoordinatedSession(
             sess, coord, [])
         self.assertFalse(coord_sess.should_stop())
         self.assertEqual(0, coord_sess.run(c))
         self.assertEqual(1, coord_sess.run(v, feed_dict={c: 1}))
         with self.assertRaisesRegexp(TypeError, 'None has invalid type'):
             coord_sess.run([None], feed_dict={c: 2})
         self.assertTrue(coord.should_stop())
         self.assertTrue(coord_sess.should_stop())
예제 #8
0
 def test_request_stop_on_exception(self):
   with self.test_session() as sess:
     c = tf.constant(0)
     v = tf.identity(c)
     coord = tf.train.Coordinator()
     coord_sess = coordinated_session.CoordinatedSession(sess, coord, [])
     self.assertFalse(coord_sess.should_stop())
     self.assertEqual(0, coord_sess.run(c))
     self.assertEqual(1, coord_sess.run(v, feed_dict={c: 1}))
     with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
                                  'both fed and fetched'):
       coord_sess.run(c, feed_dict={c: 2})
     self.assertTrue(coord.should_stop())
     self.assertTrue(coord_sess.should_stop())