Exemplo n.º 1
0
  def test_two_listeners_with_default_saver(self):
    with ops.Graph().as_default():
      global_step = variables.get_or_create_global_step()
      train_op = state_ops.assign_add(global_step, 1)
      listener1 = MockCheckpointSaverListener()
      listener2 = MockCheckpointSaverListener()
      hook = basic_session_run_hooks.CheckpointSaverHook(
          self.model_dir,
          save_steps=1,
          listeners=[listener1, listener2])
      with monitored_session.SingularMonitoredSession(
          hooks=[hook],
          checkpoint_dir=self.model_dir) as sess:
        sess.run(train_op)
        sess.run(train_op)
        global_step_val = sess.run(global_step)
      listener1_counts = listener1.get_counts()
      listener2_counts = listener2.get_counts()
    self.assertEqual(2, global_step_val)
    self.assertEqual({
        'begin': 1,
        'before_save': 2,
        'after_save': 2,
        'end': 1
    }, listener1_counts)
    self.assertEqual(listener1_counts, listener2_counts)

    with ops.Graph().as_default():
      global_step = variables.get_or_create_global_step()
      with monitored_session.SingularMonitoredSession(
          checkpoint_dir=self.model_dir) as sess2:
        global_step_saved_val = sess2.run(global_step)
    self.assertEqual(2, global_step_saved_val)
Exemplo n.º 2
0
    def test_save_secs_saves_periodically(self, mock_time):
        # Pick a fixed start time.
        current_time = 1484863632.320497

        with self.graph.as_default():
            mock_time.return_value = current_time
            hook = ProfilerHook(save_secs=2, output_dir=self.output_dir)
            with monitored_session.SingularMonitoredSession(
                    hooks=[hook]) as sess:
                sess.run(self.train_op)  # Saved.
                self.assertEqual(1, self._count_timeline_files())
                sess.run(self.train_op)  # Not saved.
                self.assertEqual(1, self._count_timeline_files())
                # Simulate 2.5 seconds of sleep.
                mock_time.return_value = current_time + 2.5
                sess.run(self.train_op)  # Saved.

                # Pretend some small amount of time has passed.
                mock_time.return_value = current_time + 0.1
                sess.run(self.train_op)  # Not saved.
                # Edge test just before we should save the timeline.
                mock_time.return_value = current_time + 1.9
                sess.run(self.train_op)  # Not saved.
                self.assertEqual(2, self._count_timeline_files())

                mock_time.return_value = current_time + 4.5
                sess.run(self.train_op)  # Saved.
                self.assertEqual(3, self._count_timeline_files())
Exemplo n.º 3
0
 def test_save_secs_saves_in_first_step(self):
     with self.graph.as_default():
         hook = ProfilerHook(save_secs=2, output_dir=self.output_dir)
         with monitored_session.SingularMonitoredSession(
                 hooks=[hook]) as sess:
             sess.run(self.train_op)
             self.assertEqual(1, self._count_timeline_files())
Exemplo n.º 4
0
 def test_listener_with_monitored_session(self):
   with ops.Graph().as_default():
     scaffold = monitored_session.Scaffold()
     global_step = variables.get_or_create_global_step()
     train_op = state_ops.assign_add(global_step, 1)
     listener = MockCheckpointSaverListener()
     hook = basic_session_run_hooks.CheckpointSaverHook(
         self.model_dir,
         save_steps=1,
         scaffold=scaffold,
         listeners=[listener])
     with monitored_session.SingularMonitoredSession(
         hooks=[hook],
         scaffold=scaffold,
         checkpoint_dir=self.model_dir) as sess:
       sess.run(train_op)
       sess.run(train_op)
       global_step_val = sess.run(global_step)
     listener_counts = listener.get_counts()
   self.assertEqual(2, global_step_val)
   self.assertEqual({
       'begin': 1,
       'before_save': 2,
       'after_save': 2,
       'end': 1
   }, listener_counts)
