def test_do_not_stop_if_checkpoint_is_not_there(self): with ops.Graph().as_default(): step = training.create_global_step() assign_ten = step.assign(10) no_op = control_flow_ops.no_op() hook = hooks_lib._StopAtCheckpointStepHook( model_dir=tempfile.mkdtemp(), last_step=10) with training.SingularMonitoredSession(hooks=[hook]) as mon_sess: mon_sess.raw_session().run(assign_ten) with test.mock.patch.object(time, 'sleep') as mock_sleep: mon_sess.run(no_op) self.assertTrue(mock_sleep.called) self.assertFalse(mon_sess.should_stop())
def test_stop_if_checkpoint_step_is_laststep(self): model_dir = tempfile.mkdtemp() with tf.Graph().as_default(): step = tf.compat.v1.train.create_global_step() assign_ten = step.assign(10) no_op = tf.no_op() hook = hooks_lib._StopAtCheckpointStepHook(model_dir=model_dir, last_step=10) with tf.compat.v1.Session() as sess: sess.run(assign_ten) tf.compat.v1.train.Saver().save( sess, os.path.join(model_dir, 'model.ckpt')) with tf.compat.v1.train.SingularMonitoredSession( hooks=[hook]) as mon_sess: mon_sess.raw_session().run(assign_ten) with tf.compat.v1.test.mock.patch.object( time, 'sleep') as mock_sleep: mon_sess.run(no_op) self.assertFalse(mock_sleep.called) self.assertTrue(mon_sess.should_stop())