Ejemplo n.º 1
0
 def close(self):
     try:
         if not self._coord.should_stop():
             self._coord.request_stop()
             self._coord.join(self._coordinated_threads_to_join)
     except Exception:  # pylint: disable=broad-except
         # Don't raise exception at close
         pass
     finally:
         WrappedSession.close(self)
Ejemplo n.º 2
0
 def close(self):
   try:
     if not self._coord.should_stop():
       self._coord.request_stop()
       self._coord.join(self._coordinated_threads_to_join)
   except Exception:  # pylint: disable=broad-except
     # Don't raise exception at close
     pass
   finally:
     WrappedSession.close(self)
Ejemplo n.º 3
0
    def __init__(self, sess, coord, coordinated_threads_to_join):
        """Create a new `CoordinatedSession`.

    Args:
      sess: A `tf.Session` object.  The wrapped session.
      coord: A `tf.train.Coordinator` object.
      coordinated_threads_to_join: A list of threads.
    """
        WrappedSession.__init__(self, sess)
        self._coord = coord
        self._coordinated_threads_to_join = coordinated_threads_to_join
Ejemplo n.º 4
0
  def __init__(self, sess, coord, coordinated_threads_to_join):
    """Create a new `CoordinatedSession`.

    Args:
      sess: A `tf.Session` object.  The wrapped session.
      coord: A `tf.train.Coordinator` object.
      coordinated_threads_to_join: A list of threads.
    """
    WrappedSession.__init__(self, sess)
    self._coord = coord
    self._coordinated_threads_to_join = coordinated_threads_to_join
Ejemplo n.º 5
0
    def __init__(self, sess_factory):
        """Create a new `RecoverableSession`.

    The value returned by calling `sess_factory()` will be the
    session wrapped by this recoverable session.

    Args:
      sess_factory: A callable with no arguments that returns a
        `tf.Session` when called.
    """
        self._factory = sess_factory
        WrappedSession.__init__(self, sess_factory())
    def __init__(self, sess_factory):
        """Create a new `RecoverableSession`.

    The value returned by calling `sess_factory()` will be the
    session wrapped by this recoverable session.

    Args:
      sess_factory: A callable with no arguments that returns a
        `tf.Session` when called.
    """
        self._factory = sess_factory
        WrappedSession.__init__(self, sess_factory())
Ejemplo n.º 7
0
    def __init__(self, sess, monitors, global_step_tensor):
        """Initializes a MonitoredSession object.

    Args:
      sess: A `tf.Session` or a `WrappedSession` object.
      monitors: An iterable of `tf.contrib.learn.BaseMonitor' objects.
      global_step_tensor: A 'Tensor' which holds a scalar int value.
    """

        WrappedSession.__init__(self, sess)
        self._monitors = monitors
        self._should_stop = False
        self._global_step_tensor = global_step_tensor
        self._last_step = None
Ejemplo n.º 8
0
  def __init__(self, sess, monitors, global_step_tensor):
    """Initializes a MonitoredSession object.

    Args:
      sess: A `tf.Session` or a `WrappedSession` object.
      monitors: An iterable of `tf.contrib.learn.BaseMonitor' objects.
      global_step_tensor: A 'Tensor' which holds a scalar int value.
    """

    WrappedSession.__init__(self, sess)
    self._monitors = monitors
    self._should_stop = False
    self._global_step_tensor = global_step_tensor
    self._last_step = None
