コード例 #1
0
    def inference(self, data=None, post_process=None, **kwargs):
        """
        Inference task for JointInference

        Parameters
        ----------
        data: BaseDataSource
            datasource use for inference, see
            `sedna.datasources.BaseDataSource` for more detail.
        post_process: function or a registered method
            effected after `estimator` inference.
        kwargs: Dict
            parameters for `estimator` inference,
            Like:  `ntree_limit` in Xgboost.XGBClassifier

        Returns
        -------
        inference result
        """

        callback_func = None
        if callable(post_process):
            callback_func = post_process
        elif post_process is not None:
            callback_func = ClassFactory.get_cls(ClassType.CALLBACK,
                                                 post_process)

        res = self.estimator.predict(data, **kwargs)
        if callback_func:
            res = callback_func(res)
        return res
コード例 #2
0
ファイル: base.py プロジェクト: XinYao1994/sedna
 def evaluate(self, data, post_process=None, **kwargs):
     callback_func = None
     if callable(post_process):
         callback_func = post_process
     elif post_process is not None:
         callback_func = ClassFactory.get_cls(ClassType.CALLBACK,
                                              post_process)
     res = self.estimator.evaluate(data=data, **kwargs)
     return callback_func(res) if callback_func else res
コード例 #3
0
 def __init__(self, **kwargs):
     super(Aggregator, self).__init__()
     self.exit_round = int(kwargs.get("exit_round", 3))
     aggregation = kwargs.get("aggregation", "FedAvg")
     self.aggregation = ClassFactory.get_cls(ClassType.FL_AGG, aggregation)
     if callable(self.aggregation):
         self.aggregation = self.aggregation()
     self.participants_count = int(kwargs.get("participants_count", "1"))
     self.current_round = 0
コード例 #4
0
 def _task_relationship_discovery(self, tasks):
     """
     Merge tasks from task_definition
     """
     method_name = self.task_relationship_discovery.get("method")
     extend_param = self._parse_param(
         self.task_relationship_discovery.get("param"))
     method_cls = ClassFactory.get_cls(ClassType.MTL,
                                       method_name)(**extend_param)
     return method_cls(tasks)
コード例 #5
0
 def _task_definition(self, samples):
     """
     Task attribute extractor and multi-task definition
     """
     method_name = self.task_definition.get("method",
                                            "TaskDefinitionByDataAttr")
     extend_param = self._parse_param(self.task_definition.get("param"))
     method_cls = ClassFactory.get_cls(ClassType.MTL,
                                       method_name)(**extend_param)
     return method_cls(samples)
コード例 #6
0
 def _inference_integrate(self, tasks):
     """
     Aggregate inference results from target models
     """
     method_name = self.inference_integrate.get("method")
     extend_param = self._parse_param(self.inference_integrate.get("param"))
     method_cls = ClassFactory.get_cls(ClassType.MTL,
                                       method_name)(models=self.models,
                                                    **extend_param)
     return method_cls(tasks=tasks) if method_cls else tasks
コード例 #7
0
 def _task_remodeling(self, samples, mappings):
     """
     Remodeling tasks from task mining
     """
     method_name = self.task_remodeling.get("method")
     extend_param = self._parse_param(self.task_remodeling.get("param"))
     method_cls = ClassFactory.get_cls(ClassType.MTL,
                                       method_name)(models=self.models,
                                                    **extend_param)
     return method_cls(samples=samples, mappings=mappings)
コード例 #8
0
ファイル: base.py プロジェクト: XinYao1994/sedna
    def inference(self, x=None, post_process=None, **kwargs):

        res = self.estimator.predict(x, kwargs=kwargs)
        callback_func = None
        if callable(post_process):
            callback_func = post_process
        elif post_process is not None:
            callback_func = ClassFactory.get_cls(ClassType.CALLBACK,
                                                 post_process)
        return callback_func(res) if callback_func else res
