Exemple #1
0
    def update_status(self, data: KBUpdateResult = Body(...)):
        deploy = True if data.status else False
        tasks = data.tasks.split(",") if data.tasks else []
        with Session(bind=engine) as session:
            session.query(TaskGrp).filter(TaskGrp.name.in_(tasks)).update(
                {TaskGrp.deploy: deploy}, synchronize_session=False)

        # todo: get from kb
        _index_path = FileOps.join_path(self.save_dir, self.kb_index)
        task_info = joblib.load(_index_path)
        new_task_group = []

        default_task = task_info["task_groups"][0]
        # todo: get from transfer learning
        for task_group in task_info["task_groups"]:
            if not ((task_group.entry in tasks) == deploy):
                new_task_group.append(default_task)
                continue
            new_task_group.append(task_group)
        task_info["task_groups"] = new_task_group
        _index_path = FileOps.join_path(self.save_dir, self.kb_index)
        FileOps.dump(task_info, _index_path)
        return f"/file/download?files={self.kb_index}&name={self.kb_index}"
Exemple #2
0
    def train(self,
              train_data,
              valid_data=None,
              post_process=None,
              action="initial",
              **kwargs):
        """
        fit for update the knowledge based on training data.

        Parameters
        ----------
        train_data : BaseDataSource
            Train data, see `sedna.datasources.BaseDataSource` for more detail.
        valid_data : BaseDataSource
            Valid data, BaseDataSource or None.
        post_process : function
            function or a registered method, callback after `estimator` train.
        action : str
            `update` or `initial` the knowledge base
        kwargs : Dict
            parameters for `estimator` training, Like:
            `early_stopping_rounds` in Xgboost.XGBClassifier

        Returns
        -------
        train_history : object
        """

        callback_func = None
        if post_process is not None:
            callback_func = ClassFactory.get_cls(ClassType.CALLBACK,
                                                 post_process)
        res, task_index_url = self.estimator.train(
            train_data=train_data, valid_data=valid_data, **kwargs
        )  # todo: Distinguishing incremental update and fully overwrite

        if isinstance(task_index_url, str) and FileOps.exists(task_index_url):
            task_index = FileOps.load(task_index_url)
        else:
            task_index = task_index_url

        extractor = task_index['extractor']
        task_groups = task_index['task_groups']

        model_upload_key = {}
        for task in task_groups:
            model_file = task.model.model
            save_model = FileOps.join_path(self.config.output_url,
                                           os.path.basename(model_file))
            if model_file not in model_upload_key:
                model_upload_key[model_file] = FileOps.upload(
                    model_file, save_model)
            model_file = model_upload_key[model_file]

            try:
                model = self.kb_server.upload_file(save_model)
            except Exception as err:
                self.log.error(
                    f"Upload task model of {model_file} fail: {err}")
                model = set_backend(
                    estimator=self.estimator.estimator.base_model)
                model.load(model_file)
            task.model.model = model

            for _task in task.tasks:
                sample_dir = FileOps.join_path(
                    self.config.output_url,
                    f"{_task.samples.data_type}_{_task.entry}.sample")
                task.samples.save(sample_dir)
                try:
                    sample_dir = self.kb_server.upload_file(sample_dir)
                except Exception as err:
                    self.log.error(
                        f"Upload task samples of {_task.entry} fail: {err}")
                _task.samples.data_url = sample_dir

        save_extractor = FileOps.join_path(
            self.config.output_url,
            KBResourceConstant.TASK_EXTRACTOR_NAME.value)
        extractor = FileOps.dump(extractor, save_extractor)
        try:
            extractor = self.kb_server.upload_file(extractor)
        except Exception as err:
            self.log.error(f"Upload task extractor fail: {err}")
        task_info = {"task_groups": task_groups, "extractor": extractor}
        fd, name = tempfile.mkstemp()
        FileOps.dump(task_info, name)

        index_file = self.kb_server.update_db(name)
        if not index_file:
            self.log.error(f"KB update Fail !")
            index_file = name
        FileOps.upload(index_file, self.config.task_index)

        task_info_res = self.estimator.model_info(
            self.config.task_index, relpath=self.config.data_path_prefix)
        self.report_task_info(None, K8sResourceKindStatus.COMPLETED.value,
                              task_info_res)
        self.log.info(f"Lifelong learning Train task Finished, "
                      f"KB idnex save in {self.config.task_index}")
        return callback_func(self.estimator, res) if callback_func else res
