Example #1
0
 def test_not_wait_for_step_zero(self):
     with tf.Graph().as_default():
         tf.contrib.framework.get_or_create_global_step()
         hook = _GlobalStepWaiterHook(wait_until_step=0)
         hook.begin()
         with tf.Session() as sess:
             # Before run should return without waiting gstep increment.
             hook.before_run(
                 tf.train.SessionRunContext(original_args=None,
                                            session=sess))
 def test_not_wait_for_step_zero(self):
   with tf.Graph().as_default():
     tf.contrib.framework.get_or_create_global_step()
     hook = _GlobalStepWaiterHook(wait_until_step=0)
     hook.begin()
     with tf.Session() as sess:
       # Before run should return without waiting gstep increment.
       hook.before_run(
           tf.train.SessionRunContext(
               original_args=None, session=sess))
Example #3
0
 def test_wait_for_step(self):
     with tf.Graph().as_default():
         gstep = tf.contrib.framework.get_or_create_global_step()
         hook = _GlobalStepWaiterHook(wait_until_step=1000)
         hook.begin()
         with tf.Session() as sess:
             sess.run(tf.initialize_all_variables())
             waiter = threading.Thread(
                 target=hook.before_run,
                 args=(tf.train.SessionRunContext(original_args=None,
                                                  session=sess), ))
             waiter.daemon = True
             waiter.start()
             time.sleep(1.0)
             self.assertTrue(waiter.is_alive())
             sess.run(tf.assign(gstep, 500))
             time.sleep(1.0)
             self.assertTrue(waiter.is_alive())
             sess.run(tf.assign(gstep, 1100))
             time.sleep(1.2)
             self.assertFalse(waiter.is_alive())
 def test_wait_for_step(self):
   with tf.Graph().as_default():
     gstep = tf.contrib.framework.get_or_create_global_step()
     hook = _GlobalStepWaiterHook(wait_until_step=1000)
     hook.begin()
     with tf.Session() as sess:
       sess.run(tf.initialize_all_variables())
       waiter = threading.Thread(
           target=hook.before_run,
           args=(tf.train.SessionRunContext(
               original_args=None, session=sess),))
       waiter.daemon = True
       waiter.start()
       time.sleep(1.0)
       self.assertTrue(waiter.is_alive())
       sess.run(tf.assign(gstep, 500))
       time.sleep(1.0)
       self.assertTrue(waiter.is_alive())
       sess.run(tf.assign(gstep, 1100))
       time.sleep(1.2)
       self.assertFalse(waiter.is_alive())