Ejemplo n.º 9
0
  def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
    """See base class."""
    if self.should_stop():
      raise RuntimeError('Run called even after should_stop requested.')

    if self._last_step is None:
      self._last_step = WrappedSession.run(self, self._global_step_tensor)

    monitors_step = self._last_step + 1
    monitor_fetches = []
    for monitor in self._monitors:
      monitor_requests = monitor.step_begin(monitors_step)
      if monitor_requests:
        # TODO(ispir): remove following restriction after b/30136815 fixed
        if not isinstance(monitor_requests, list):
          raise ValueError('Monitor.step_begin should return a list.')
        monitor_fetches.extend(monitor_requests)

    actual_fetches = {
        'caller': fetches,
        self._global_step_tensor: self._global_step_tensor,
        'monitors': [_as_graph_element(f, self.graph) for f in monitor_fetches]
    }

    # Do session run.
    outputs = WrappedSession.run(self,
                                 fetches=actual_fetches,
                                 feed_dict=feed_dict,
                                 options=options,
                                 run_metadata=run_metadata)
    self._last_step = outputs[self._global_step_tensor]

    # Call monitors step_end and stop if one of them tells to stop.
    if monitor_fetches:
      monitor_outputs = dict(zip(monitor_fetches, outputs['monitors']))
    else:
      monitor_outputs = {}

    for monitor in self._monitors:
      induce_stop = monitor.step_end(monitors_step, monitor_outputs)
      self._should_stop = self._should_stop or induce_stop

    # Call the post_step methods.
    for monitor in self._monitors:
      monitor.post_step(monitors_step, self._sess)

    return outputs['caller']
Ejemplo n.º 10
0
    def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
        """See base class."""
        if self.should_stop():
            raise RuntimeError('run called even after should_stop requested.')

        if self._last_step is None:
            self._last_step = WrappedSession.run(self,
                                                 self._global_step_tensor)
        logging.info('Initialized step to: %d', self._last_step)

        monitors_step = self._last_step + 1
        monitor_fetches = []
        for monitor in self._monitors:
            monitor_fetches.extend(monitor.step_begin(monitors_step))

        actual_fetches = {
            'caller': fetches,
            self._global_step_tensor: self._global_step_tensor,
            'monitors':
            [_as_graph_element(f, self.graph) for f in monitor_fetches]
        }

        # Do session run.
        outputs = WrappedSession.run(self,
                                     fetches=actual_fetches,
                                     feed_dict=feed_dict,
                                     options=options,
                                     run_metadata=run_metadata)
        self._last_step = outputs[self._global_step_tensor]

        # Call monitors step_end and stop if one of them tells to stop.
        if monitor_fetches:
            monitor_outputs = dict(zip(monitor_fetches, outputs['monitors']))
        else:
            monitor_outputs = {}

        for monitor in self._monitors:
            induce_stop = monitor.step_end(monitors_step, monitor_outputs)
            self._should_stop = self._should_stop or induce_stop

        return outputs['caller']
Ejemplo n.º 11
0
  def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
    """See base class."""
    if self.should_stop():
      raise RuntimeError('run called even after should_stop requested.')

    if self._last_step is None:
      self._last_step = WrappedSession.run(self, self._global_step_tensor)
    logging.info('Initialized step to: %d', self._last_step)

    monitors_step = self._last_step + 1
    monitor_fetches = []
    for monitor in self._monitors:
      monitor_fetches.extend(monitor.step_begin(monitors_step))

    actual_fetches = {
        'caller': fetches,
        self._global_step_tensor: self._global_step_tensor,
        'monitors': [_as_graph_element(f, self.graph) for f in monitor_fetches]
    }

    # Do session run.
    outputs = WrappedSession.run(self,
                                 fetches=actual_fetches,
                                 feed_dict=feed_dict,
                                 options=options,
                                 run_metadata=run_metadata)
    self._last_step = outputs[self._global_step_tensor]

    # Call monitors step_end and stop if one of them tells to stop.
    if monitor_fetches:
      monitor_outputs = dict(zip(monitor_fetches, outputs['monitors']))
    else:
      monitor_outputs = {}

    for monitor in self._monitors:
      induce_stop = monitor.step_end(monitors_step, monitor_outputs)
      self._should_stop = self._should_stop or induce_stop

    return outputs['caller']
 def run(self, fetches, **kwargs):
     self.args_called = dict(kwargs)
     # Call run only with fetches since we directly pass other arguments.
     return WrappedSession.run(self, fetches)
 def __init__(self, sess):
     WrappedSession.__init__(self, sess)
     self.args_called = {}
Ejemplo n.º 14
0
 def run(self, fetches, **kwargs):
     self.args_called = dict(kwargs)
     # Call run only with fetches since we directly pass other arguments.
     return WrappedSession.run(self, fetches)
Ejemplo n.º 15
0
 def __init__(self, sess):
     WrappedSession.__init__(self, sess)
     self.args_called = {}