Exemple #3
0
    def update(self, task: UploadFile = File(...)):
        tasks = task.file.read()
        fd, name = tempfile.mkstemp()
        with open(name, "wb") as fout:
            fout.write(tasks)
        os.close(fd)
        upload_info = joblib.load(name)

        with Session(bind=engine) as session:
            for task_group in upload_info["task_groups"]:
                grp, g_create = get_or_create(session=session,
                                              model=TaskGrp,
                                              name=task_group.entry)
                if g_create:
                    grp.sample_num = 0
                    grp.task_num = 0
                    session.add(grp)
                grp.sample_num += len(task_group.samples)
                grp.task_num += len(task_group.tasks)
                t_id = []
                for task in task_group.tasks:
                    t_obj, t_create = get_or_create(session=session,
                                                    model=Tasks,
                                                    name=task.entry)
                    if task.meta_attr:
                        t_obj.task_attr = json.dumps(task.meta_attr)
                    if t_create:
                        session.add(t_obj)

                    sample_obj = Samples(data_type=task.samples.data_type,
                                         sample_num=len(task.samples),
                                         data_url=getattr(
                                             task, 'data_url', ''))
                    session.add(sample_obj)

                    session.flush()
                    session.commit()
                    tsample = TaskSample(sample=sample_obj, task=t_obj)
                    session.add(tsample)
                    session.flush()
                    t_id.append(t_obj.id)

                model_obj, m_create = get_or_create(session=session,
                                                    model=TaskModel,
                                                    task=grp)
                model_obj.model_url = task_group.model.model
                model_obj.is_current = False
                if m_create:
                    session.add(model_obj)
                session.flush()
                session.commit()
                transfer_radio = 1 / grp.task_num
                for t in t_id:
                    t_obj, t_create = get_or_create(session=session,
                                                    model=TaskRelation,
                                                    task_id=t,
                                                    grp=grp)
                    t_obj.transfer_radio = transfer_radio
                    if t_create:
                        session.add(t_obj)
                        session.flush()
                    session.commit()
                session.query(TaskRelation).filter(
                    TaskRelation.grp == grp).update(
                        {"transfer_radio": transfer_radio})

            session.commit()

        # todo: get from kb
        _index_path = FileOps.join_path(self.save_dir, self.kb_index)
        _index_path = FileOps.dump(upload_info, _index_path)

        return f"/file/download?files={self.kb_index}&name={self.kb_index}"
    def train(self,
              train_data: BaseDataSource,
              valid_data: BaseDataSource = None,
              post_process=None,
              **kwargs):
        """
        fit for update the knowledge based on training data.

        Parameters
        ----------
        train_data : BaseDataSource
            Train data, see `sedna.datasources.BaseDataSource` for more detail.
        valid_data : BaseDataSource
            Valid data, BaseDataSource or None.
        post_process : function
            function or a registered method, callback after `estimator` train.
        kwargs : Dict
            parameters for `estimator` training, Like:
            `early_stopping_rounds` in Xgboost.XGBClassifier

        Returns
        -------
        feedback : Dict
            contain all training result in each tasks.
        task_index_url : str
            task extractor model path, used for task mining.
        """

        tasks, task_extractor, train_data = self._task_definition(train_data)
        self.extractor = task_extractor
        task_groups = self._task_relationship_discovery(tasks)
        self.models = []
        callback = None
        if isinstance(post_process, str):
            callback = ClassFactory.get_cls(ClassType.CALLBACK, post_process)()
        self.task_groups = []
        feedback = {}
        rare_task = []
        for i, task in enumerate(task_groups):
            if not isinstance(task, TaskGroup):
                rare_task.append(i)
                self.models.append(None)
                self.task_groups.append(None)
                continue
            if not (task.samples
                    and len(task.samples) > self.min_train_sample):
                self.models.append(None)
                self.task_groups.append(None)
                rare_task.append(i)
                n = len(task.samples)
                LOGGER.info(f"Sample {n} of {task.entry} will be merge")
                continue
            LOGGER.info(f"MTL Train start {i} : {task.entry}")

            model = None
            for t in task.tasks:  # if model has train in tasks
                if not (t.model and t.result):
                    continue
                model_path = t.model.save(model_name=f"{task.entry}.model")
                t.model = model_path
                model = Model(index=i,
                              entry=t.entry,
                              model=model_path,
                              result=t.result)
                model.meta_attr = t.meta_attr
                break
            if not model:
                model_obj = set_backend(estimator=self.base_model)
                res = model_obj.train(train_data=task.samples, **kwargs)
                if callback:
                    res = callback(model_obj, res)
                model_path = model_obj.save(model_name=f"{task.entry}.model")
                model = Model(index=i,
                              entry=task.entry,
                              model=model_path,
                              result=res)

                model.meta_attr = [t.meta_attr for t in task.tasks]
            task.model = model
            self.models.append(model)
            feedback[task.entry] = model.result
            self.task_groups.append(task)

        if len(rare_task):
            model_obj = set_backend(estimator=self.base_model)
            res = model_obj.train(train_data=train_data, **kwargs)
            model_path = model_obj.save(model_name="global.model")
            for i in rare_task:
                task = task_groups[i]
                entry = getattr(task, 'entry', "global")
                if not isinstance(task, TaskGroup):
                    task = TaskGroup(entry=entry, tasks=[])
                model = Model(index=i,
                              entry=entry,
                              model=model_path,
                              result=res)
                model.meta_attr = [t.meta_attr for t in task.tasks]
                task.model = model
                task.samples = train_data
                self.models[i] = model
                feedback[entry] = res
                self.task_groups[i] = task

        task_index = {
            "extractor": self.extractor,
            "task_groups": self.task_groups
        }
        if valid_data:
            feedback, _ = self.evaluate(valid_data, **kwargs)
        try:
            FileOps.dump(task_index, self.task_index_url)
        except TypeError:
            return feedback, task_index
        return feedback, self.task_index_url
Exemple #5
0
 def save(self, output=""):
     return FileOps.dump(self, output)