def save(self, model_url="", model_name=None): mname = model_name or self.model_name if os.path.isfile(self.model_save_path): self.model_save_path, mname = os.path.split(self.model_save_path) FileOps.clean_folder([self.model_save_path], clean=False) model_path = FileOps.join_path(self.model_save_path, mname) self.estimator.save(model_path) if model_url and FileOps.exists(model_path): FileOps.upload(model_path, model_url) model_path = model_url return model_path
def evaluate(self, data, post_process=None, **kwargs): """ evaluated the performance of each task from training, filter tasks based on the defined rules. Parameters ---------- data : BaseDataSource valid data, see `sedna.datasources.BaseDataSource` for more detail. kwargs: Dict parameters for `estimator` evaluate, Like: `ntree_limit` in Xgboost.XGBClassifier """ callback_func = None if callable(post_process): callback_func = post_process elif post_process is not None: callback_func = ClassFactory.get_cls(ClassType.CALLBACK, post_process) task_index_url = self.get_parameters("MODEL_URLS", self.config.task_index) index_url = self.estimator.estimator.task_index_url self.log.info( f"Download kb index from {task_index_url} to {index_url}") FileOps.download(task_index_url, index_url) res, tasks_detail = self.estimator.evaluate(data=data, **kwargs) drop_tasks = [] model_filter_operator = self.get_parameters("operator", ">") model_threshold = float(self.get_parameters('model_threshold', 0.1)) operator_map = { ">": lambda x, y: x > y, "<": lambda x, y: x < y, "=": lambda x, y: x == y, ">=": lambda x, y: x >= y, "<=": lambda x, y: x <= y, } if model_filter_operator not in operator_map: self.log.warn(f"operator {model_filter_operator} use to " f"compare is not allow, set to <") model_filter_operator = "<" operator_func = operator_map[model_filter_operator] for detail in tasks_detail: scores = detail.scores entry = detail.entry self.log.info(f"{entry} scores: {scores}") if any( map(lambda x: operator_func(float(x), model_threshold), scores.values())): self.log.warn( f"{entry} will not be deploy because all " f"scores {model_filter_operator} {model_threshold}") drop_tasks.append(entry) continue drop_task = ",".join(drop_tasks) index_file = self.kb_server.update_task_status(drop_task, new_status=0) if not index_file: self.log.error(f"KB update Fail !") index_file = str(index_url) self.log.info( f"upload kb index from {index_file} to {self.config.task_index}") FileOps.upload(index_file, self.config.task_index) task_info_res = self.estimator.model_info( self.config.task_index, result=res, relpath=self.config.data_path_prefix) self.report_task_info(None, K8sResourceKindStatus.COMPLETED.value, task_info_res, kind="eval") return callback_func(res) if callback_func else res
def train(self, train_data, valid_data=None, post_process=None, action="initial", **kwargs): """ fit for update the knowledge based on training data. Parameters ---------- train_data : BaseDataSource Train data, see `sedna.datasources.BaseDataSource` for more detail. valid_data : BaseDataSource Valid data, BaseDataSource or None. post_process : function function or a registered method, callback after `estimator` train. action : str `update` or `initial` the knowledge base kwargs : Dict parameters for `estimator` training, Like: `early_stopping_rounds` in Xgboost.XGBClassifier Returns ------- train_history : object """ callback_func = None if post_process is not None: callback_func = ClassFactory.get_cls(ClassType.CALLBACK, post_process) res, task_index_url = self.estimator.train( train_data=train_data, valid_data=valid_data, **kwargs ) # todo: Distinguishing incremental update and fully overwrite if isinstance(task_index_url, str) and FileOps.exists(task_index_url): task_index = FileOps.load(task_index_url) else: task_index = task_index_url extractor = task_index['extractor'] task_groups = task_index['task_groups'] model_upload_key = {} for task in task_groups: model_file = task.model.model save_model = FileOps.join_path(self.config.output_url, os.path.basename(model_file)) if model_file not in model_upload_key: model_upload_key[model_file] = FileOps.upload( model_file, save_model) model_file = model_upload_key[model_file] try: model = self.kb_server.upload_file(save_model) except Exception as err: self.log.error( f"Upload task model of {model_file} fail: {err}") model = set_backend( estimator=self.estimator.estimator.base_model) model.load(model_file) task.model.model = model for _task in task.tasks: sample_dir = FileOps.join_path( self.config.output_url, f"{_task.samples.data_type}_{_task.entry}.sample") task.samples.save(sample_dir) try: sample_dir = self.kb_server.upload_file(sample_dir) except Exception as err: self.log.error( f"Upload task samples of {_task.entry} fail: {err}") _task.samples.data_url = sample_dir save_extractor = FileOps.join_path( self.config.output_url, KBResourceConstant.TASK_EXTRACTOR_NAME.value) extractor = FileOps.dump(extractor, save_extractor) try: extractor = self.kb_server.upload_file(extractor) except Exception as err: self.log.error(f"Upload task extractor fail: {err}") task_info = {"task_groups": task_groups, "extractor": extractor} fd, name = tempfile.mkstemp() FileOps.dump(task_info, name) index_file = self.kb_server.update_db(name) if not index_file: self.log.error(f"KB update Fail !") index_file = name FileOps.upload(index_file, self.config.task_index) task_info_res = self.estimator.model_info( self.config.task_index, relpath=self.config.data_path_prefix) self.report_task_info(None, K8sResourceKindStatus.COMPLETED.value, task_info_res) self.log.info(f"Lifelong learning Train task Finished, " f"KB idnex save in {self.config.task_index}") return callback_func(self.estimator, res) if callback_func else res