コード例 #9
0
    def predict(self, data: BaseDataSource, post_process=None, **kwargs):
        """
        predict the result for input data based on training knowledge.

        Parameters
        ----------
        data : BaseDataSource
            inference sample, see `sedna.datasources.BaseDataSource` for
            more detail.
        post_process: function
            function or a registered method,  effected after `estimator`
            prediction, like: label transform.
        kwargs: Dict
            parameters for `estimator` predict, Like:
            `ntree_limit` in Xgboost.XGBClassifier

        Returns
        -------
        result : array_like
            results array, contain all inference results in each sample.
        tasks : List
            tasks assigned to each sample.
        """

        if not (self.models and self.extractor):
            self.load()

        data, mappings = self._task_mining(samples=data)
        samples, models = self._task_remodeling(samples=data,
                                                mappings=mappings)

        callback = None
        if post_process:
            callback = ClassFactory.get_cls(ClassType.CALLBACK, post_process)()

        tasks = []
        for inx, df in enumerate(samples):
            m = models[inx]
            if not isinstance(m, Model):
                continue
            if isinstance(m.model, str):
                evaluator = set_backend(estimator=self.base_model)
                evaluator.load(m.model)
            else:
                evaluator = m.model
            pred = evaluator.predict(df.x, **kwargs)
            if callable(callback):
                pred = callback(pred, df)
            task = Task(entry=m.entry, samples=df)
            task.result = pred
            task.model = m
            tasks.append(task)
        res = self._inference_integrate(tasks)
        return res, tasks
コード例 #10
0
    def inference(self, data=None, post_process=None, **kwargs):
        """
        Inference task with JointInference

        Parameters
        ----------
        data: BaseDataSource
            datasource use for inference, see
            `sedna.datasources.BaseDataSource` for more detail.
        post_process: function or a registered method
            effected after `estimator` inference.
        kwargs: Dict
            parameters for `estimator` inference,
            Like:  `ntree_limit` in Xgboost.XGBClassifier

        Returns
        -------
        if is hard sample : bool
        inference result : object
        result from little-model : object
        result from big-model: object
        """

        callback_func = None
        if callable(post_process):
            callback_func = post_process
        elif post_process is not None:
            callback_func = ClassFactory.get_cls(ClassType.CALLBACK,
                                                 post_process)

        res = self.estimator.predict(data, **kwargs)
        edge_result = deepcopy(res)

        if callback_func:
            res = callback_func(res)

        self.lc_reporter.update_for_edge_inference()

        is_hard_example = False
        cloud_result = None

        if self.hard_example_mining_algorithm:
            is_hard_example = self.hard_example_mining_algorithm(res)
            if is_hard_example:
                try:
                    cloud_data = self.cloud.inference(
                        data.tolist(), post_process=post_process, **kwargs)
                    cloud_result = cloud_data["result"]
                except Exception as err:
                    self.log.error(f"get cloud result error: {err}")
                else:
                    res = cloud_result
                self.lc_reporter.update_for_collaboration_inference()
        return [is_hard_example, res, edge_result, cloud_result]
コード例 #11
0
ファイル: __init__.py プロジェクト: XinYao1994/sedna
 def __init__(self, data_type="train", func=None):
     self.data_type = data_type  # sample type: train/eval/test
     self.process_func = None
     if callable(func):
         self.process_func = func
     elif func:
         self.process_func = ClassFactory.get_cls(ClassType.CALLBACK,
                                                  func)()
     self.x = None  # sample feature
     self.y = None  # sample label
     self.meta_attr = None  # special in lifelong learning
コード例 #12
0
    def evaluate(self, data, post_process=None, **kwargs):
        """
        Evaluate task for IncrementalLearning

        Parameters
        ----------
        data: BaseDataSource
            datasource use for evaluation, see
            `sedna.datasources.BaseDataSource` for more detail.
        post_process: function or a registered method
            effected after `estimator` evaluation.
        kwargs: Dict
            parameters for `estimator` evaluate,
            Like:  `metric_name` in Xgboost.XGBClassifier

        Returns
        -------
        evaluate metrics : List
        """

        callback_func = None
        if callable(post_process):
            callback_func = post_process
        elif post_process:
            callback_func = ClassFactory.get_cls(ClassType.CALLBACK,
                                                 post_process)
        final_res = []
        all_models = []
        if self.model_urls:
            all_models = self.model_urls.split(";")
        elif self.config.model_url:
            all_models.append(self.config.model_url)
        for model_url in all_models:
            if not model_url.strip():
                continue
            self.estimator.model_save_path = model_url
            res = self.estimator.evaluate(data=data,
                                          model_path=model_url,
                                          **kwargs)
            if callback_func:
                res = callback_func(res)
            self.log.info(f"Evaluation with {model_url} : {res} ")
            task_info_res = self.estimator.model_info(
                model_url, result=res, relpath=self.config.data_path_prefix)
            if isinstance(task_info_res, (list, tuple)) and len(task_info_res):
                task_info_res = list(task_info_res)[0]
            final_res.append(task_info_res)
        self.report_task_info(None,
                              K8sResourceKindStatus.COMPLETED.value,
                              final_res,
                              kind="eval")

        return final_res
