def testOnlyHooksHaveFeeds(self): with tf.Graph().as_default(), tf.Session() as sess: mock_hook = FakeHook() mock_hook2 = FakeHook() mon_sess = monitored_session._HookedSession( sess=sess, hooks=[mock_hook, mock_hook2]) a_tensor = tf.constant([0], name='a_tensor') b_tensor = tf.constant([0], name='b_tensor') add_tensor = a_tensor + b_tensor mock_hook.request = session_run_hook.SessionRunArgs( None, feed_dict={a_tensor: [5]}) mock_hook2.request = session_run_hook.SessionRunArgs( None, feed_dict={b_tensor: [10]}) sess.run(tf.initialize_all_variables()) self.assertEqual(mon_sess.run(fetches=add_tensor), [15])
def run(self, fetches, feed_dict=None, options=None, run_metadata=None): """See base class.""" if self.should_stop(): raise RuntimeError('Run called even after should_stop requested.') actual_fetches = {'caller': fetches} run_context = session_run_hook.SessionRunContext( original_args=session_run_hook.SessionRunArgs(fetches, feed_dict), session=self._sess) feed_dict = self._call_hook_before_run(run_context, actual_fetches, feed_dict) # Do session run. outputs = _WrappedSession.run(self, fetches=actual_fetches, feed_dict=feed_dict, options=options, run_metadata=run_metadata) for hook in self._hooks: hook.after_run( run_context, session_run_hook.SessionRunValues( results=outputs[hook] if hook in outputs else None)) self._should_stop = self._should_stop or run_context.stop_requested return outputs['caller']
def before_run(self, run_context): # pylint: disable=unused-argument return session_run_hook.SessionRunArgs({ "global_step": self._global_step_tensor, "train_op": self._current_train_op })
def testHooksAndUserFeedConflicts(self): with tf.Graph().as_default(), tf.Session() as sess: mock_hook = FakeHook() mock_hook2 = FakeHook() mon_sess = monitored_session._HookedSession( sess=sess, hooks=[mock_hook, mock_hook2]) a_tensor = tf.constant([0], name='a_tensor') b_tensor = tf.constant([0], name='b_tensor') add_tensor = a_tensor + b_tensor mock_hook.request = session_run_hook.SessionRunArgs( None, feed_dict={a_tensor: [5]}) mock_hook2.request = session_run_hook.SessionRunArgs( None, feed_dict={b_tensor: [10]}) sess.run(tf.initialize_all_variables()) with self.assertRaisesRegexp(RuntimeError, 'Same tensor is fed'): mon_sess.run(fetches=add_tensor, feed_dict={b_tensor: [10]})
def before_run(self, run_context): del run_context # unused by StopTrainingAfterNTrees. return session_run_hook.SessionRunArgs({ "num_attempted_trees": self._num_attempted_trees_tensor, "num_finalized_trees": self._num_finalized_trees_tensor, })
def testFetchesHookRequests(self): with tf.Graph().as_default(), tf.Session() as sess: mock_hook = FakeHook() mock_hook2 = FakeHook() mon_sess = monitored_session._HookedSession( sess=sess, hooks=[mock_hook, mock_hook2]) a_tensor = tf.constant([0], name='a_tensor') another_tensor = tf.constant([5], name='another_tensor') third_tensor = tf.constant([10], name='third_tensor') mock_hook.request = session_run_hook.SessionRunArgs([another_tensor]) mock_hook2.request = session_run_hook.SessionRunArgs([third_tensor]) sess.run(tf.initialize_all_variables()) output = mon_sess.run(fetches=a_tensor) self.assertEqual(output, [0]) self.assertEqual(mock_hook.last_run_values.results, [5]) self.assertEqual(mock_hook2.last_run_values.results, [10])
def testBothHooksAndUserHaveFeeds(self): with tf.Graph().as_default(), tf.Session() as sess: mock_hook = FakeHook() mock_hook2 = FakeHook() mon_sess = monitored_session._HookedSession( sess=sess, hooks=[mock_hook, mock_hook2]) a_tensor = tf.constant([0], name='a_tensor') b_tensor = tf.constant([0], name='b_tensor') c_tensor = tf.constant([0], name='c_tensor') add_tensor = a_tensor + b_tensor + c_tensor mock_hook.request = session_run_hook.SessionRunArgs( None, feed_dict={a_tensor: [5]}) mock_hook2.request = session_run_hook.SessionRunArgs( None, feed_dict={b_tensor: [10]}) sess.run(tf.initialize_all_variables()) feed_dict = {c_tensor: [20]} self.assertEqual( mon_sess.run(fetches=add_tensor, feed_dict=feed_dict), [35]) # User feed_dict should not be changed self.assertEqual(len(feed_dict), 1)
def before_run(self, run_context): if self._last_step is None: self._last_step = run_context.session.run(self._global_step_tensor) + 1 request = {self._global_step_tensor: self._global_step_tensor} monitor_fetches = [] for m in self._monitors: monitor_requests = m.step_begin(self._last_step) if monitor_requests: if not isinstance(monitor_requests, list): raise ValueError("Monitor.step_begin should return a list.") monitor_fetches.extend(monitor_requests) if monitor_fetches: request["monitors"] = dict( zip(monitor_fetches, [_as_graph_element(f) for f in monitor_fetches])) return session_run_hook.SessionRunArgs(request)
def testCallsHooksBeginEnd(self): with tf.Graph().as_default(), tf.Session() as sess: mock_hook = FakeHook() mock_hook2 = FakeHook() mon_sess = monitored_session._HookedSession( sess=sess, hooks=[mock_hook, mock_hook2]) a_tensor = tf.constant([0], name='a_tensor') sess.run(tf.initialize_all_variables()) mon_sess.run(a_tensor) for hook in [mock_hook, mock_hook2]: self.assertEqual( hook.last_run_values, session_run_hook.SessionRunValues(results=None)) self.assertEqual(hook.last_run_context.original_args, session_run_hook.SessionRunArgs(a_tensor)) self.assertEqual(hook.last_run_context.session, sess) self.assertEqual(hook.call_counter['before_run'], 1) self.assertEqual(hook.call_counter['after_run'], 1)
def before_run(self, run_context): del run_context # unused by FeedFnHook. return session_run_hook.SessionRunArgs(fetches=None, feed_dict=self.feed_fn)