예제 #1
0
 def test_stop_threads_on_close_after_exception(self):
   with self.test_session() as sess:
     c = tf.constant(0)
     v = tf.identity(c)
     coord = tf.train.Coordinator()
     threads = [threading.Thread(
         target=busy_wait_for_coord_stop, args=(coord,)) for _ in range(3)]
     for t in threads:
       coord.register_thread(t)
       t.start()
     coord_sess = monitored_session._CoordinatedSession(sess, coord)
     self.assertFalse(coord_sess.should_stop())
     for t in threads:
       self.assertTrue(t.is_alive())
     self.assertEqual(0, coord_sess.run(c))
     for t in threads:
       self.assertTrue(t.is_alive())
     self.assertEqual(1, coord_sess.run(v, feed_dict={c: 1}))
     for t in threads:
       self.assertTrue(t.is_alive())
     with self.assertRaisesRegexp(TypeError, 'None has invalid type'):
       coord_sess.run([None], feed_dict={c: 2})
     coord_sess.close()
     for t in threads:
       self.assertFalse(t.is_alive())
     self.assertTrue(coord.should_stop())
     self.assertTrue(coord_sess.should_stop())
예제 #2
0
 def test_should_stop_on_coord_stop(self):
   with self.test_session() as sess:
     coord = tf.train.Coordinator()
     coord_sess = monitored_session._CoordinatedSession(sess, coord)
     self.assertFalse(coord_sess.should_stop())
     coord.request_stop()
     self.assertTrue(coord_sess.should_stop())
예제 #3
0
 def test_run(self):
   with self.test_session() as sess:
     c = tf.constant(0)
     v = tf.identity(c)
     coord = tf.train.Coordinator()
     coord_sess = monitored_session._CoordinatedSession(sess, coord)
     self.assertEqual(42, coord_sess.run(v, feed_dict={c: 42}))
예제 #4
0
 def test_properties(self):
   with self.test_session() as sess:
     tf.constant(0.0)
     coord = tf.train.Coordinator()
     coord_sess = monitored_session._CoordinatedSession(sess, coord)
     self.assertEquals(sess.graph, coord_sess.graph)
     self.assertEquals(sess.sess_str, coord_sess.sess_str)
예제 #5
0
 def test_should_stop_on_close(self):
   with self.test_session() as sess:
     coord = coordinator.Coordinator()
     coord_sess = monitored_session._CoordinatedSession(sess, coord)
     self.assertFalse(coord_sess.should_stop())
     coord_sess.close()
     self.assertTrue(coord_sess.should_stop())
예제 #6
0
 def test_dont_request_stop_on_exception_in_main_thread(self):
   with self.test_session() as sess:
     c = tf.constant(0)
     v = tf.identity(c)
     coord = tf.train.Coordinator()
     coord_sess = monitored_session._CoordinatedSession(sess, coord)
     self.assertFalse(coord_sess.should_stop())
     self.assertEqual(0, coord_sess.run(c))
     self.assertEqual(1, coord_sess.run(v, feed_dict={c: 1}))
     with self.assertRaisesRegexp(TypeError, 'None has invalid type'):
       coord_sess.run([None], feed_dict={c: 2})
     self.assertFalse(coord.should_stop())
     self.assertFalse(coord_sess.should_stop())
예제 #7
0
 def test_stop_threads_on_close(self):
   with self.test_session() as sess:
     coord = tf.train.Coordinator()
     threads = [threading.Thread(
         target=busy_wait_for_coord_stop, args=(coord,)) for _ in range(3)]
     for t in threads:
       coord.register_thread(t)
       t.start()
     coord_sess = monitored_session._CoordinatedSession(sess, coord)
     coord_sess.close()
     for t in threads:
       self.assertFalse(t.is_alive())
     self.assertTrue(coord.should_stop())
     self.assertTrue(coord_sess.should_stop())