コード例 #13
0
    def inference(self, data=None, post_process=None, **kwargs):
        """
        predict the result for input data based on training knowledge.

        Parameters
        ----------
        data : BaseDataSource
            inference sample, see `sedna.datasources.BaseDataSource` for
            more detail.
        post_process: function
            function or a registered method,  effected after `estimator`
            prediction, like: label transform.
        kwargs: Dict
            parameters for `estimator` predict, Like:
            `ntree_limit` in Xgboost.XGBClassifier

        Returns
        -------
        result : array_like
            results array, contain all inference results in each sample.
        is_unseen_task : bool
            `true` means detect an unseen task, `false` means not
        tasks : List
            tasks assigned to each sample.
        """
        task_index_url = self.get_parameters("MODEL_URLS",
                                             self.config.task_index)
        index_url = self.estimator.estimator.task_index_url
        FileOps.download(task_index_url, index_url)
        res, tasks = self.estimator.predict(data=data,
                                            post_process=post_process,
                                            **kwargs)

        is_unseen_task = False
        if self.unseen_task_detect:

            try:
                if callable(self.unseen_task_detect):
                    unseen_task_detect_algorithm = self.unseen_task_detect()
                else:
                    unseen_task_detect_algorithm = ClassFactory.get_cls(
                        ClassType.UTD, self.unseen_task_detect)()
            except ValueError as err:
                self.log.error("Lifelong learning "
                               "Inference [UTD] : {}".format(err))
            else:
                is_unseen_task = unseen_task_detect_algorithm(
                    tasks=tasks, result=res, **self.unseen_task_detect_param)
        return res, is_unseen_task, tasks
コード例 #14
0
    def __init__(self, estimator, hard_example_mining: dict = None):
        super(IncrementalLearning, self).__init__(estimator=estimator)

        self.model_urls = self.get_parameters(
            "MODEL_URLS")  # use in evaluation
        self.job_kind = K8sResourceKind.INCREMENTAL_JOB.value
        FileOps.clean_folder([self.config.model_url], clean=False)
        self.hard_example_mining_algorithm = None
        if not hard_example_mining:
            hard_example_mining = self.get_hem_algorithm_from_config()
        if hard_example_mining:
            hem = hard_example_mining.get("method", "IBT")
            hem_parameters = hard_example_mining.get("param", {})
            self.hard_example_mining_algorithm = ClassFactory.get_cls(
                ClassType.HEM, hem)(**hem_parameters)
コード例 #15
0
    def _task_mining(self, samples):
        """
        Mining tasks of inference sample base on task attribute extractor
        """
        method_name = self.task_mining.get("method")
        extend_param = self._parse_param(self.task_mining.get("param"))

        if not method_name:
            task_definition = self.task_definition.get(
                "method", "TaskDefinitionByDataAttr")
            method_name = self._method_pair.get(task_definition,
                                                'TaskMiningByDataAttr')
            extend_param = self._parse_param(self.task_definition.get("param"))
        method_cls = ClassFactory.get_cls(ClassType.MTL, method_name)(
            task_extractor=self.extractor, **extend_param)
        return method_cls(samples=samples)
コード例 #16
0
    def __init__(self, estimator, aggregation="FedAvg"):

        protocol = Context.get_parameters("AGG_PROTOCOL", "ws")
        agg_ip = Context.get_parameters("AGG_IP", "127.0.0.1")
        agg_port = int(Context.get_parameters("AGG_PORT", "7363"))
        agg_uri = f"{protocol}://{agg_ip}:{agg_port}/{aggregation}"
        config = dict(
            protocol=protocol,
            agg_ip=agg_ip,
            agg_port=agg_port,
            agg_uri=agg_uri
        )
        super(FederatedLearning, self).__init__(
            estimator=estimator, config=config)
        self.aggregation = ClassFactory.get_cls(ClassType.FL_AGG, aggregation)

        connect_timeout = int(Context.get_parameters("CONNECT_TIMEOUT", "300"))
        self.node = None
        self.register(timeout=connect_timeout)