Exemplo n.º 5
0
 def run_session(self, hooks, should_stop):
   hooks = hooks if isinstance(hooks, list) else [hooks]
   with ops.Graph().as_default():
     training_util.create_global_step()
     no_op = control_flow_ops.no_op()
     with monitored_session.SingularMonitoredSession(hooks=hooks) as mon_sess:
       mon_sess.run(no_op)
       self.assertEqual(mon_sess.should_stop(), should_stop)
Exemplo n.º 6
0
 def test_stop(self):
   hook = early_stopping._CheckForStoppingHook()
   with ops.Graph().as_default():
     no_op = control_flow_ops.no_op()
     assign_op = state_ops.assign(early_stopping._get_or_create_stop_var(),
                                  True)
     with monitored_session.SingularMonitoredSession(hooks=[hook]) as mon_sess:
       mon_sess.run(no_op)
       self.assertFalse(mon_sess.should_stop())
       mon_sess.run(assign_op)
       self.assertTrue(mon_sess.should_stop())
Exemplo n.º 7
0
  def test_stop(self):
    hook = early_stopping._StopOnPredicateHook(
        should_stop_fn=lambda: False, run_every_secs=0)
    with ops.Graph().as_default():
      training_util.create_global_step()
      no_op = control_flow_ops.no_op()
      with monitored_session.SingularMonitoredSession(hooks=[hook]) as mon_sess:
        mon_sess.run(no_op)
        self.assertFalse(mon_sess.should_stop())
        self.assertFalse(mon_sess.raw_session().run(hook._stop_var))

    hook = early_stopping._StopOnPredicateHook(
        should_stop_fn=lambda: True, run_every_secs=0)
    with ops.Graph().as_default():
      training_util.create_global_step()
      no_op = control_flow_ops.no_op()
      with monitored_session.SingularMonitoredSession(hooks=[hook]) as mon_sess:
        mon_sess.run(no_op)
        self.assertTrue(mon_sess.should_stop())
        self.assertTrue(mon_sess.raw_session().run(hook._stop_var))
Exemplo n.º 8
0
 def test_run_metadata_saves_in_first_step(self):
   writer_cache.FileWriterCache.clear()
   fake_summary_writer.FakeSummaryWriter.install()
   fake_writer = writer_cache.FileWriterCache.get(self.output_dir)
   with self.graph.as_default():
     hook = basic_session_run_hooks.ProfilerHook(
         save_secs=2, output_dir=self.output_dir)
     with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
       sess.run(self.train_op)  # Saved.
       self.assertEqual(
           list(fake_writer._added_run_metadata.keys()), ['step_1'])
   fake_summary_writer.FakeSummaryWriter.uninstall()
Exemplo n.º 9
0
 def test_save_steps_saves_periodically(self):
     with self.graph.as_default():
         hook = ProfilerHook(save_steps=2, output_dir=self.output_dir)
         with monitored_session.SingularMonitoredSession(
                 hooks=[hook]) as sess:
             self.assertEqual(0, self._count_timeline_files())
             sess.run(self.train_op)  # Saved.
             self.assertEqual(1, self._count_timeline_files())
             sess.run(self.train_op)  # Not saved.
             self.assertEqual(1, self._count_timeline_files())
             sess.run(self.train_op)  # Saved.
             self.assertEqual(2, self._count_timeline_files())
             sess.run(self.train_op)  # Not saved.
             self.assertEqual(2, self._count_timeline_files())
             sess.run(self.train_op)  # Saved.
             self.assertEqual(3, self._count_timeline_files())
Exemplo n.º 10
0
    def test_stop(self):
        hook = early_stopping._CheckForStoppingHook()
        with ops.Graph().as_default():
            no_op = control_flow_ops.no_op()
            assign_op = state_ops.assign(
                early_stopping._get_or_create_stop_var(), True)
            with monitored_session.SingularMonitoredSession(
                    hooks=[hook]) as mon_sess:
                mon_sess.run(no_op)
                self.assertFalse(mon_sess.should_stop())

                mon_sess.run(assign_op)

                # Because there are no guarantees that the stop variable will be read
                # after the assign op is completed, run another no_op to ensure that the
                # updated value is read.
                if not mon_sess.should_stop():
                    mon_sess.run(no_op)
                    self.assertTrue(mon_sess.should_stop())