def test_not_wait_for_step_zero(self): with tf.Graph().as_default(): tf.contrib.framework.get_or_create_global_step() hook = _GlobalStepWaiterHook(wait_until_step=0) hook.begin() with tf.Session() as sess: # Before run should return without waiting gstep increment. hook.before_run( tf.train.SessionRunContext(original_args=None, session=sess))
def test_not_wait_for_step_zero(self): with tf.Graph().as_default(): tf.contrib.framework.get_or_create_global_step() hook = _GlobalStepWaiterHook(wait_until_step=0) hook.begin() with tf.Session() as sess: # Before run should return without waiting gstep increment. hook.before_run( tf.train.SessionRunContext( original_args=None, session=sess))
def test_wait_for_step(self): with tf.Graph().as_default(): gstep = tf.contrib.framework.get_or_create_global_step() hook = _GlobalStepWaiterHook(wait_until_step=1000) hook.begin() with tf.Session() as sess: sess.run(tf.initialize_all_variables()) waiter = threading.Thread( target=hook.before_run, args=(tf.train.SessionRunContext(original_args=None, session=sess), )) waiter.daemon = True waiter.start() time.sleep(1.0) self.assertTrue(waiter.is_alive()) sess.run(tf.assign(gstep, 500)) time.sleep(1.0) self.assertTrue(waiter.is_alive()) sess.run(tf.assign(gstep, 1100)) time.sleep(1.2) self.assertFalse(waiter.is_alive())
def test_wait_for_step(self): with tf.Graph().as_default(): gstep = tf.contrib.framework.get_or_create_global_step() hook = _GlobalStepWaiterHook(wait_until_step=1000) hook.begin() with tf.Session() as sess: sess.run(tf.initialize_all_variables()) waiter = threading.Thread( target=hook.before_run, args=(tf.train.SessionRunContext( original_args=None, session=sess),)) waiter.daemon = True waiter.start() time.sleep(1.0) self.assertTrue(waiter.is_alive()) sess.run(tf.assign(gstep, 500)) time.sleep(1.0) self.assertTrue(waiter.is_alive()) sess.run(tf.assign(gstep, 1100)) time.sleep(1.2) self.assertFalse(waiter.is_alive())