コード例 #17
0
    def inference(self, data=None, post_process=None, **kwargs):
        """
        Inference task for IncrementalLearning

        Parameters
        ----------
        data: BaseDataSource
            datasource use for inference, see
            `sedna.datasources.BaseDataSource` for more detail.
        post_process: function or a registered method
            effected after `estimator` inference.
        kwargs: Dict
            parameters for `estimator` inference,
            Like:  `ntree_limit` in Xgboost.XGBClassifier

        Returns
        -------
        inference result : object
        result after post_process : object
        if is hard sample : bool
        """

        if not self.estimator.has_load:
            self.estimator.load(self.model_path)

        callback_func = None
        if callable(post_process):
            callback_func = post_process
        elif post_process is not None:
            callback_func = ClassFactory.get_cls(ClassType.CALLBACK,
                                                 post_process)
        infer_res = self.estimator.predict(data, **kwargs)
        if callback_func:
            res = callback_func(
                deepcopy(infer_res)  # Prevent infer_result from being modified
            )
        else:
            res = infer_res
        is_hard_example = False

        if self.hard_example_mining_algorithm:
            is_hard_example = self.hard_example_mining_algorithm(res)
        return infer_res, res, is_hard_example
コード例 #18
0
    def __init__(self, estimator=None, hard_example_mining: dict = None):
        super(JointInference, self).__init__(estimator=estimator)
        self.job_kind = K8sResourceKind.JOINT_INFERENCE_SERVICE.value
        self.local_ip = get_host_ip()
        self.remote_ip = self.get_parameters("BIG_MODEL_IP", self.local_ip)
        self.port = int(self.get_parameters("BIG_MODEL_PORT", "5000"))

        report_msg = {
            "name": self.worker_name,
            "namespace": self.config.namespace,
            "ownerName": self.job_name,
            "ownerKind": self.job_kind,
            "kind": "inference",
            "results": []
        }
        period_interval = int(self.get_parameters("LC_PERIOD", "30"))
        self.lc_reporter = LCReporter(lc_server=self.config.lc_server,
                                      message=report_msg,
                                      period_interval=period_interval)
        self.lc_reporter.setDaemon(True)
        self.lc_reporter.start()

        if callable(self.estimator):
            self.estimator = self.estimator()
        if not os.path.exists(self.model_path):
            raise FileExistsError(f"{self.model_path} miss")
        else:
            self.estimator.load(self.model_path)
        self.cloud = ModelClient(service_name=self.job_name,
                                 host=self.remote_ip,
                                 port=self.port)
        self.hard_example_mining_algorithm = None
        if not hard_example_mining:
            hard_example_mining = self.get_hem_algorithm_from_config()
        if hard_example_mining:
            hem = hard_example_mining.get("method", "IBT")
            hem_parameters = hard_example_mining.get("param", {})
            self.hard_example_mining_algorithm = ClassFactory.get_cls(
                ClassType.HEM, hem)(**hem_parameters)
コード例 #19
0
    def train(self, train_data, valid_data=None, post_process=None, **kwargs):
        """
        Training task for IncrementalLearning

        Parameters
        ----------
        train_data: BaseDataSource
            datasource use for train, see
            `sedna.datasources.BaseDataSource` for more detail.
        valid_data:  BaseDataSource
            datasource use for evaluation, see
            `sedna.datasources.BaseDataSource` for more detail.
        post_process: function or a registered method
            effected after `estimator` training.
        kwargs: Dict
            parameters for `estimator` training,
            Like:  `early_stopping_rounds` in Xgboost.XGBClassifier

        Returns
        -------
        estimator
        """

        callback_func = None
        if post_process is not None:
            callback_func = ClassFactory.get_cls(ClassType.CALLBACK,
                                                 post_process)

        res = self.estimator.train(train_data=train_data,
                                   valid_data=valid_data,
                                   **kwargs)
        model_paths = self.estimator.save(self.model_path)
        task_info_res = self.estimator.model_info(
            model_paths, result=res, relpath=self.config.data_path_prefix)
        self.report_task_info(None, K8sResourceKindStatus.COMPLETED.value,
                              task_info_res)
        return callback_func(
            self.estimator) if callback_func else self.estimator
