def _event_cb(self, event): evt_obj = event.get("object") evt_type = event.get("type") if not evt_obj or not evt_type: logger.error("Event doesn't have object or type: %s" % event) return pod_name = evt_obj.metadata.name phase = evt_obj.status.phase logger.info("Got event %s, phase %s for pod: %s" % (evt_type, phase, pod_name)) relaunch = False with self._lock: worker_id = self._pod_name_to_id.get(pod_name) if (worker_id is None and pod_name != self._k8s_client.get_master_pod_name()): logger.error("Unknown worker pod name: %s" % pod_name) return self._pods_phase[worker_id] = (pod_name, phase) if evt_type == "DELETED": del self._pods_phase[worker_id] del self._pod_name_to_id[pod_name] self._task_d.recover_tasks(worker_id) # If a deleted pod was not "Succeeded", relaunch a worker. relaunch = (self._relaunch_deleted_live_worker and phase != "Succeeded") if relaunch: logger.info("Relaunching worker.") self._start_worker(self._next_worker_id())
def GetModel(self, request, _): if not self._use_async: self._validate_model_version(request.version) if ( request.method == elasticdl_pb2.MINIMUM or request.version == self._version ): if self._use_async: res = self._get_model_no_lock() else: with self._lock: res = self._get_model_no_lock() return res # Read from checkpoint for the fixed version model pb_model = elasticdl_pb2.Model() try: pb_model = self._checkpoint_service.get_checkpoint_model( request.version ) except Exception: logger.error( "Failed to fetch checkpoint model for " "model version {}".format(request.version) ) return pb_model
def _remove_worker(self, worker_id): logger.info("Removing worker: %d", worker_id) with self._lock: if worker_id not in self._pods_phase: logger.error("Unknown worker id: %s" % worker_id) return # TODO: change _k8s_client to accept pod name instead of worker id. self._k8s_client.delete_worker(worker_id)
def _get_embedding_cluster(self): startup_nodes = [{ "host": ip, "port": "%d" % port } for ip, port_list in self._embedding_service_endpoint.items() for port in port_list] try: redis_cluster = RedisCluster(startup_nodes=startup_nodes, decode_responses=False) except Exception as e: logger.error(e) return None else: return redis_cluster
def report_evaluation_metrics(self, evaluation_version, evaluation_metrics): if (self.model_version >= 0 and evaluation_version != self.model_version): logger.error( "Drop a wrong version evaluation: request %d, receive %d" % (self.model_version, evaluation_version)) return False for k, v in evaluation_metrics.items(): if k in self._evaluation_metrics: self._evaluation_metrics[k] += tensor_to_ndarray(v) else: self._evaluation_metrics[k] = np.copy(tensor_to_ndarray(v)) self._completed_minibatches += 1 return True
def get_dataset(self): """ Return a RecordIO dataset, or None if no more data. """ if self._pending_dataset: if self._pending_tasks_with_counts: logger.error( "Cannot get new dataset when there are pending tasks") return None self._reset() ds = tf.data.Dataset.from_generator(self._gen, (tf.string), (tf.TensorShape([]))) self._pending_dataset = False return ds else: return None
def _save_checkpoint(self, locking, is_eval_checkpoint): try: logger.info("Saving checkpoint for model version %d" % self._version) if locking: self._lock.acquire() pb_model = self._get_model_no_lock() self._checkpoint_service.save(self._version, pb_model, is_eval_checkpoint) checkpoint_version = self._version if locking: self._lock.release() return checkpoint_version except Exception: logger.error( "Failed to save checkpoint file for model version %d" % self._version)