def run(self): if not tpu_cluster_resolver.is_running_in_gce(): logging.warning( 'TPUPollingThread is running in a non-GCE environment, exiting...') self._running = False return while self._running: response = self._cluster._fetch_cloud_tpu_metadata() # pylint: disable=protected-access logging.warning( 'TPUPollingThread found TPU %s in state %s, and health %s.', self._cluster._tpu, response['state'], response['health']) # pylint: disable=protected-access if 'state' in response and response['state'] in [ 'TERMINATED', 'PREEMPTED' ]: logging.warning('TPU node %s reached an unrecoverable state %s, ' 'terminating the session now.', self._cluster._tpu, # pylint: disable=protected-access response['state']) # Try to close the session. self._session.close() time.sleep(self._interval) if not self._session_closed: # Raise an exception if the session.close() stucks. logging.warning('Cannot close session on TPU node %s.', self._cluster._tpu) # pylint: disable=protected-access raise errors.UnavailableError( None, None, 'TPU node %s reached an unrecoverable state %s.' % (self._cluster._tpu, response['state'])) # pylint: disable=protected-access time.sleep(self._interval)
def run(self): if not tpu_cluster_resolver.is_running_in_gce(): logging.warning( 'TPUPollingThread is running in a non-GCE environment, exiting...') self._running = False return while self._running: response = self._cluster._fetch_cloud_tpu_metadata() # pylint: disable=protected-access logging.warning( 'TPUPollingThread found TPU %s in state %s, and health %s.', self._cluster._tpu, response['state'], # pylint: disable=protected-access response.get('health', 'UNKNOWN')) if 'state' in response and response['state'] in [ 'TERMINATED', 'PREEMPTED' ]: logging.warning( 'TPU node %s reached an unrecoverable state %s, ' 'terminating training.', self._cluster._tpu, # pylint: disable=protected-access response['state']) os._exit(1) # pylint: disable=protected-access time.sleep(self._interval)
def run(self): if not tpu_cluster_resolver.is_running_in_gce(): logging.warning( 'TPUPollingThread is running in a non-GCE environment, exiting...' ) self._running = False return while self._running: recoverable = self._cluster._cloud_tpu_client.recoverable() # pylint: disable=protected-access if not recoverable: logging.warning('TPUPollingThread found TPU %s in state %s', self._cluster._tpu, self._cluster._cloud_tpu_client.state()) # pylint: disable=protected-access os._exit(1) # pylint: disable=protected-access time.sleep(self._interval)
def after_create_session(self, session, coord): if tpu_cluster_resolver.is_running_in_gce(): self._tpu_poller = _TPUPollingThread(self._cluster, session) self._tpu_poller.start()
def testIsNotRunningInGce(self): self.assertFalse(resolver.is_running_in_gce())
def testIsRunningInGce(self): self.assertTrue(resolver.is_running_in_gce())