コード例 #20
0
    def evaluate(self, data, post_process=None, **kwargs):
        """
        evaluated the performance of each task from training, filter tasks
        based on the defined rules.

        Parameters
        ----------
        data : BaseDataSource
            valid data, see `sedna.datasources.BaseDataSource` for more detail.
        kwargs: Dict
            parameters for `estimator` evaluate, Like:
            `ntree_limit` in Xgboost.XGBClassifier
        """

        callback_func = None
        if callable(post_process):
            callback_func = post_process
        elif post_process is not None:
            callback_func = ClassFactory.get_cls(ClassType.CALLBACK,
                                                 post_process)
        task_index_url = self.get_parameters("MODEL_URLS",
                                             self.config.task_index)
        index_url = self.estimator.estimator.task_index_url
        self.log.info(
            f"Download kb index from {task_index_url} to {index_url}")
        FileOps.download(task_index_url, index_url)
        res, tasks_detail = self.estimator.evaluate(data=data, **kwargs)
        drop_tasks = []

        model_filter_operator = self.get_parameters("operator", ">")
        model_threshold = float(self.get_parameters('model_threshold', 0.1))

        operator_map = {
            ">": lambda x, y: x > y,
            "<": lambda x, y: x < y,
            "=": lambda x, y: x == y,
            ">=": lambda x, y: x >= y,
            "<=": lambda x, y: x <= y,
        }
        if model_filter_operator not in operator_map:
            self.log.warn(f"operator {model_filter_operator} use to "
                          f"compare is not allow, set to <")
            model_filter_operator = "<"
        operator_func = operator_map[model_filter_operator]

        for detail in tasks_detail:
            scores = detail.scores
            entry = detail.entry
            self.log.info(f"{entry} scores: {scores}")
            if any(
                    map(lambda x: operator_func(float(x), model_threshold),
                        scores.values())):
                self.log.warn(
                    f"{entry} will not be deploy because all "
                    f"scores {model_filter_operator} {model_threshold}")
                drop_tasks.append(entry)
                continue
        drop_task = ",".join(drop_tasks)
        index_file = self.kb_server.update_task_status(drop_task, new_status=0)
        if not index_file:
            self.log.error(f"KB update Fail !")
            index_file = str(index_url)
        self.log.info(
            f"upload kb index from {index_file} to {self.config.task_index}")
        FileOps.upload(index_file, self.config.task_index)
        task_info_res = self.estimator.model_info(
            self.config.task_index,
            result=res,
            relpath=self.config.data_path_prefix)
        self.report_task_info(None,
                              K8sResourceKindStatus.COMPLETED.value,
                              task_info_res,
                              kind="eval")
        return callback_func(res) if callback_func else res
