コード例 #1
0
    def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
        """See base class."""
        if self.should_stop():
            raise RuntimeError('Run called even after should_stop requested.')

        actual_fetches = {'caller': fetches}

        run_context = session_run_hook.SessionRunContext(
            original_args=session_run_hook.SessionRunArgs(fetches, feed_dict),
            session=self._sess)

        options = options or config_pb2.RunOptions()
        feed_dict = self._call_hook_before_run(run_context, actual_fetches,
                                               feed_dict, options)

        # Do session run.
        run_metadata = run_metadata or config_pb2.RunMetadata()
        outputs = _WrappedSession.run(self,
                                      fetches=actual_fetches,
                                      feed_dict=feed_dict,
                                      options=options,
                                      run_metadata=run_metadata)

        for hook in self._hooks:
            hook.after_run(
                run_context,
                session_run_hook.SessionRunValues(
                    results=outputs[hook] if hook in outputs else None,
                    options=options,
                    run_metadata=run_metadata))
        self._should_stop = self._should_stop or run_context.stop_requested

        return outputs['caller']
コード例 #2
0
 def test_not_wait_for_step_zero(self):
   with ops.Graph().as_default():
     variables.get_or_create_global_step()
     hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=0)
     hook.begin()
     with session_lib.Session() as sess:
       # Before run should return without waiting gstep increment.
       hook.before_run(
           session_run_hook.SessionRunContext(
               original_args=None, session=sess))
コード例 #3
0
 def run_hook_with_indices(self, sweep_hook, row_indices, col_indices):
   with self.test_session() as sess:
     # Before run.
     run_context = session_run_hook.SessionRunContext(
         original_args=None, session=sess)
     sess_run_args = sweep_hook.before_run(run_context)
     feed_dict = {
         self._input_row_indices_ph: row_indices,
         self._input_col_indices_ph: col_indices
     }
     # Run.
     run_results = sess.run(sess_run_args.fetches, feed_dict=feed_dict)
     run_values = session_run_hook.SessionRunValues(
         results=run_results, options=None, run_metadata=None)
     # After run.
     sweep_hook.after_run(run_context, run_values)
コード例 #4
0
 def test_wait_for_step(self):
   with ops.Graph().as_default():
     gstep = variables.get_or_create_global_step()
     hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=1000)
     hook.begin()
     with session_lib.Session() as sess:
       sess.run(variables_lib.global_variables_initializer())
       waiter = threading.Thread(
           target=hook.before_run,
           args=(session_run_hook.SessionRunContext(
               original_args=None, session=sess),))
       waiter.daemon = True
       waiter.start()
       time.sleep(1.0)
       self.assertTrue(waiter.is_alive())
       sess.run(state_ops.assign(gstep, 500))
       time.sleep(1.0)
       self.assertTrue(waiter.is_alive())
       sess.run(state_ops.assign(gstep, 1100))
       time.sleep(1.2)
       self.assertFalse(waiter.is_alive())
コード例 #5
0
 def end(self, session):  # pylint: disable=unused-argument
     """Runs evaluator for final model."""
     step = session.run(self._global_step_tensor)
     run_ctx = session_run_hook.SessionRunContext({}, session)
     self._predict(run_ctx, step)