def test_two_listeners_with_default_saver(self): with ops.Graph().as_default(): global_step = variables.get_or_create_global_step() train_op = state_ops.assign_add(global_step, 1) listener1 = MockCheckpointSaverListener() listener2 = MockCheckpointSaverListener() hook = basic_session_run_hooks.CheckpointSaverHook( self.model_dir, save_steps=1, listeners=[listener1, listener2]) with monitored_session.SingularMonitoredSession( hooks=[hook], checkpoint_dir=self.model_dir) as sess: sess.run(train_op) sess.run(train_op) global_step_val = sess.run(global_step) listener1_counts = listener1.get_counts() listener2_counts = listener2.get_counts() self.assertEqual(2, global_step_val) self.assertEqual({ 'begin': 1, 'before_save': 2, 'after_save': 2, 'end': 1 }, listener1_counts) self.assertEqual(listener1_counts, listener2_counts) with ops.Graph().as_default(): global_step = variables.get_or_create_global_step() with monitored_session.SingularMonitoredSession( checkpoint_dir=self.model_dir) as sess2: global_step_saved_val = sess2.run(global_step) self.assertEqual(2, global_step_saved_val)
def test_save_secs_saves_periodically(self, mock_time): # Pick a fixed start time. current_time = 1484863632.320497 with self.graph.as_default(): mock_time.return_value = current_time hook = ProfilerHook(save_secs=2, output_dir=self.output_dir) with monitored_session.SingularMonitoredSession( hooks=[hook]) as sess: sess.run(self.train_op) # Saved. self.assertEqual(1, self._count_timeline_files()) sess.run(self.train_op) # Not saved. self.assertEqual(1, self._count_timeline_files()) # Simulate 2.5 seconds of sleep. mock_time.return_value = current_time + 2.5 sess.run(self.train_op) # Saved. # Pretend some small amount of time has passed. mock_time.return_value = current_time + 0.1 sess.run(self.train_op) # Not saved. # Edge test just before we should save the timeline. mock_time.return_value = current_time + 1.9 sess.run(self.train_op) # Not saved. self.assertEqual(2, self._count_timeline_files()) mock_time.return_value = current_time + 4.5 sess.run(self.train_op) # Saved. self.assertEqual(3, self._count_timeline_files())
def test_save_secs_saves_in_first_step(self): with self.graph.as_default(): hook = ProfilerHook(save_secs=2, output_dir=self.output_dir) with monitored_session.SingularMonitoredSession( hooks=[hook]) as sess: sess.run(self.train_op) self.assertEqual(1, self._count_timeline_files())
def test_listener_with_monitored_session(self): with ops.Graph().as_default(): scaffold = monitored_session.Scaffold() global_step = variables.get_or_create_global_step() train_op = state_ops.assign_add(global_step, 1) listener = MockCheckpointSaverListener() hook = basic_session_run_hooks.CheckpointSaverHook( self.model_dir, save_steps=1, scaffold=scaffold, listeners=[listener]) with monitored_session.SingularMonitoredSession( hooks=[hook], scaffold=scaffold, checkpoint_dir=self.model_dir) as sess: sess.run(train_op) sess.run(train_op) global_step_val = sess.run(global_step) listener_counts = listener.get_counts() self.assertEqual(2, global_step_val) self.assertEqual({ 'begin': 1, 'before_save': 2, 'after_save': 2, 'end': 1 }, listener_counts)
def run_session(self, hooks, should_stop): hooks = hooks if isinstance(hooks, list) else [hooks] with ops.Graph().as_default(): training_util.create_global_step() no_op = control_flow_ops.no_op() with monitored_session.SingularMonitoredSession(hooks=hooks) as mon_sess: mon_sess.run(no_op) self.assertEqual(mon_sess.should_stop(), should_stop)
def test_stop(self): hook = early_stopping._CheckForStoppingHook() with ops.Graph().as_default(): no_op = control_flow_ops.no_op() assign_op = state_ops.assign(early_stopping._get_or_create_stop_var(), True) with monitored_session.SingularMonitoredSession(hooks=[hook]) as mon_sess: mon_sess.run(no_op) self.assertFalse(mon_sess.should_stop()) mon_sess.run(assign_op) self.assertTrue(mon_sess.should_stop())
def test_stop(self): hook = early_stopping._StopOnPredicateHook( should_stop_fn=lambda: False, run_every_secs=0) with ops.Graph().as_default(): training_util.create_global_step() no_op = control_flow_ops.no_op() with monitored_session.SingularMonitoredSession(hooks=[hook]) as mon_sess: mon_sess.run(no_op) self.assertFalse(mon_sess.should_stop()) self.assertFalse(mon_sess.raw_session().run(hook._stop_var)) hook = early_stopping._StopOnPredicateHook( should_stop_fn=lambda: True, run_every_secs=0) with ops.Graph().as_default(): training_util.create_global_step() no_op = control_flow_ops.no_op() with monitored_session.SingularMonitoredSession(hooks=[hook]) as mon_sess: mon_sess.run(no_op) self.assertTrue(mon_sess.should_stop()) self.assertTrue(mon_sess.raw_session().run(hook._stop_var))
def test_run_metadata_saves_in_first_step(self): writer_cache.FileWriterCache.clear() fake_summary_writer.FakeSummaryWriter.install() fake_writer = writer_cache.FileWriterCache.get(self.output_dir) with self.graph.as_default(): hook = basic_session_run_hooks.ProfilerHook( save_secs=2, output_dir=self.output_dir) with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: sess.run(self.train_op) # Saved. self.assertEqual( list(fake_writer._added_run_metadata.keys()), ['step_1']) fake_summary_writer.FakeSummaryWriter.uninstall()
def test_save_steps_saves_periodically(self): with self.graph.as_default(): hook = ProfilerHook(save_steps=2, output_dir=self.output_dir) with monitored_session.SingularMonitoredSession( hooks=[hook]) as sess: self.assertEqual(0, self._count_timeline_files()) sess.run(self.train_op) # Saved. self.assertEqual(1, self._count_timeline_files()) sess.run(self.train_op) # Not saved. self.assertEqual(1, self._count_timeline_files()) sess.run(self.train_op) # Saved. self.assertEqual(2, self._count_timeline_files()) sess.run(self.train_op) # Not saved. self.assertEqual(2, self._count_timeline_files()) sess.run(self.train_op) # Saved. self.assertEqual(3, self._count_timeline_files())
def test_stop(self): hook = early_stopping._CheckForStoppingHook() with ops.Graph().as_default(): no_op = control_flow_ops.no_op() assign_op = state_ops.assign( early_stopping._get_or_create_stop_var(), True) with monitored_session.SingularMonitoredSession( hooks=[hook]) as mon_sess: mon_sess.run(no_op) self.assertFalse(mon_sess.should_stop()) mon_sess.run(assign_op) # Because there are no guarantees that the stop variable will be read # after the assign op is completed, run another no_op to ensure that the # updated value is read. if not mon_sess.should_stop(): mon_sess.run(no_op) self.assertTrue(mon_sess.should_stop())