def close(self): try: if not self._coord.should_stop(): self._coord.request_stop() self._coord.join(self._coordinated_threads_to_join) except Exception: # pylint: disable=broad-except # Don't raise exception at close pass finally: WrappedSession.close(self)
def __init__(self, sess, coord, coordinated_threads_to_join): """Create a new `CoordinatedSession`. Args: sess: A `tf.Session` object. The wrapped session. coord: A `tf.train.Coordinator` object. coordinated_threads_to_join: A list of threads. """ WrappedSession.__init__(self, sess) self._coord = coord self._coordinated_threads_to_join = coordinated_threads_to_join
def __init__(self, sess_factory): """Create a new `RecoverableSession`. The value returned by calling `sess_factory()` will be the session wrapped by this recoverable session. Args: sess_factory: A callable with no arguments that returns a `tf.Session` when called. """ self._factory = sess_factory WrappedSession.__init__(self, sess_factory())
def __init__(self, sess, monitors, global_step_tensor): """Initializes a MonitoredSession object. Args: sess: A `tf.Session` or a `WrappedSession` object. monitors: An iterable of `tf.contrib.learn.BaseMonitor' objects. global_step_tensor: A 'Tensor' which holds a scalar int value. """ WrappedSession.__init__(self, sess) self._monitors = monitors self._should_stop = False self._global_step_tensor = global_step_tensor self._last_step = None
def run(self, fetches, feed_dict=None, options=None, run_metadata=None): """See base class.""" if self.should_stop(): raise RuntimeError('Run called even after should_stop requested.') if self._last_step is None: self._last_step = WrappedSession.run(self, self._global_step_tensor) monitors_step = self._last_step + 1 monitor_fetches = [] for monitor in self._monitors: monitor_requests = monitor.step_begin(monitors_step) if monitor_requests: # TODO(ispir): remove following restriction after b/30136815 fixed if not isinstance(monitor_requests, list): raise ValueError('Monitor.step_begin should return a list.') monitor_fetches.extend(monitor_requests) actual_fetches = { 'caller': fetches, self._global_step_tensor: self._global_step_tensor, 'monitors': [_as_graph_element(f, self.graph) for f in monitor_fetches] } # Do session run. outputs = WrappedSession.run(self, fetches=actual_fetches, feed_dict=feed_dict, options=options, run_metadata=run_metadata) self._last_step = outputs[self._global_step_tensor] # Call monitors step_end and stop if one of them tells to stop. if monitor_fetches: monitor_outputs = dict(zip(monitor_fetches, outputs['monitors'])) else: monitor_outputs = {} for monitor in self._monitors: induce_stop = monitor.step_end(monitors_step, monitor_outputs) self._should_stop = self._should_stop or induce_stop # Call the post_step methods. for monitor in self._monitors: monitor.post_step(monitors_step, self._sess) return outputs['caller']
def run(self, fetches, feed_dict=None, options=None, run_metadata=None): """See base class.""" if self.should_stop(): raise RuntimeError('run called even after should_stop requested.') if self._last_step is None: self._last_step = WrappedSession.run(self, self._global_step_tensor) logging.info('Initialized step to: %d', self._last_step) monitors_step = self._last_step + 1 monitor_fetches = [] for monitor in self._monitors: monitor_fetches.extend(monitor.step_begin(monitors_step)) actual_fetches = { 'caller': fetches, self._global_step_tensor: self._global_step_tensor, 'monitors': [_as_graph_element(f, self.graph) for f in monitor_fetches] } # Do session run. outputs = WrappedSession.run(self, fetches=actual_fetches, feed_dict=feed_dict, options=options, run_metadata=run_metadata) self._last_step = outputs[self._global_step_tensor] # Call monitors step_end and stop if one of them tells to stop. if monitor_fetches: monitor_outputs = dict(zip(monitor_fetches, outputs['monitors'])) else: monitor_outputs = {} for monitor in self._monitors: induce_stop = monitor.step_end(monitors_step, monitor_outputs) self._should_stop = self._should_stop or induce_stop return outputs['caller']
def run(self, fetches, **kwargs): self.args_called = dict(kwargs) # Call run only with fetches since we directly pass other arguments. return WrappedSession.run(self, fetches)
def __init__(self, sess): WrappedSession.__init__(self, sess) self.args_called = {}