def run(self, fetches, feed_dict=None, options=None, run_metadata=None): """See base class.""" if self.should_stop(): raise RuntimeError('Run called even after should_stop requested.') actual_fetches = {'caller': fetches} run_context = session_run_hook.SessionRunContext( original_args=session_run_hook.SessionRunArgs(fetches, feed_dict), session=self._sess) options = options or config_pb2.RunOptions() feed_dict = self._call_hook_before_run(run_context, actual_fetches, feed_dict, options) # Do session run. run_metadata = run_metadata or config_pb2.RunMetadata() outputs = _WrappedSession.run(self, fetches=actual_fetches, feed_dict=feed_dict, options=options, run_metadata=run_metadata) for hook in self._hooks: hook.after_run( run_context, session_run_hook.SessionRunValues( results=outputs[hook] if hook in outputs else None, options=options, run_metadata=run_metadata)) self._should_stop = self._should_stop or run_context.stop_requested return outputs['caller']
def test_not_wait_for_step_zero(self): with ops.Graph().as_default(): variables.get_or_create_global_step() hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=0) hook.begin() with session_lib.Session() as sess: # Before run should return without waiting gstep increment. hook.before_run( session_run_hook.SessionRunContext( original_args=None, session=sess))
def run_hook_with_indices(self, sweep_hook, row_indices, col_indices): with self.test_session() as sess: # Before run. run_context = session_run_hook.SessionRunContext( original_args=None, session=sess) sess_run_args = sweep_hook.before_run(run_context) feed_dict = { self._input_row_indices_ph: row_indices, self._input_col_indices_ph: col_indices } # Run. run_results = sess.run(sess_run_args.fetches, feed_dict=feed_dict) run_values = session_run_hook.SessionRunValues( results=run_results, options=None, run_metadata=None) # After run. sweep_hook.after_run(run_context, run_values)
def test_wait_for_step(self): with ops.Graph().as_default(): gstep = variables.get_or_create_global_step() hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=1000) hook.begin() with session_lib.Session() as sess: sess.run(variables_lib.global_variables_initializer()) waiter = threading.Thread( target=hook.before_run, args=(session_run_hook.SessionRunContext( original_args=None, session=sess),)) waiter.daemon = True waiter.start() time.sleep(1.0) self.assertTrue(waiter.is_alive()) sess.run(state_ops.assign(gstep, 500)) time.sleep(1.0) self.assertTrue(waiter.is_alive()) sess.run(state_ops.assign(gstep, 1100)) time.sleep(1.2) self.assertFalse(waiter.is_alive())
def end(self, session): # pylint: disable=unused-argument """Runs evaluator for final model.""" step = session.run(self._global_step_tensor) run_ctx = session_run_hook.SessionRunContext({}, session) self._predict(run_ctx, step)