def run(self, fetches, feed_dict=None, options=None, run_metadata=None): """See base class.""" if self.should_stop(): raise RuntimeError('Run called even after should_stop requested.') if self._last_step is None: self._last_step = WrappedSession.run(self, self._global_step_tensor) monitors_step = self._last_step + 1 monitor_fetches = [] for monitor in self._monitors: monitor_requests = monitor.step_begin(monitors_step) if monitor_requests: # TODO(ispir): remove following restriction after b/30136815 fixed if not isinstance(monitor_requests, list): raise ValueError('Monitor.step_begin should return a list.') monitor_fetches.extend(monitor_requests) actual_fetches = { 'caller': fetches, self._global_step_tensor: self._global_step_tensor, 'monitors': [_as_graph_element(f, self.graph) for f in monitor_fetches] } # Do session run. outputs = WrappedSession.run(self, fetches=actual_fetches, feed_dict=feed_dict, options=options, run_metadata=run_metadata) self._last_step = outputs[self._global_step_tensor] # Call monitors step_end and stop if one of them tells to stop. if monitor_fetches: monitor_outputs = dict(zip(monitor_fetches, outputs['monitors'])) else: monitor_outputs = {} for monitor in self._monitors: induce_stop = monitor.step_end(monitors_step, monitor_outputs) self._should_stop = self._should_stop or induce_stop # Call the post_step methods. for monitor in self._monitors: monitor.post_step(monitors_step, self._sess) return outputs['caller']
def run(self, fetches, feed_dict=None, options=None, run_metadata=None): """See base class.""" if self.should_stop(): raise RuntimeError('run called even after should_stop requested.') if self._last_step is None: self._last_step = WrappedSession.run(self, self._global_step_tensor) logging.info('Initialized step to: %d', self._last_step) monitors_step = self._last_step + 1 monitor_fetches = [] for monitor in self._monitors: monitor_fetches.extend(monitor.step_begin(monitors_step)) actual_fetches = { 'caller': fetches, self._global_step_tensor: self._global_step_tensor, 'monitors': [_as_graph_element(f, self.graph) for f in monitor_fetches] } # Do session run. outputs = WrappedSession.run(self, fetches=actual_fetches, feed_dict=feed_dict, options=options, run_metadata=run_metadata) self._last_step = outputs[self._global_step_tensor] # Call monitors step_end and stop if one of them tells to stop. if monitor_fetches: monitor_outputs = dict(zip(monitor_fetches, outputs['monitors'])) else: monitor_outputs = {} for monitor in self._monitors: induce_stop = monitor.step_end(monitors_step, monitor_outputs) self._should_stop = self._should_stop or induce_stop return outputs['caller']
def run(self, fetches, **kwargs): self.args_called = dict(kwargs) # Call run only with fetches since we directly pass other arguments. return WrappedSession.run(self, fetches)