def tf_allreduce(self, grads, op="MEAN"): if grads is None: logger.error("Grads is required for tf_allreduce operation") return CollectiveCommunicatorStatus.FAILED, grads # convert tf.Tensor to numpy numpy_data = [g.numpy() for g in grads] return self.allreduce(numpy_data, op)
def get_dataset(self): """ If there's more data, this creates a `tf.data.Dataset` object. Otherwise, this returns `None`. """ if self._pending_dataset: if self._pending_tasks: logger.error( "Cannot get new dataset when there are pending tasks" ) return None self._reset() # We use a task to perform warm-up for data reader in order # to collect useful metadata. Note that we only performs # data fetching for this task and `break` instantly to make # sure `read_records()` is executed without iterating all the # records so this should not be time consuming. if self._warm_up_task is None and not self._has_warmed_up: task = self._worker.get_task() self._warm_up_task = task for _ in self.data_reader.read_records(task): break self._has_warmed_up = True ds = tf.data.Dataset.from_generator( self._gen, self.data_reader.records_output_types ) self._pending_dataset = False return ds else: return None
def run(self): """ The main loop of master. Dispatch the tasks to the workers until all the tasks are completed. """ try: while True: if self._stop_requested: break if self.pod_manager and self.pod_manager.all_workers_exited: if self.pod_manager.all_workers_failed: logger.error("All workers failed") self._exit_code = 1 break if self.task_manager and not self.task_manager.finished(): logger.warning( "All workers exited but there also are " "unfinished tasks", ) self.pod_manager.update_status(PodManagerStatus.FINISHED) break time.sleep(30) except KeyboardInterrupt: self.logger.warning("Server stopping") finally: self.stop() return self._exit_code
def monitor_status(self): retry_num = 0 pod_succeeded = False while True: try: pod = self.client.get_pod(self.pod_name) if pod is None: retry_num += 1 if retry_num > MAX_READ_POD_RETRIES: logger.error("{} Not Found".format(self.pod_name)) break time.sleep(10) continue retry_num = 0 logger.info("Pod Status : %s" % pod.status.phase) if pod.status.phase == PodStatus.SUCCEEDED: pod_succeeded = True break elif pod.status.phase == PodStatus.FAILED: logger.info(self.client.get_pod_log(self.pod_name)) break else: time.sleep(30) except client.api_client.ApiException: time.sleep(60) return pod_succeeded
def GetModel(self, request, _): if not self._use_async: self._validate_model_version(request.version) if ( request.method == elasticdl_pb2.MINIMUM or request.version == self._version ): if self._use_async: res = self._get_model_no_lock() else: with self._lock: res = self._get_model_no_lock() return res # Read from checkpoint for the fixed version model pb_model = elasticdl_pb2.Model() try: pb_model = self._checkpoint_service.get_checkpoint_model( request.version ) except Exception: logger.error( "Failed to fetch checkpoint model for " "model version {}".format(request.version) ) return pb_model
def monitor_status(self): retry_num = 0 job_succeed = False master_old_log = "" while True: master_pod = self.client.get_master_pod() if master_pod is None: retry_num += 1 if retry_num > MAX_READ_POD_RETRIES: logger.error("{} Not Found".format( self.client.get_master_pod_name())) break time.sleep(10) continue logger.info("Master status: {}".format(master_pod.status.phase)) if master_pod.status.phase == PodStatus.SUCCEEDED: job_succeed = True break elif master_pod.status.phase == PodStatus.PENDING: time.sleep(10) elif master_pod.status.phase == PodStatus.FAILED: log = self.client.get_master_log() print_tail_log(log, tail_num=100) logger.error("Job {} Failed".format(self.job_name)) break else: master_new_log = self.client.get_master_log() self.show_evaluation_and_task_log(master_new_log, master_old_log) master_old_log = master_new_log self.check_worker_status() self.check_ps_status() time.sleep(60) return job_succeed
def _event_cb(self, event): evt_obj = event.get("object") evt_type = event.get("type") if not evt_obj or not evt_type: logger.error("Event doesn't have object or type: %s" % event) return pod_name = evt_obj.metadata.name phase = evt_obj.status.phase logger.info("Got event %s, phase %s for pod: %s" % (evt_type, phase, pod_name)) relaunch = False with self._lock: worker_id = self._pod_name_to_id.get(pod_name) if (worker_id is None and pod_name != self._k8s_client.get_master_pod_name()): logger.error("Unknown worker pod name: %s" % pod_name) return self._pods_phase[worker_id] = (pod_name, phase) if evt_type == "DELETED": del self._pods_phase[worker_id] del self._pod_name_to_id[pod_name] self._task_d.recover_tasks(worker_id) # If a deleted pod was not "Succeeded", relaunch a worker. relaunch = (self._relaunch_deleted_live_worker and phase != "Succeeded") if relaunch: logger.info("Relaunching worker.") self._start_worker(self._next_worker_id())
def __init__(self, service_name=None): if _FTLIB_INSTALLED: connection_try_num = 0 while True: try: peer_list = list(self._get_peer_set(service_name)) except Exception: if (connection_try_num * 5 > _FTLIB_CONSENSUS_CONNECTION_TIMEOUT_SECS): logger.error( "Cannot connect to FTLib consensus service in %s " "seconds", str(_FTLIB_CONSENSUS_CONNECTION_TIMEOUT_SECS), ) self._ftlib = None return # sleep for 5s and try again logger.warning( "Cannot connect to FTLib consensus service, " "trying again.") connection_try_num += 1 time.sleep(5) else: break self._ftlib = BasicFTLib( consensus="gossip", commlib="pytorch", consensus_init_kwargs={ "known_addr_list": peer_list, "custom_bind_addr": socket.gethostbyname(socket.gethostname()), }, ) connection_try_num = 0 while peer_list and not self._ftlib.consensus_joined(): logger.warning("Retry building consensus...") try: self._ftlib.manual_join( known_addr_list=list(self._get_peer_set(service_name))) except Exception: if (connection_try_num * 5 > _FTLIB_CONSENSUS_CONNECTION_TIMEOUT_SECS): logger.error( "Cannot join FTLib consensus service in %s " "seconds", str(_FTLIB_CONSENSUS_CONNECTION_TIMEOUT_SECS), ) self._ftlib = None return logger.warning("Cannot join FTLib consensus service, " "trying again.") connection_try_num += 1 time.sleep(5) else: logger.warning( "FTLib is not installed. The CollectiveCommunicator " "may not work as expected") self._ftlib = None
def check_exceed_max_task_retries(self, task): self._task_retry_count.setdefault(task, 1) self._task_retry_count[task] += 1 if self._task_retry_count[task] > _MAX_TASK_RETRIES: logger.error("A %s task failed with %d retries " % (task.type, _MAX_TASK_RETRIES)) return True return False
def request_stop(self, success, msg=""): self._stop_requested = True if success: self._exit_code = 0 logger.info(msg) else: self._exit_code = 1 logger.error(msg)
def _remove_ps(self, ps_id): logger.info("Removing PS: %d", ps_id) with self._lock: if ps_id not in self._ps_pods_phase: logger.error("Unknown PS id: %s" % ps_id) return self._k8s_client.delete_ps(ps_id)
def _remove_worker(self, worker_id): logger.info("Removing worker: %d", worker_id) with self._lock: if worker_id not in self._worker_pods_phase: logger.error("Unknown worker id: %s" % worker_id) return # TODO: change _k8s_client to accept pod name instead of worker id. self._k8s_client.delete_worker(worker_id)
def check_worker_status(self): for i in range(self.worker_num): worker_pod = self.client.get_worker_pod(i) worker_pod_name = self.client.get_worker_pod_name(i) if worker_pod is None: logger.error("Worker {} Not Found".format(worker_pod_name)) elif worker_pod.status.phase == PodStatus.FAILED: logger.error("Worker {} {}".format(worker_pod_name, worker_pod.status.phase))
def check_ps_status(self): for i in range(self.ps_num): ps_pod = self.client.get_ps_pod(i) ps_pod_name = self.client.get_ps_pod_name(i) if ps_pod is None: logger.error("PS {} Not Found".format(ps_pod_name)) elif ps_pod.status.phase == PodStatus.FAILED: logger.error("PS {} {}".format(ps_pod_name, ps_pod.status.phase))
def _event_cb(self, event): evt_obj = event.get("object") evt_type = event.get("type") if not evt_obj or not evt_type: logger.error("Event doesn't have object or type: %s" % event) return if evt_obj.kind != "Pod": # We only care about pod related events return pod_name = evt_obj.metadata.name phase = evt_obj.status.phase logger.info("Got event %s, phase %s for pod: %s" % (evt_type, phase, pod_name)) if pod_name == self._k8s_client.get_master_pod_name(): # No need to care about master pod return relaunch_worker = False relaunch_ps = False ps_id = -1 with self._lock: if pod_name in self._worker_pod_name_to_id: worker_id = self._worker_pod_name_to_id.get(pod_name) self._worker_pods_phase[worker_id] = (pod_name, phase) if evt_type == "DELETED": del self._worker_pods_phase[worker_id] del self._worker_pod_name_to_id[pod_name] self._task_d.recover_tasks(worker_id) # If a deleted pod was not "Succeeded", relaunch a worker. relaunch_worker = (self._relaunch_deleted_live_worker and phase != "Succeeded") elif pod_name in self._ps_pod_name_to_id: ps_id = self._ps_pod_name_to_id.get(pod_name) self._ps_pods_phase[ps_id] = (pod_name, phase) if evt_type == "DELETED": del self._ps_pods_phase[ps_id] del self._ps_pod_name_to_id[pod_name] relaunch_ps = self._relaunch_deleted_live_ps else: logger.error("Unknown worker pod name: %s" % pod_name) return if relaunch_worker: logger.info("Relaunching worker.") self._start_worker(self._next_worker_id()) elif relaunch_ps: logger.info("Relaunching ps.") self._start_ps(ps_id)
def set_training_params( self, batch_size, num_epochs, dataset_size, shuffle, shuffle_shards, num_minibatches_per_shard, ): logger.info( "Set training parameters: " "batch_size={}, num_epochs={}, dataset_size={}, shuffle={}, " "shuffle_shards={}, num_minibatches_per_shard={}".format( batch_size, num_epochs, dataset_size, shuffle, shuffle_shards, num_minibatches_per_shard, )) with self._lock: if self._training_shards: logger.info( "The training shards have already been initialized." "Ignore these training parameters.") return logger.info("Initialize with these training parameters.") # The master receives the training params to create shards self._batch_size = batch_size self._shuffle = shuffle self._shuffle_shards = shuffle_shards self._num_minibatches_per_task = ( num_minibatches_per_shard if num_minibatches_per_shard > 0 else self._num_minibatches_per_task) self._records_per_task = (batch_size * self._num_minibatches_per_task) self._num_epochs = num_epochs if dataset_size > 0: self._dataset_size = dataset_size self._training_shards = self._create_shards_by_dataset_size( self._dataset_size) else: logger.error( "No shard creating because dataset size {} <= 0".format( dataset_size)) if self._training_shards: logger.info("Starting epoch %d", self._epoch) self.create_tasks(elasticai_api_pb2.TRAINING)
def _get_embedding_cluster(self): startup_nodes = [{ "host": ip, "port": "%d" % port } for ip, port_list in self._embedding_service_endpoint.items() for port in port_list] try: redis_cluster = RedisCluster(startup_nodes=startup_nodes, decode_responses=False) except Exception as e: logger.error(e) return None else: return redis_cluster
def report_evaluation_metrics(self, evaluation_version, evaluation_metrics): if (self.model_version >= 0 and evaluation_version != self.model_version): logger.error( "Drop a wrong version evaluation: request %d, receive %d" % (self.model_version, evaluation_version)) return False for k, v in evaluation_metrics.items(): if k in self._evaluation_metrics: self._evaluation_metrics[k] += tensor_to_ndarray(v) else: self._evaluation_metrics[k] = np.copy(tensor_to_ndarray(v)) self._completed_minibatches += 1 return True
def get_dataset(self): """ Return a RecordIO dataset, or None if no more data. """ if self._pending_dataset: if self._pending_tasks_with_counts: logger.error( "Cannot get new dataset when there are pending tasks") return None self._reset() ds = tf.data.Dataset.from_generator( self._gen, self._data_reader.records_output_types) self._pending_dataset = False return ds else: return None
def report_evaluation_metrics(self, evaluation_version, model_outputs, labels): if (self.model_version >= 0 and evaluation_version != self.model_version): logger.error( "Drop a wrong version evaluation: request %d, receive %d" % (self.model_version, evaluation_version)) return False labels = tensor_to_ndarray(labels) for key, tensor in model_outputs.items(): metrics = self._metrics_dict.get(key, {}) if not metrics: continue outputs = tensor_to_ndarray(tensor) for metric_inst in metrics.values(): metric_inst.update_state(labels, outputs) return True
def _save_checkpoint(self, locking, is_eval_checkpoint): try: logger.info("Saving checkpoint for model version %d" % self._version) if locking: self._lock.acquire() pb_model = self._get_model_no_lock() self._checkpoint_service.save(self._version, pb_model, is_eval_checkpoint) checkpoint_version = self._version if locking: self._lock.release() return checkpoint_version except Exception: logger.error( "Failed to save checkpoint file for model version %d" % self._version)
def allreduce(self, data, op="MEAN"): if data is None: logger.error("Data is required for allreduce operation") return CollectiveCommunicatorStatus.FAILED, data if op not in _SUPPORTED_ALLREDUCE_OPS: logger.error( "%s is not in list of supported allreduce operations: %s" % (op, _SUPPORTED_ALLREDUCE_OPS)) return CollectiveCommunicatorStatus.FAILED, data if self._ftlib is not None: res = self._ftlib.wait_gradients_ready(data) if res == FTAllReduceStatus.SUCCESS: return CollectiveCommunicatorStatus.SUCCEEDED, data else: return CollectiveCommunicatorStatus.FAILED, data else: logger.warning(_FTLIB_UNINSTALLED_DEFAULT_STATUS_MESSAGE) return CollectiveCommunicatorStatus.SUCCEEDED, data
def allreduce(self, data, op="MEAN"): if data is None: logger.error("Data is required for allreduce operation") return CollectiveCommunicatorStatus.FAILED, data if op not in _SUPPORTED_ALLREDUCE_OPS: logger.error( "%s is not in list of supported allreduce operations: %s" % (op, _SUPPORTED_ALLREDUCE_OPS)) return CollectiveCommunicatorStatus.FAILED, data if self._ftlib is not None: res = self._ftlib.allreduce_average(data) if res == FTAllReduceStatus.SUCCESS: return CollectiveCommunicatorStatus.SUCCEEDED, data else: return CollectiveCommunicatorStatus.FAILED, data else: logger.warning("FTLib is not installed. " "Default to succeeded for testing purposes") return CollectiveCommunicatorStatus.SUCCEEDED, data
def get_dataset(self): """ If there's more data, this creates a `tf.data.Dataset` object. Otherwise, this returns `None`. """ if self._pending_dataset: if self._pending_tasks: logger.error( "Cannot get new dataset when there are pending tasks") return None self._reset() # We use a task to perform warm-up for data reader in order # to collect useful metadata. Note that we only performs # data fetching for this task and `break` instantly to make # sure `read_records()` is executed without iterating all the # records so this should not be time consuming. if self._warm_up_task is None and not self._has_warmed_up: while True: task = self._worker.get_task() if task.type != elasticdl_pb2.WAIT: break time.sleep(2) if task.type == elasticdl_pb2.TRAIN_END_CALLBACK: self._pending_train_end_callback_task = task return None elif not task.shard_name: logger.info("No more task, stopping") return None else: self._warm_up_task = task for _ in self.data_reader.read_records(task): break self._has_warmed_up = True ds = tf.data.Dataset.from_generator( self._gen, self.data_reader.records_output_types) self._pending_dataset = False return ds else: return None
def _event_cb(self, event): evt_obj = event.get("object") evt_type = event.get("type") if not evt_obj or not evt_type: logger.error("Event doesn't have object or type: %s" % event) return if evt_obj.kind != "Pod": # We only care about pod related events return pod_name = evt_obj.metadata.name pod_ip = evt_obj.status.pod_ip phase = evt_obj.status.phase if pod_name == self._k8s_client.get_master_pod_name(): # No need to care about master pod return relaunch_worker = False relaunch_ps = False worker_id = -1 ps_id = -1 with self._lock: if pod_name in self._failed_pods: return relaunch_failed_pod = False if evt_type == "MODIFIED" and phase == "Failed": self._failed_pods.append(pod_name) worker_id = self._worker_pod_name_to_id.get(pod_name, None) if worker_id is not None: # Recover tasks when the worker failed self._task_d.recover_tasks(worker_id) if (evt_obj.status.container_statuses and evt_obj.status.container_statuses[0].state.terminated and evt_obj.status.container_statuses[0].state. terminated.exit_code == 137 and evt_obj.status.container_statuses[0].state. terminated.reason != "OOMKilled"): relaunch_failed_pod = True logger.info("Pod %s is killed with reason %s." % ( pod_name, evt_obj.status.container_statuses[0].state.terminated. reason, )) if pod_name in self._worker_pod_name_to_id: worker_id = self._worker_pod_name_to_id.get(pod_name) self._worker_pods_ip_phase[worker_id] = ( pod_name, pod_ip, phase, ) if evt_type == "DELETED" or relaunch_failed_pod: del self._worker_pods_ip_phase[worker_id] del self._worker_pod_name_to_id[pod_name] # If a deleted pod was not "Succeeded", relaunch a worker. relaunch_worker = (self._relaunch_deleted_live_worker and phase != "Succeeded") else: workers_failed = [] for ( pod_name, _, phase, ) in self._worker_pods_ip_phase.values(): workers_failed.append(phase == PodStatus.FAILED) self.all_workers_failed = all(workers_failed) elif pod_name in self._ps_pod_name_to_id: ps_id = self._ps_pod_name_to_id.get(pod_name) self._ps_pods_phase[ps_id] = (pod_name, phase) if evt_type == "DELETED" or relaunch_failed_pod: del self._ps_pods_phase[ps_id] del self._ps_pod_name_to_id[pod_name] relaunch_ps = self._relaunch_deleted_live_ps else: logger.error("Unknown pod name: %s" % pod_name) return if self._rendezvous_server: self._worker_addrs = self._get_alive_worker_addr() self._rendezvous_server.set_worker_hosts(self._worker_addrs) if relaunch_worker and worker_id >= 0: logger.info("Relaunching worker.") new_worker_id = self._next_worker_id() with self._lock: self._worker_pod_priority[ new_worker_id] = self._worker_pod_priority[worker_id] self._start_worker(new_worker_id) elif relaunch_ps: logger.info("Relaunching ps.") # Note: the ID and service address for relaunched parameter # server are intentionally left unchanged to support fault # tolerance. self._start_ps(ps_id)
def _event_cb(self, event): evt_obj = event.get("object") evt_type = event.get("type") if not evt_obj or not evt_type: logger.error("Event doesn't have object or type: %s" % event) return if evt_obj.kind != "Pod": # We only care about pod related events return pod_name = evt_obj.metadata.name phase = evt_obj.status.phase logger.info( "Got event %s, phase %s for pod: %s" % (evt_type, phase, pod_name) ) if pod_name == self._k8s_client.get_master_pod_name(): # No need to care about master pod return relaunch_worker = False relaunch_ps = False worker_id = -1 ps_id = -1 with self._lock: if pod_name in self._worker_pod_name_to_id: worker_id = self._worker_pod_name_to_id.get(pod_name) self._worker_pods_phase[worker_id] = (pod_name, phase) if evt_type == "DELETED": del self._worker_pods_phase[worker_id] del self._worker_pod_name_to_id[pod_name] self._task_d.recover_tasks(worker_id) # If a deleted pod was not "Succeeded", relaunch a worker. relaunch_worker = ( self._relaunch_deleted_live_worker and phase != "Succeeded" ) elif pod_name in self._ps_pod_name_to_id: ps_id = self._ps_pod_name_to_id.get(pod_name) self._ps_pods_phase[ps_id] = (pod_name, phase) if evt_type == "DELETED": del self._ps_pods_phase[ps_id] del self._ps_pod_name_to_id[pod_name] relaunch_ps = self._relaunch_deleted_live_ps else: logger.error("Unknown pod name: %s" % pod_name) return if relaunch_worker and worker_id >= 0: logger.info("Relaunching worker.") new_worker_id = self._next_worker_id() self._start_worker(new_worker_id) self._update_addr( worker_id, new_worker_id, self._worker_addrs, addr_get_fn=self._k8s_client.get_worker_service_address, ) elif relaunch_ps: logger.info("Relaunching ps.") # Note: the ID and service address for relaunched parameter # server are intentionally left unchanged to support fault # tolerance. self._start_ps(ps_id)
def _event_cb(self, event): evt_obj = event.get("object") evt_type = event.get("type") if not evt_obj or not evt_type: logger.error("Event doesn't have object or type: %s" % event) return if evt_obj.kind != "Pod": # We only care about pod related events return pod_name = evt_obj.metadata.name phase = evt_obj.status.phase logger.info("Got event %s, phase %s for pod: %s" % (evt_type, phase, pod_name)) if pod_name == self._k8s_client.get_master_pod_name(): # No need to care about master pod return relaunch_worker = False relaunch_ps = False worker_id = -1 ps_id = -1 with self._lock: if pod_name in self._failed_pods: return # Workaround for memory leak issues in tf eager mode. # A pod may fail due to OOM from tf eager mode memory leak. failed_pod = False if (evt_type == "MODIFIED" and phase == "Failed" and evt_obj.status.container_statuses and evt_obj.status.container_statuses[0].state.terminated and evt_obj.status.container_statuses[0].state.terminated. reason == "OOMKilled"): self._failed_pods.append(pod_name) failed_pod = True logger.info("Pod %s is OOMKilled." % pod_name) if pod_name in self._worker_pod_name_to_id: worker_id = self._worker_pod_name_to_id.get(pod_name) self._worker_pods_phase[worker_id] = (pod_name, phase) if evt_type == "DELETED" or failed_pod: del self._worker_pods_phase[worker_id] del self._worker_pod_name_to_id[pod_name] self._task_d.recover_tasks(worker_id) # If a deleted pod was not "Succeeded", relaunch a worker. relaunch_worker = (self._relaunch_deleted_live_worker and phase != "Succeeded") elif pod_name in self._ps_pod_name_to_id: ps_id = self._ps_pod_name_to_id.get(pod_name) self._ps_pods_phase[ps_id] = (pod_name, phase) if evt_type == "DELETED" or failed_pod: del self._ps_pods_phase[ps_id] del self._ps_pod_name_to_id[pod_name] relaunch_ps = self._relaunch_deleted_live_ps else: logger.error("Unknown pod name: %s" % pod_name) return if relaunch_worker and worker_id >= 0: logger.info("Relaunching worker.") new_worker_id = self._next_worker_id() self._start_worker(new_worker_id) self._update_addr( worker_id, new_worker_id, self._worker_addrs, addr_get_fn=self._k8s_client.get_worker_service_address, ) elif relaunch_ps: logger.info("Relaunching ps.") # Note: the ID and service address for relaunched parameter # server are intentionally left unchanged to support fault # tolerance. self._start_ps(ps_id)
def _event_cb(self, event): evt_obj = event.get("object") evt_type = event.get("type") if not evt_obj or not evt_type: logger.error("Event doesn't have object or type: %s" % event) return if evt_obj.kind != "Pod": # We only care about pod related events return pod_name = evt_obj.metadata.name phase = evt_obj.status.phase if pod_name == self._k8s_client.get_master_pod_name(): # No need to care about master pod return relaunch_worker = False relaunch_ps = False worker_id = -1 ps_id = -1 with self._lock: if pod_name in self._failed_pods: return # When a pod fails with exit_code == 137, it may be deleted, # preempted, or OOMkilled. Master will try to relaunch it. # For OOMkilled, the relaunch is a workaround for memory leak # issues in tf eager mode. relaunch_failed_pod = False if (evt_type == "MODIFIED" and phase == "Failed" and evt_obj.status.container_statuses and evt_obj.status.container_statuses[0].state.terminated and evt_obj.status.container_statuses[0].state.terminated. exit_code == 137): self._failed_pods.append(pod_name) relaunch_failed_pod = True logger.info("Pod %s is killed with reason %s." % ( pod_name, evt_obj.status.container_statuses[0].state.terminated. reason, )) if pod_name in self._worker_pod_name_to_id: worker_id = self._worker_pod_name_to_id.get(pod_name) self._worker_pods_phase[worker_id] = (pod_name, phase) if evt_type == "DELETED" or relaunch_failed_pod: del self._worker_pods_phase[worker_id] del self._worker_pod_name_to_id[pod_name] self._task_d.recover_tasks(worker_id) # If a deleted pod was not "Succeeded", relaunch a worker. relaunch_worker = (self._relaunch_deleted_live_worker and phase != "Succeeded") else: workers_failed = [] for pod_name, phase in self._worker_pods_phase.values(): workers_failed.append(phase == PodStatus.FAILED) self.all_workers_failed = all(workers_failed) elif pod_name in self._ps_pod_name_to_id: ps_id = self._ps_pod_name_to_id.get(pod_name) self._ps_pods_phase[ps_id] = (pod_name, phase) if evt_type == "DELETED" or relaunch_failed_pod: del self._ps_pods_phase[ps_id] del self._ps_pod_name_to_id[pod_name] relaunch_ps = self._relaunch_deleted_live_ps else: logger.error("Unknown pod name: %s" % pod_name) return if relaunch_worker and worker_id >= 0: logger.info("Relaunching worker.") new_worker_id = self._next_worker_id() with self._lock: self._worker_pod_priority[ new_worker_id] = self._worker_pod_priority[worker_id] self._start_worker(new_worker_id) with self._lock: self._update_worker_addr(worker_id, new_worker_id) elif relaunch_ps: logger.info("Relaunching ps.") # Note: the ID and service address for relaunched parameter # server are intentionally left unchanged to support fault # tolerance. self._start_ps(ps_id)