def run(self, fetches, feed_dict=None, options=None, run_metadata=None): """See base class.""" if self.should_stop(): raise RuntimeError('Run called even after should_stop requested.') actual_fetches = {'caller': fetches} run_context = session_run_hook.SessionRunContext( original_args=session_run_hook.SessionRunArgs(fetches, feed_dict), session=self._sess) feed_dict = self._call_hook_before_run(run_context, actual_fetches, feed_dict) # Do session run. outputs = _WrappedSession.run(self, fetches=actual_fetches, feed_dict=feed_dict, options=options, run_metadata=run_metadata) for hook in self._hooks: hook.after_run( run_context, session_run_hook.SessionRunValues( results=outputs[hook] if hook in outputs else None)) self._should_stop = self._should_stop or run_context.stop_requested return outputs['caller']
def testCallsHooksBeginEnd(self): with tf.Graph().as_default(), tf.Session() as sess: mock_hook = FakeHook() mock_hook2 = FakeHook() mon_sess = monitored_session._HookedSession( sess=sess, hooks=[mock_hook, mock_hook2]) a_tensor = tf.constant([0], name='a_tensor') sess.run(tf.initialize_all_variables()) mon_sess.run(a_tensor) for hook in [mock_hook, mock_hook2]: self.assertEqual( hook.last_run_values, session_run_hook.SessionRunValues(results=None)) self.assertEqual(hook.last_run_context.original_args, session_run_hook.SessionRunArgs(a_tensor)) self.assertEqual(hook.last_run_context.session, sess) self.assertEqual(hook.call_counter['before_run'], 1) self.assertEqual(hook.call_counter['after_run'], 1)