Exemplo n.º 1
0
 def __init__(self, estimator, config=None):
     self.config = BaseConfig()
     if config:
         self.config.from_json(config)
     self.log = LOGGER
     self.estimator = set_backend(estimator=estimator, config=self.config)
     self.job_kind = K8sResourceKind.DEFAULT.value
     self.job_name = self.config.job_name or self.config.service_name
     self.worker_name = self.config.worker_name or self.job_name
     self.namespace = self.config.namespace or self.job_name
     self.lc_server = self.config.lc_server
     if str(self.get_parameters("MODEL_HOT_UPDATE",
                                "False")).lower() == "true":
         ModelLoadingThread(self.estimator, self.report_task_info).start()
Exemplo n.º 2
0
def set_backend(estimator=None, config=None):
    """Create Trainer class"""
    if estimator is None:
        return
    if config is None:
        config = BaseConfig()
    use_cuda = False
    backend_type = os.getenv('BACKEND_TYPE',
                             config.get("backend_type", "UNKNOWN"))
    backend_type = str(backend_type).upper()
    device_category = os.getenv('DEVICE_CATEGORY',
                                config.get("device_category", "CPU"))
    if 'CUDA_VISIBLE_DEVICES' in os.environ:
        os.environ['DEVICE_CATEGORY'] = 'GPU'
        use_cuda = True
    else:
        os.environ['DEVICE_CATEGORY'] = device_category

    if backend_type == "TENSORFLOW":
        from sedna.backend.tensorflow import TFBackend as REGISTER
    elif backend_type == "KERAS":
        from sedna.backend.tensorflow import KerasBackend as REGISTER
    else:
        warnings.warn(f"{backend_type} Not Support yet, use itself")
        from sedna.backend.base import BackendBase as REGISTER
    model_save_url = config.get("model_url")
    base_model_save = config.get("base_model_url") or model_save_url
    model_save_name = config.get("model_name")
    return REGISTER(estimator=estimator,
                    use_cuda=use_cuda,
                    model_save_path=base_model_save,
                    model_name=model_save_name,
                    model_save_url=model_save_url)
Exemplo n.º 3
0
 def __init__(self):
     BaseConfig.__init__(self)
Exemplo n.º 4
0
 def __init__(self):
     BaseConfig.__init__(self)
     self.bind_ip = os.getenv("BIG_MODEL_BIND_IP", "0.0.0.0")
     self.bind_port = (
         int(os.getenv("BIG_MODEL_BIND_PORT", "5000"))
     )
Exemplo n.º 5
0
 def __init__(self):
     BaseConfig.__init__(self)
     self.agg_ip = os.getenv("AGG_IP", "0.0.0.0")
     self.agg_port = int(os.getenv("AGG_PORT", "7363"))
Exemplo n.º 6
0
 def __init__(self):
     BaseConfig.__init__(self)
     self.model_urls = os.getenv("MODEL_URLS")
     self.base_model_url = os.getenv("BASE_MODEL_URL")
     self.saved_model_name = "model.pb"
Exemplo n.º 7
0
    def __init__(self):
        BaseConfig.__init__(self)

        self.bind_ip = os.getenv("AGG_BIND_IP", "0.0.0.0")
        self.bind_port = int(os.getenv("AGG_BIND_PORT", "7363"))
        self.participants_count = int(os.getenv("PARTICIPANTS_COUNT", "1"))
Exemplo n.º 8
0
class JobBase:
    """ sedna feature base class """
    parameters = Context

    def __init__(self, estimator, config=None):
        self.config = BaseConfig()
        if config:
            self.config.from_json(config)
        self.log = LOGGER
        self.estimator = set_backend(estimator=estimator, config=self.config)
        self.job_kind = K8sResourceKind.DEFAULT.value
        self.job_name = self.config.job_name or self.config.service_name
        self.worker_name = self.config.worker_name or self.job_name
        self.namespace = self.config.namespace or self.job_name
        self.lc_server = self.config.lc_server
        if str(self.get_parameters("MODEL_HOT_UPDATE",
                                   "False")).lower() == "true":
            ModelLoadingThread(self.estimator, self.report_task_info).start()

    @property
    def model_path(self):
        if os.path.isfile(self.config.model_url):
            return self.config.model_url
        return self.get_parameters('model_path') or FileOps.join_path(
            self.config.model_url, self.estimator.model_name)

    def train(self, **kwargs):
        raise NotImplementedError

    def inference(self, x=None, post_process=None, **kwargs):

        res = self.estimator.predict(x, kwargs=kwargs)
        callback_func = None
        if callable(post_process):
            callback_func = post_process
        elif post_process is not None:
            callback_func = ClassFactory.get_cls(ClassType.CALLBACK,
                                                 post_process)
        return callback_func(res) if callback_func else res

    def evaluate(self, data, post_process=None, **kwargs):
        callback_func = None
        if callable(post_process):
            callback_func = post_process
        elif post_process is not None:
            callback_func = ClassFactory.get_cls(ClassType.CALLBACK,
                                                 post_process)
        res = self.estimator.evaluate(data=data, **kwargs)
        return callback_func(res) if callback_func else res

    def get_parameters(self, param, default=None):
        return self.parameters.get_parameters(param=param, default=default)

    def report_task_info(self, task_info, status, results=None, kind="train"):
        message = {
            "name": self.worker_name,
            "namespace": self.namespace,
            "ownerName": self.job_name,
            "ownerKind": self.job_kind,
            "kind": kind,
            "status": status
        }
        if results:
            message["results"] = results
        if task_info:
            message["ownerInfo"] = task_info
        try:
            LCClient.send(self.lc_server, self.worker_name, message)
        except Exception as err:
            self.log.error(err)