コード例 #21
0
    def train(self,
              train_data,
              valid_data=None,
              post_process=None,
              action="initial",
              **kwargs):
        """
        fit for update the knowledge based on training data.

        Parameters
        ----------
        train_data : BaseDataSource
            Train data, see `sedna.datasources.BaseDataSource` for more detail.
        valid_data : BaseDataSource
            Valid data, BaseDataSource or None.
        post_process : function
            function or a registered method, callback after `estimator` train.
        action : str
            `update` or `initial` the knowledge base
        kwargs : Dict
            parameters for `estimator` training, Like:
            `early_stopping_rounds` in Xgboost.XGBClassifier

        Returns
        -------
        train_history : object
        """

        callback_func = None
        if post_process is not None:
            callback_func = ClassFactory.get_cls(ClassType.CALLBACK,
                                                 post_process)
        res, task_index_url = self.estimator.train(
            train_data=train_data, valid_data=valid_data, **kwargs
        )  # todo: Distinguishing incremental update and fully overwrite

        if isinstance(task_index_url, str) and FileOps.exists(task_index_url):
            task_index = FileOps.load(task_index_url)
        else:
            task_index = task_index_url

        extractor = task_index['extractor']
        task_groups = task_index['task_groups']

        model_upload_key = {}
        for task in task_groups:
            model_file = task.model.model
            save_model = FileOps.join_path(self.config.output_url,
                                           os.path.basename(model_file))
            if model_file not in model_upload_key:
                model_upload_key[model_file] = FileOps.upload(
                    model_file, save_model)
            model_file = model_upload_key[model_file]

            try:
                model = self.kb_server.upload_file(save_model)
            except Exception as err:
                self.log.error(
                    f"Upload task model of {model_file} fail: {err}")
                model = set_backend(
                    estimator=self.estimator.estimator.base_model)
                model.load(model_file)
            task.model.model = model

            for _task in task.tasks:
                sample_dir = FileOps.join_path(
                    self.config.output_url,
                    f"{_task.samples.data_type}_{_task.entry}.sample")
                task.samples.save(sample_dir)
                try:
                    sample_dir = self.kb_server.upload_file(sample_dir)
                except Exception as err:
                    self.log.error(
                        f"Upload task samples of {_task.entry} fail: {err}")
                _task.samples.data_url = sample_dir

        save_extractor = FileOps.join_path(
            self.config.output_url,
            KBResourceConstant.TASK_EXTRACTOR_NAME.value)
        extractor = FileOps.dump(extractor, save_extractor)
        try:
            extractor = self.kb_server.upload_file(extractor)
        except Exception as err:
            self.log.error(f"Upload task extractor fail: {err}")
        task_info = {"task_groups": task_groups, "extractor": extractor}
        fd, name = tempfile.mkstemp()
        FileOps.dump(task_info, name)

        index_file = self.kb_server.update_db(name)
        if not index_file:
            self.log.error(f"KB update Fail !")
            index_file = name
        FileOps.upload(index_file, self.config.task_index)

        task_info_res = self.estimator.model_info(
            self.config.task_index, relpath=self.config.data_path_prefix)
        self.report_task_info(None, K8sResourceKindStatus.COMPLETED.value,
                              task_info_res)
        self.log.info(f"Lifelong learning Train task Finished, "
                      f"KB idnex save in {self.config.task_index}")
        return callback_func(self.estimator, res) if callback_func else res
コード例 #22
0
    def train(self, train_data,
              valid_data=None,
              post_process=None,
              **kwargs):
        """
        Training task for FederatedLearning

        Parameters
        ----------
        train_data: BaseDataSource
            datasource use for train, see
            `sedna.datasources.BaseDataSource` for more detail.
        valid_data:  BaseDataSource
            datasource use for evaluation, see
            `sedna.datasources.BaseDataSource` for more detail.
        post_process: function or a registered method
            effected after `estimator` training.
        kwargs: Dict
            parameters for `estimator` training,
            Like:  `early_stopping_rounds` in Xgboost.XGBClassifier
        """

        callback_func = None
        if post_process:
            callback_func = ClassFactory.get_cls(
                ClassType.CALLBACK, post_process)

        round_number = 0
        num_samples = len(train_data)
        _flag = True
        start = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
        res = None
        while 1:
            if _flag:
                round_number += 1
                start = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
                self.log.info(
                    f"Federated learning start, round_number={round_number}")
                res = self.estimator.train(
                    train_data=train_data, valid_data=valid_data, **kwargs)

                current_weights = self.estimator.get_weights()
                send_data = {"num_samples": num_samples,
                             "weights": current_weights}
                self.node.send(
                    send_data, msg_type="update_weight", job_name=self.job_name
                )
            received = self.node.recv(wait_data_type="recv_weight")
            if not received:
                _flag = False
                continue
            _flag = True

            rec_data = received.get("data", {})
            exit_flag = rec_data.get("exit_flag", "")
            server_round = int(rec_data.get("round_number"))
            total_size = int(rec_data.get("total_sample"))
            self.log.info(
                f"Federated learning recv weight, "
                f"round: {server_round}, total_sample: {total_size}"
            )
            n_weight = rec_data.get("weights")
            self.estimator.set_weights(n_weight)
            task_info = {
                'currentRound': round_number,
                'sampleCount': total_size,
                'startTime': start,
                'updateTime': time.strftime(
                    "%Y-%m-%d %H:%M:%S", time.localtime())
            }
            model_paths = self.estimator.save()
            task_info_res = self.estimator.model_info(
                model_paths, result=res, relpath=self.config.data_path_prefix)
            if exit_flag == "ok":
                self.report_task_info(
                    task_info,
                    K8sResourceKindStatus.COMPLETED.value,
                    task_info_res)
                self.log.info(f"exit training from [{self.worker_name}]")
                return callback_func(
                    self.estimator) if callback_func else self.estimator
            else:
                self.report_task_info(
                    task_info,
                    K8sResourceKindStatus.RUNNING.value,
                    task_info_res)
