def get_workers_list(cluster_resolver): worker_job_name = 'worker' cluster_spec = cluster_resolver.cluster_spec() if not cluster_spec: raise errors.UnavailableError( 'None', 'None', 'Cluster spec not found, your client must run in GCE environment.') task_indices = cluster_spec.task_indices(worker_job_name) workers_list = [ cluster_spec.task_address(worker_job_name, i).replace(':8470', ':8466') for i in task_indices ] return ','.join(workers_list)
def stop(save=True): """Stops the current profiling session. The profiler session will be stopped and profile results can be saved. Args: save: An optional variable to save the results to TensorBoard. Default True. Raises: UnavailableError: If there is no active profiling session. """ global _profiler with _profiler_lock: if _profiler is None: raise errors.UnavailableError( None, None, 'Cannot export profiling results. No profiler is running.') if save: _profiler.export_to_tb() _profiler = None
def run(self): if not tpu_cluster_resolver.is_running_in_gce(): logging.warning( 'TPUPollingThread is running in a non-GCE environment, exiting...' ) self._running = False return while self._running: response = self._cluster._fetch_cloud_tpu_metadata() # pylint: disable=protected-access logging.warning( 'TPUPollingThread found TPU %s in state %s, and health %s.', self._cluster._tpu, response['state'], # pylint: disable=protected-access response.get('health', 'UNKNOWN')) if 'state' in response and response['state'] in [ 'TERMINATED', 'PREEMPTED' ]: logging.warning( 'TPU node %s reached an unrecoverable state %s, ' 'terminating the session now.', self._cluster._tpu, # pylint: disable=protected-access response['state']) # Try to close the session. self._session.close() time.sleep(self._interval) if not self._session_closed: # Raise an exception if the session.close() stucks. logging.warning('Cannot close session on TPU node %s.', self._cluster._tpu) # pylint: disable=protected-access raise errors.UnavailableError( None, None, 'TPU node %s reached an unrecoverable state %s.' % (self._cluster._tpu, response['state'])) # pylint: disable=protected-access time.sleep(self._interval)