예제 #1
0
    def load(self, task_index_url=None):
        """
        load task_detail (tasks/models etc ...) from task index file.
        It'll automatically loaded during `inference` and `evaluation` phases.

        Parameters
        ----------
        task_index_url : str
            task index file path, default self.task_index_url.
        """

        if task_index_url:
            self.task_index_url = task_index_url
        assert FileOps.exists(self.task_index_url), FileExistsError(
            f"Task index miss: {self.task_index_url}")
        task_index = FileOps.load(self.task_index_url)
        self.extractor = task_index['extractor']
        if isinstance(self.extractor, str):
            self.extractor = FileOps.load(self.extractor)
        self.task_groups = task_index['task_groups']
        self.models = [task.model for task in self.task_groups]
예제 #2
0
    def train(self,
              train_data,
              valid_data=None,
              post_process=None,
              action="initial",
              **kwargs):
        """
        fit for update the knowledge based on training data.

        Parameters
        ----------
        train_data : BaseDataSource
            Train data, see `sedna.datasources.BaseDataSource` for more detail.
        valid_data : BaseDataSource
            Valid data, BaseDataSource or None.
        post_process : function
            function or a registered method, callback after `estimator` train.
        action : str
            `update` or `initial` the knowledge base
        kwargs : Dict
            parameters for `estimator` training, Like:
            `early_stopping_rounds` in Xgboost.XGBClassifier

        Returns
        -------
        train_history : object
        """

        callback_func = None
        if post_process is not None:
            callback_func = ClassFactory.get_cls(ClassType.CALLBACK,
                                                 post_process)
        res, task_index_url = self.estimator.train(
            train_data=train_data, valid_data=valid_data, **kwargs
        )  # todo: Distinguishing incremental update and fully overwrite

        if isinstance(task_index_url, str) and FileOps.exists(task_index_url):
            task_index = FileOps.load(task_index_url)
        else:
            task_index = task_index_url

        extractor = task_index['extractor']
        task_groups = task_index['task_groups']

        model_upload_key = {}
        for task in task_groups:
            model_file = task.model.model
            save_model = FileOps.join_path(self.config.output_url,
                                           os.path.basename(model_file))
            if model_file not in model_upload_key:
                model_upload_key[model_file] = FileOps.upload(
                    model_file, save_model)
            model_file = model_upload_key[model_file]

            try:
                model = self.kb_server.upload_file(save_model)
            except Exception as err:
                self.log.error(
                    f"Upload task model of {model_file} fail: {err}")
                model = set_backend(
                    estimator=self.estimator.estimator.base_model)
                model.load(model_file)
            task.model.model = model

            for _task in task.tasks:
                sample_dir = FileOps.join_path(
                    self.config.output_url,
                    f"{_task.samples.data_type}_{_task.entry}.sample")
                task.samples.save(sample_dir)
                try:
                    sample_dir = self.kb_server.upload_file(sample_dir)
                except Exception as err:
                    self.log.error(
                        f"Upload task samples of {_task.entry} fail: {err}")
                _task.samples.data_url = sample_dir

        save_extractor = FileOps.join_path(
            self.config.output_url,
            KBResourceConstant.TASK_EXTRACTOR_NAME.value)
        extractor = FileOps.dump(extractor, save_extractor)
        try:
            extractor = self.kb_server.upload_file(extractor)
        except Exception as err:
            self.log.error(f"Upload task extractor fail: {err}")
        task_info = {"task_groups": task_groups, "extractor": extractor}
        fd, name = tempfile.mkstemp()
        FileOps.dump(task_info, name)

        index_file = self.kb_server.update_db(name)
        if not index_file:
            self.log.error(f"KB update Fail !")
            index_file = name
        FileOps.upload(index_file, self.config.task_index)

        task_info_res = self.estimator.model_info(
            self.config.task_index, relpath=self.config.data_path_prefix)
        self.report_task_info(None, K8sResourceKindStatus.COMPLETED.value,
                              task_info_res)
        self.log.info(f"Lifelong learning Train task Finished, "
                      f"KB idnex save in {self.config.task_index}")
        return callback_func(self.estimator, res) if callback_func else res