コード例 #23
0
    def train(self,
              train_data: BaseDataSource,
              valid_data: BaseDataSource = None,
              post_process=None,
              **kwargs):
        """
        fit for update the knowledge based on training data.

        Parameters
        ----------
        train_data : BaseDataSource
            Train data, see `sedna.datasources.BaseDataSource` for more detail.
        valid_data : BaseDataSource
            Valid data, BaseDataSource or None.
        post_process : function
            function or a registered method, callback after `estimator` train.
        kwargs : Dict
            parameters for `estimator` training, Like:
            `early_stopping_rounds` in Xgboost.XGBClassifier

        Returns
        -------
        feedback : Dict
            contain all training result in each tasks.
        task_index_url : str
            task extractor model path, used for task mining.
        """

        tasks, task_extractor, train_data = self._task_definition(train_data)
        self.extractor = task_extractor
        task_groups = self._task_relationship_discovery(tasks)
        self.models = []
        callback = None
        if isinstance(post_process, str):
            callback = ClassFactory.get_cls(ClassType.CALLBACK, post_process)()
        self.task_groups = []
        feedback = {}
        rare_task = []
        for i, task in enumerate(task_groups):
            if not isinstance(task, TaskGroup):
                rare_task.append(i)
                self.models.append(None)
                self.task_groups.append(None)
                continue
            if not (task.samples
                    and len(task.samples) > self.min_train_sample):
                self.models.append(None)
                self.task_groups.append(None)
                rare_task.append(i)
                n = len(task.samples)
                LOGGER.info(f"Sample {n} of {task.entry} will be merge")
                continue
            LOGGER.info(f"MTL Train start {i} : {task.entry}")

            model = None
            for t in task.tasks:  # if model has train in tasks
                if not (t.model and t.result):
                    continue
                model_path = t.model.save(model_name=f"{task.entry}.model")
                t.model = model_path
                model = Model(index=i,
                              entry=t.entry,
                              model=model_path,
                              result=t.result)
                model.meta_attr = t.meta_attr
                break
            if not model:
                model_obj = set_backend(estimator=self.base_model)
                res = model_obj.train(train_data=task.samples, **kwargs)
                if callback:
                    res = callback(model_obj, res)
                model_path = model_obj.save(model_name=f"{task.entry}.model")
                model = Model(index=i,
                              entry=task.entry,
                              model=model_path,
                              result=res)

                model.meta_attr = [t.meta_attr for t in task.tasks]
            task.model = model
            self.models.append(model)
            feedback[task.entry] = model.result
            self.task_groups.append(task)

        if len(rare_task):
            model_obj = set_backend(estimator=self.base_model)
            res = model_obj.train(train_data=train_data, **kwargs)
            model_path = model_obj.save(model_name="global.model")
            for i in rare_task:
                task = task_groups[i]
                entry = getattr(task, 'entry', "global")
                if not isinstance(task, TaskGroup):
                    task = TaskGroup(entry=entry, tasks=[])
                model = Model(index=i,
                              entry=entry,
                              model=model_path,
                              result=res)
                model.meta_attr = [t.meta_attr for t in task.tasks]
                task.model = model
                task.samples = train_data
                self.models[i] = model
                feedback[entry] = res
                self.task_groups[i] = task

        task_index = {
            "extractor": self.extractor,
            "task_groups": self.task_groups
        }
        if valid_data:
            feedback, _ = self.evaluate(valid_data, **kwargs)
        try:
            FileOps.dump(task_index, self.task_index_url)
        except TypeError:
            return feedback, task_index
        return feedback, self.task_index_url