示例#1
0
  def run(self):
    if not tpu_cluster_resolver.is_running_in_gce():
      logging.warning(
          'TPUPollingThread is running in a non-GCE environment, exiting...')
      self._running = False
      return

    while self._running:
      response = self._cluster._fetch_cloud_tpu_metadata()  # pylint: disable=protected-access
      logging.warning(
          'TPUPollingThread found TPU %s in state %s, and health %s.',
          self._cluster._tpu, response['state'], response['health'])  # pylint: disable=protected-access

      if 'state' in response and response['state'] in [
          'TERMINATED', 'PREEMPTED'
      ]:
        logging.warning('TPU node %s reached an unrecoverable state %s, '
                        'terminating the session now.', self._cluster._tpu,  # pylint: disable=protected-access
                        response['state'])
        # Try to close the session.
        self._session.close()
        time.sleep(self._interval)

        if not self._session_closed:
          # Raise an exception if the session.close() stucks.
          logging.warning('Cannot close session on TPU node %s.',
                          self._cluster._tpu)  # pylint: disable=protected-access

          raise errors.UnavailableError(
              None, None, 'TPU node %s reached an unrecoverable state %s.' %
              (self._cluster._tpu, response['state']))  # pylint: disable=protected-access

      time.sleep(self._interval)
示例#2
0
  def run(self):
    if not tpu_cluster_resolver.is_running_in_gce():
      logging.warning(
          'TPUPollingThread is running in a non-GCE environment, exiting...')
      self._running = False
      return

    while self._running:
      response = self._cluster._fetch_cloud_tpu_metadata()  # pylint: disable=protected-access
      logging.warning(
          'TPUPollingThread found TPU %s in state %s, and health %s.',
          self._cluster._tpu, response['state'], response['health'])  # pylint: disable=protected-access

      if 'state' in response and response['state'] in [
          'TERMINATED', 'PREEMPTED'
      ]:
        logging.warning('TPU node %s reached an unrecoverable state %s, '
                        'terminating the session now.', self._cluster._tpu,  # pylint: disable=protected-access
                        response['state'])
        # Try to close the session.
        self._session.close()
        time.sleep(self._interval)

        if not self._session_closed:
          # Raise an exception if the session.close() stucks.
          logging.warning('Cannot close session on TPU node %s.',
                          self._cluster._tpu)  # pylint: disable=protected-access

          raise errors.UnavailableError(
              None, None, 'TPU node %s reached an unrecoverable state %s.' %
              (self._cluster._tpu, response['state']))  # pylint: disable=protected-access

      time.sleep(self._interval)
  def run(self):
    if not tpu_cluster_resolver.is_running_in_gce():
      logging.warning(
          'TPUPollingThread is running in a non-GCE environment, exiting...')
      self._running = False
      return

    while self._running:
      response = self._cluster._fetch_cloud_tpu_metadata()  # pylint: disable=protected-access
      logging.warning(
          'TPUPollingThread found TPU %s in state %s, and health %s.',
          self._cluster._tpu, response['state'],  # pylint: disable=protected-access
          response.get('health', 'UNKNOWN'))

      if 'state' in response and response['state'] in [
          'TERMINATED', 'PREEMPTED'
      ]:
        logging.warning(
            'TPU node %s reached an unrecoverable state %s, '
            'terminating training.',
            self._cluster._tpu,  # pylint: disable=protected-access
            response['state'])
        os._exit(1)  # pylint: disable=protected-access

      time.sleep(self._interval)
    def run(self):
        if not tpu_cluster_resolver.is_running_in_gce():
            logging.warning(
                'TPUPollingThread is running in a non-GCE environment, exiting...'
            )
            self._running = False
            return

        while self._running:
            recoverable = self._cluster._cloud_tpu_client.recoverable()  # pylint: disable=protected-access
            if not recoverable:
                logging.warning('TPUPollingThread found TPU %s in state %s',
                                self._cluster._tpu,
                                self._cluster._cloud_tpu_client.state())  # pylint: disable=protected-access
                os._exit(1)  # pylint: disable=protected-access
            time.sleep(self._interval)
 def after_create_session(self, session, coord):
     if tpu_cluster_resolver.is_running_in_gce():
         self._tpu_poller = _TPUPollingThread(self._cluster, session)
         self._tpu_poller.start()
 def testIsNotRunningInGce(self):
   self.assertFalse(resolver.is_running_in_gce())
 def testIsRunningInGce(self):
   self.assertTrue(resolver.is_running_in_gce())
示例#8
0
 def testIsNotRunningInGce(self):
   self.assertFalse(resolver.is_running_in_gce())
示例#9
0
 def testIsRunningInGce(self):
   self.assertTrue(resolver.is_running_in_gce())
示例#10
0
 def after_create_session(self, session, coord):
   if tpu_cluster_resolver.is_running_in_gce():
     self._tpu_poller = _TPUPollingThread(self._cluster, session)
     self._tpu_poller.start()