def test_stop_threads_on_close_after_exception(self): with self.test_session() as sess: c = tf.constant(0) v = tf.identity(c) coord = tf.train.Coordinator() threads = [threading.Thread( target=busy_wait_for_coord_stop, args=(coord,)) for _ in range(3)] for t in threads: coord.register_thread(t) t.start() coord_sess = monitored_session._CoordinatedSession(sess, coord) self.assertFalse(coord_sess.should_stop()) for t in threads: self.assertTrue(t.is_alive()) self.assertEqual(0, coord_sess.run(c)) for t in threads: self.assertTrue(t.is_alive()) self.assertEqual(1, coord_sess.run(v, feed_dict={c: 1})) for t in threads: self.assertTrue(t.is_alive()) with self.assertRaisesRegexp(TypeError, 'None has invalid type'): coord_sess.run([None], feed_dict={c: 2}) coord_sess.close() for t in threads: self.assertFalse(t.is_alive()) self.assertTrue(coord.should_stop()) self.assertTrue(coord_sess.should_stop())
def test_should_stop_on_coord_stop(self): with self.test_session() as sess: coord = tf.train.Coordinator() coord_sess = monitored_session._CoordinatedSession(sess, coord) self.assertFalse(coord_sess.should_stop()) coord.request_stop() self.assertTrue(coord_sess.should_stop())
def test_run(self): with self.test_session() as sess: c = tf.constant(0) v = tf.identity(c) coord = tf.train.Coordinator() coord_sess = monitored_session._CoordinatedSession(sess, coord) self.assertEqual(42, coord_sess.run(v, feed_dict={c: 42}))
def test_properties(self): with self.test_session() as sess: tf.constant(0.0) coord = tf.train.Coordinator() coord_sess = monitored_session._CoordinatedSession(sess, coord) self.assertEquals(sess.graph, coord_sess.graph) self.assertEquals(sess.sess_str, coord_sess.sess_str)
def test_should_stop_on_close(self): with self.test_session() as sess: coord = coordinator.Coordinator() coord_sess = monitored_session._CoordinatedSession(sess, coord) self.assertFalse(coord_sess.should_stop()) coord_sess.close() self.assertTrue(coord_sess.should_stop())
def test_dont_request_stop_on_exception_in_main_thread(self): with self.test_session() as sess: c = tf.constant(0) v = tf.identity(c) coord = tf.train.Coordinator() coord_sess = monitored_session._CoordinatedSession(sess, coord) self.assertFalse(coord_sess.should_stop()) self.assertEqual(0, coord_sess.run(c)) self.assertEqual(1, coord_sess.run(v, feed_dict={c: 1})) with self.assertRaisesRegexp(TypeError, 'None has invalid type'): coord_sess.run([None], feed_dict={c: 2}) self.assertFalse(coord.should_stop()) self.assertFalse(coord_sess.should_stop())
def test_stop_threads_on_close(self): with self.test_session() as sess: coord = tf.train.Coordinator() threads = [threading.Thread( target=busy_wait_for_coord_stop, args=(coord,)) for _ in range(3)] for t in threads: coord.register_thread(t) t.start() coord_sess = monitored_session._CoordinatedSession(sess, coord) coord_sess.close() for t in threads: self.assertFalse(t.is_alive()) self.assertTrue(coord.should_stop()) self.assertTrue(coord_sess.should_stop())