def __init__(self, estimator, config=None): self.config = BaseConfig() if config: self.config.from_json(config) self.log = LOGGER self.estimator = set_backend(estimator=estimator, config=self.config) self.job_kind = K8sResourceKind.DEFAULT.value self.job_name = self.config.job_name or self.config.service_name self.worker_name = self.config.worker_name or self.job_name self.namespace = self.config.namespace or self.job_name self.lc_server = self.config.lc_server if str(self.get_parameters("MODEL_HOT_UPDATE", "False")).lower() == "true": ModelLoadingThread(self.estimator, self.report_task_info).start()
def set_backend(estimator=None, config=None): """Create Trainer class""" if estimator is None: return if config is None: config = BaseConfig() use_cuda = False backend_type = os.getenv('BACKEND_TYPE', config.get("backend_type", "UNKNOWN")) backend_type = str(backend_type).upper() device_category = os.getenv('DEVICE_CATEGORY', config.get("device_category", "CPU")) if 'CUDA_VISIBLE_DEVICES' in os.environ: os.environ['DEVICE_CATEGORY'] = 'GPU' use_cuda = True else: os.environ['DEVICE_CATEGORY'] = device_category if backend_type == "TENSORFLOW": from sedna.backend.tensorflow import TFBackend as REGISTER elif backend_type == "KERAS": from sedna.backend.tensorflow import KerasBackend as REGISTER else: warnings.warn(f"{backend_type} Not Support yet, use itself") from sedna.backend.base import BackendBase as REGISTER model_save_url = config.get("model_url") base_model_save = config.get("base_model_url") or model_save_url model_save_name = config.get("model_name") return REGISTER(estimator=estimator, use_cuda=use_cuda, model_save_path=base_model_save, model_name=model_save_name, model_save_url=model_save_url)
def __init__(self): BaseConfig.__init__(self)
def __init__(self): BaseConfig.__init__(self) self.bind_ip = os.getenv("BIG_MODEL_BIND_IP", "0.0.0.0") self.bind_port = ( int(os.getenv("BIG_MODEL_BIND_PORT", "5000")) )
def __init__(self): BaseConfig.__init__(self) self.agg_ip = os.getenv("AGG_IP", "0.0.0.0") self.agg_port = int(os.getenv("AGG_PORT", "7363"))
def __init__(self): BaseConfig.__init__(self) self.model_urls = os.getenv("MODEL_URLS") self.base_model_url = os.getenv("BASE_MODEL_URL") self.saved_model_name = "model.pb"
def __init__(self): BaseConfig.__init__(self) self.bind_ip = os.getenv("AGG_BIND_IP", "0.0.0.0") self.bind_port = int(os.getenv("AGG_BIND_PORT", "7363")) self.participants_count = int(os.getenv("PARTICIPANTS_COUNT", "1"))
class JobBase: """ sedna feature base class """ parameters = Context def __init__(self, estimator, config=None): self.config = BaseConfig() if config: self.config.from_json(config) self.log = LOGGER self.estimator = set_backend(estimator=estimator, config=self.config) self.job_kind = K8sResourceKind.DEFAULT.value self.job_name = self.config.job_name or self.config.service_name self.worker_name = self.config.worker_name or self.job_name self.namespace = self.config.namespace or self.job_name self.lc_server = self.config.lc_server if str(self.get_parameters("MODEL_HOT_UPDATE", "False")).lower() == "true": ModelLoadingThread(self.estimator, self.report_task_info).start() @property def model_path(self): if os.path.isfile(self.config.model_url): return self.config.model_url return self.get_parameters('model_path') or FileOps.join_path( self.config.model_url, self.estimator.model_name) def train(self, **kwargs): raise NotImplementedError def inference(self, x=None, post_process=None, **kwargs): res = self.estimator.predict(x, kwargs=kwargs) callback_func = None if callable(post_process): callback_func = post_process elif post_process is not None: callback_func = ClassFactory.get_cls(ClassType.CALLBACK, post_process) return callback_func(res) if callback_func else res def evaluate(self, data, post_process=None, **kwargs): callback_func = None if callable(post_process): callback_func = post_process elif post_process is not None: callback_func = ClassFactory.get_cls(ClassType.CALLBACK, post_process) res = self.estimator.evaluate(data=data, **kwargs) return callback_func(res) if callback_func else res def get_parameters(self, param, default=None): return self.parameters.get_parameters(param=param, default=default) def report_task_info(self, task_info, status, results=None, kind="train"): message = { "name": self.worker_name, "namespace": self.namespace, "ownerName": self.job_name, "ownerKind": self.job_kind, "kind": kind, "status": status } if results: message["results"] = results if task_info: message["ownerInfo"] = task_info try: LCClient.send(self.lc_server, self.worker_name, message) except Exception as err: self.log.error(err)