Пример #1
0
    def inference(self, data=None, post_process=None, **kwargs):
        """
        predict the result for input data based on training knowledge.

        Parameters
        ----------
        data : BaseDataSource
            inference sample, see `sedna.datasources.BaseDataSource` for
            more detail.
        post_process: function
            function or a registered method,  effected after `estimator`
            prediction, like: label transform.
        kwargs: Dict
            parameters for `estimator` predict, Like:
            `ntree_limit` in Xgboost.XGBClassifier

        Returns
        -------
        result : array_like
            results array, contain all inference results in each sample.
        is_unseen_task : bool
            `true` means detect an unseen task, `false` means not
        tasks : List
            tasks assigned to each sample.
        """
        task_index_url = self.get_parameters("MODEL_URLS",
                                             self.config.task_index)
        index_url = self.estimator.estimator.task_index_url
        FileOps.download(task_index_url, index_url)
        res, tasks = self.estimator.predict(data=data,
                                            post_process=post_process,
                                            **kwargs)

        is_unseen_task = False
        if self.unseen_task_detect:

            try:
                if callable(self.unseen_task_detect):
                    unseen_task_detect_algorithm = self.unseen_task_detect()
                else:
                    unseen_task_detect_algorithm = ClassFactory.get_cls(
                        ClassType.UTD, self.unseen_task_detect)()
            except ValueError as err:
                self.log.error("Lifelong learning "
                               "Inference [UTD] : {}".format(err))
            else:
                is_unseen_task = unseen_task_detect_algorithm(
                    tasks=tasks, result=res, **self.unseen_task_detect_param)
        return res, is_unseen_task, tasks
Пример #2
0
 def run(self):
     while self.run_flag:
         time.sleep(self.check_time)
         conf = FileOps.download(self.hot_update_conf)
         if not (conf and FileOps.exists(conf)):
             continue
         with open(conf, "r") as fin:
             try:
                 conf_msg = json.load(fin)
                 model_msg = conf_msg["model_config"]
                 latest_version = str(model_msg["model_update_time"])
                 model = FileOps.download(
                     model_msg["model_path"],
                     FileOps.join_path(self.temp_path,
                                       f"model.{latest_version}"))
             except (json.JSONDecodeError, KeyError):
                 LOGGER.error(f"fail to parse model hot update config: "
                              f"{self.hot_update_conf}")
                 continue
         if not (model and FileOps.exists(model)):
             continue
         if latest_version == self.version:
             continue
         self.version = latest_version
         with self.MODEL_MANIPULATION_SEM:
             LOGGER.info(f"Update model start with version {self.version}")
             try:
                 self.production_estimator.load(model)
                 status = K8sResourceKindStatus.COMPLETED.value
                 LOGGER.info(f"Update model complete "
                             f"with version {self.version}")
             except Exception as e:
                 LOGGER.error(f"fail to update model: {e}")
                 status = K8sResourceKindStatus.FAILED.value
             if self.callback:
                 self.callback(task_info=None, status=status, kind="deploy")
         gc.collect()
Пример #3
0
 def load(self, model_url="", model_name=None, **kwargs):
     mname = model_name or self.model_name
     if callable(self.estimator):
         varkw = self.parse_kwargs(self.estimator, **kwargs)
         self.estimator = self.estimator(**varkw)
     if model_url and os.path.isfile(model_url):
         self.model_save_path, mname = os.path.split(model_url)
     elif os.path.isfile(self.model_save_path):
         self.model_save_path, mname = os.path.split(self.model_save_path)
     model_path = FileOps.join_path(self.model_save_path, mname)
     if model_url:
         model_path = FileOps.download(model_url, model_path)
     self.has_load = True
     if not (hasattr(self.estimator, "load")
             and os.path.exists(model_path)):
         return
     return self.estimator.load(model_url=model_path)
Пример #4
0
    def evaluate(self, data, post_process=None, **kwargs):
        """
        evaluated the performance of each task from training, filter tasks
        based on the defined rules.

        Parameters
        ----------
        data : BaseDataSource
            valid data, see `sedna.datasources.BaseDataSource` for more detail.
        kwargs: Dict
            parameters for `estimator` evaluate, Like:
            `ntree_limit` in Xgboost.XGBClassifier
        """

        callback_func = None
        if callable(post_process):
            callback_func = post_process
        elif post_process is not None:
            callback_func = ClassFactory.get_cls(ClassType.CALLBACK,
                                                 post_process)
        task_index_url = self.get_parameters("MODEL_URLS",
                                             self.config.task_index)
        index_url = self.estimator.estimator.task_index_url
        self.log.info(
            f"Download kb index from {task_index_url} to {index_url}")
        FileOps.download(task_index_url, index_url)
        res, tasks_detail = self.estimator.evaluate(data=data, **kwargs)
        drop_tasks = []

        model_filter_operator = self.get_parameters("operator", ">")
        model_threshold = float(self.get_parameters('model_threshold', 0.1))

        operator_map = {
            ">": lambda x, y: x > y,
            "<": lambda x, y: x < y,
            "=": lambda x, y: x == y,
            ">=": lambda x, y: x >= y,
            "<=": lambda x, y: x <= y,
        }
        if model_filter_operator not in operator_map:
            self.log.warn(f"operator {model_filter_operator} use to "
                          f"compare is not allow, set to <")
            model_filter_operator = "<"
        operator_func = operator_map[model_filter_operator]

        for detail in tasks_detail:
            scores = detail.scores
            entry = detail.entry
            self.log.info(f"{entry} scores: {scores}")
            if any(
                    map(lambda x: operator_func(float(x), model_threshold),
                        scores.values())):
                self.log.warn(
                    f"{entry} will not be deploy because all "
                    f"scores {model_filter_operator} {model_threshold}")
                drop_tasks.append(entry)
                continue
        drop_task = ",".join(drop_tasks)
        index_file = self.kb_server.update_task_status(drop_task, new_status=0)
        if not index_file:
            self.log.error(f"KB update Fail !")
            index_file = str(index_url)
        self.log.info(
            f"upload kb index from {index_file} to {self.config.task_index}")
        FileOps.upload(index_file, self.config.task_index)
        task_info_res = self.estimator.model_info(
            self.config.task_index,
            result=res,
            relpath=self.config.data_path_prefix)
        self.report_task_info(None,
                              K8sResourceKindStatus.COMPLETED.value,
                              task_info_res,
                              kind="eval")
        return callback_func(res) if callback_func else res