def inference(self, data=None, post_process=None, **kwargs): """ predict the result for input data based on training knowledge. Parameters ---------- data : BaseDataSource inference sample, see `sedna.datasources.BaseDataSource` for more detail. post_process: function function or a registered method, effected after `estimator` prediction, like: label transform. kwargs: Dict parameters for `estimator` predict, Like: `ntree_limit` in Xgboost.XGBClassifier Returns ------- result : array_like results array, contain all inference results in each sample. is_unseen_task : bool `true` means detect an unseen task, `false` means not tasks : List tasks assigned to each sample. """ task_index_url = self.get_parameters("MODEL_URLS", self.config.task_index) index_url = self.estimator.estimator.task_index_url FileOps.download(task_index_url, index_url) res, tasks = self.estimator.predict(data=data, post_process=post_process, **kwargs) is_unseen_task = False if self.unseen_task_detect: try: if callable(self.unseen_task_detect): unseen_task_detect_algorithm = self.unseen_task_detect() else: unseen_task_detect_algorithm = ClassFactory.get_cls( ClassType.UTD, self.unseen_task_detect)() except ValueError as err: self.log.error("Lifelong learning " "Inference [UTD] : {}".format(err)) else: is_unseen_task = unseen_task_detect_algorithm( tasks=tasks, result=res, **self.unseen_task_detect_param) return res, is_unseen_task, tasks
def run(self): while self.run_flag: time.sleep(self.check_time) conf = FileOps.download(self.hot_update_conf) if not (conf and FileOps.exists(conf)): continue with open(conf, "r") as fin: try: conf_msg = json.load(fin) model_msg = conf_msg["model_config"] latest_version = str(model_msg["model_update_time"]) model = FileOps.download( model_msg["model_path"], FileOps.join_path(self.temp_path, f"model.{latest_version}")) except (json.JSONDecodeError, KeyError): LOGGER.error(f"fail to parse model hot update config: " f"{self.hot_update_conf}") continue if not (model and FileOps.exists(model)): continue if latest_version == self.version: continue self.version = latest_version with self.MODEL_MANIPULATION_SEM: LOGGER.info(f"Update model start with version {self.version}") try: self.production_estimator.load(model) status = K8sResourceKindStatus.COMPLETED.value LOGGER.info(f"Update model complete " f"with version {self.version}") except Exception as e: LOGGER.error(f"fail to update model: {e}") status = K8sResourceKindStatus.FAILED.value if self.callback: self.callback(task_info=None, status=status, kind="deploy") gc.collect()
def load(self, model_url="", model_name=None, **kwargs): mname = model_name or self.model_name if callable(self.estimator): varkw = self.parse_kwargs(self.estimator, **kwargs) self.estimator = self.estimator(**varkw) if model_url and os.path.isfile(model_url): self.model_save_path, mname = os.path.split(model_url) elif os.path.isfile(self.model_save_path): self.model_save_path, mname = os.path.split(self.model_save_path) model_path = FileOps.join_path(self.model_save_path, mname) if model_url: model_path = FileOps.download(model_url, model_path) self.has_load = True if not (hasattr(self.estimator, "load") and os.path.exists(model_path)): return return self.estimator.load(model_url=model_path)
def evaluate(self, data, post_process=None, **kwargs): """ evaluated the performance of each task from training, filter tasks based on the defined rules. Parameters ---------- data : BaseDataSource valid data, see `sedna.datasources.BaseDataSource` for more detail. kwargs: Dict parameters for `estimator` evaluate, Like: `ntree_limit` in Xgboost.XGBClassifier """ callback_func = None if callable(post_process): callback_func = post_process elif post_process is not None: callback_func = ClassFactory.get_cls(ClassType.CALLBACK, post_process) task_index_url = self.get_parameters("MODEL_URLS", self.config.task_index) index_url = self.estimator.estimator.task_index_url self.log.info( f"Download kb index from {task_index_url} to {index_url}") FileOps.download(task_index_url, index_url) res, tasks_detail = self.estimator.evaluate(data=data, **kwargs) drop_tasks = [] model_filter_operator = self.get_parameters("operator", ">") model_threshold = float(self.get_parameters('model_threshold', 0.1)) operator_map = { ">": lambda x, y: x > y, "<": lambda x, y: x < y, "=": lambda x, y: x == y, ">=": lambda x, y: x >= y, "<=": lambda x, y: x <= y, } if model_filter_operator not in operator_map: self.log.warn(f"operator {model_filter_operator} use to " f"compare is not allow, set to <") model_filter_operator = "<" operator_func = operator_map[model_filter_operator] for detail in tasks_detail: scores = detail.scores entry = detail.entry self.log.info(f"{entry} scores: {scores}") if any( map(lambda x: operator_func(float(x), model_threshold), scores.values())): self.log.warn( f"{entry} will not be deploy because all " f"scores {model_filter_operator} {model_threshold}") drop_tasks.append(entry) continue drop_task = ",".join(drop_tasks) index_file = self.kb_server.update_task_status(drop_task, new_status=0) if not index_file: self.log.error(f"KB update Fail !") index_file = str(index_url) self.log.info( f"upload kb index from {index_file} to {self.config.task_index}") FileOps.upload(index_file, self.config.task_index) task_info_res = self.estimator.model_info( self.config.task_index, result=res, relpath=self.config.data_path_prefix) self.report_task_info(None, K8sResourceKindStatus.COMPLETED.value, task_info_res, kind="eval") return callback_func(res) if callback_func else res