Exemple #1
0
    def inference(self, data=None, post_process=None, **kwargs):
        """
        predict the result for input data based on training knowledge.

        Parameters
        ----------
        data : BaseDataSource
            inference sample, see `sedna.datasources.BaseDataSource` for
            more detail.
        post_process: function
            function or a registered method,  effected after `estimator`
            prediction, like: label transform.
        kwargs: Dict
            parameters for `estimator` predict, Like:
            `ntree_limit` in Xgboost.XGBClassifier

        Returns
        -------
        result : array_like
            results array, contain all inference results in each sample.
        is_unseen_task : bool
            `true` means detect an unseen task, `false` means not
        tasks : List
            tasks assigned to each sample.
        """
        task_index_url = self.get_parameters("MODEL_URLS",
                                             self.config.task_index)
        index_url = self.estimator.estimator.task_index_url
        FileOps.download(task_index_url, index_url)
        res, tasks = self.estimator.predict(data=data,
                                            post_process=post_process,
                                            **kwargs)

        is_unseen_task = False
        if self.unseen_task_detect:

            try:
                if callable(self.unseen_task_detect):
                    unseen_task_detect_algorithm = self.unseen_task_detect()
                else:
                    unseen_task_detect_algorithm = ClassFactory.get_cls(
                        ClassType.UTD, self.unseen_task_detect)()
            except ValueError as err:
                self.log.error("Lifelong learning "
                               "Inference [UTD] : {}".format(err))
            else:
                is_unseen_task = unseen_task_detect_algorithm(
                    tasks=tasks, result=res, **self.unseen_task_detect_param)
        return res, is_unseen_task, tasks
    def __init__(self, estimator, hard_example_mining: dict = None):
        super(IncrementalLearning, self).__init__(estimator=estimator)

        self.model_urls = self.get_parameters(
            "MODEL_URLS")  # use in evaluation
        self.job_kind = K8sResourceKind.INCREMENTAL_JOB.value
        FileOps.clean_folder([self.config.model_url], clean=False)
        self.hard_example_mining_algorithm = None
        if not hard_example_mining:
            hard_example_mining = self.get_hem_algorithm_from_config()
        if hard_example_mining:
            hem = hard_example_mining.get("method", "IBT")
            hem_parameters = hard_example_mining.get("param", {})
            self.hard_example_mining_algorithm = ClassFactory.get_cls(
                ClassType.HEM, hem)(**hem_parameters)
Exemple #3
0
 def upload_file(self, files, name=""):
     if not (files and os.path.isfile(files)):
         return files
     if not name:
         name = os.path.basename(files)
     LOGGER.info(f"Try to upload file {name}")
     _url = f"{self.kbserver}/file/upload"
     with open(files, "rb") as fin:
         files = {"file": fin}
         outurl = http_request(url=_url, method="POST", files=files)
     if FileOps.is_remote(outurl):
         return outurl
     outurl = outurl.lstrip("/")
     FileOps.delete(files)
     return f"{self.kbserver}/{outurl}"
Exemple #4
0
 async def file_upload(self, file: UploadFile = File(...)):
     files = await file.read()
     filename = str(file.filename)
     output = FileOps.join_path(self.save_dir, filename)
     with open(output, "wb") as fout:
         fout.write(files)
     return f"/file/download?files={filename}&name={filename}"
Exemple #5
0
 def parse(self, *args, **kwargs):
     x_data = []
     y_data = []
     use_raw = kwargs.get("use_raw")
     for f in args:
         if not (f and FileOps.exists(f)):
             continue
         with open(f) as fin:
             if self.process_func:
                 res = list(
                     map(self.process_func,
                         [line.strip() for line in fin.readlines()]))
             else:
                 res = [line.strip().split() for line in fin.readlines()]
         for tup in res:
             if not len(tup):
                 continue
             if use_raw:
                 x_data.append(tup)
             else:
                 x_data.append(tup[0])
                 if not self.is_test_data:
                     if len(tup) > 1:
                         y_data.append(tup[1])
                     else:
                         y_data.append(0)
     self.x = np.array(x_data)
     self.y = np.array(y_data)
Exemple #6
0
 def parse(self, *args, **kwargs):
     x_data = []
     y_data = []
     label = kwargs.pop("label") if "label" in kwargs else ""
     usecols = kwargs.get("usecols", "")
     if usecols and isinstance(usecols, str):
         usecols = usecols.split(",")
     if len(usecols):
         if label and label not in usecols:
             usecols.append(label)
         kwargs["usecols"] = usecols
     for f in args:
         if isinstance(f, (dict, list)):
             res = self.parse_json(f, **kwargs)
         else:
             if not (f and FileOps.exists(f)):
                 continue
             res = pd.read_csv(f, **kwargs)
         if self.process_func and callable(self.process_func):
             res = self.process_func(res)
         if label:
             if label not in res.columns:
                 continue
             y = res[label]
             y_data.append(y)
             res.drop(label, axis=1, inplace=True)
         x_data.append(res)
     if not x_data:
         return
     self.x = pd.concat(x_data)
     self.y = pd.concat(y_data)
Exemple #7
0
    def __init__(self,
                 estimator,
                 task_definition=None,
                 task_relationship_discovery=None,
                 task_mining=None,
                 task_remodeling=None,
                 inference_integrate=None,
                 unseen_task_detect=None):

        if not task_definition:
            task_definition = {"method": "TaskDefinitionByDataAttr"}
        if not unseen_task_detect:
            unseen_task_detect = {"method": "TaskAttrFilter"}
        e = MulTaskLearning(
            estimator=estimator,
            task_definition=task_definition,
            task_relationship_discovery=task_relationship_discovery,
            task_mining=task_mining,
            task_remodeling=task_remodeling,
            inference_integrate=inference_integrate)
        self.unseen_task_detect = unseen_task_detect.get(
            "method", "TaskAttrFilter")
        self.unseen_task_detect_param = e._parse_param(
            unseen_task_detect.get("param", {}))
        config = dict(ll_kb_server=Context.get_parameters("KB_SERVER"),
                      output_url=Context.get_parameters("OUTPUT_URL", "/tmp"))
        task_index = FileOps.join_path(config['output_url'],
                                       KBResourceConstant.KB_INDEX_NAME.value)
        config['task_index'] = task_index
        super(LifelongLearning, self).__init__(estimator=e, config=config)
        self.job_kind = K8sResourceKind.LIFELONG_JOB.value
        self.kb_server = KBClient(kbserver=self.config.ll_kb_server)
Exemple #8
0
 def load(self, model_url="", model_name=None, **kwargs):
     mname = model_name or self.model_name
     if callable(self.estimator):
         varkw = self.parse_kwargs(self.estimator, **kwargs)
         self.estimator = self.estimator(**varkw)
     if model_url and os.path.isfile(model_url):
         self.model_save_path, mname = os.path.split(model_url)
     elif os.path.isfile(self.model_save_path):
         self.model_save_path, mname = os.path.split(self.model_save_path)
     model_path = FileOps.join_path(self.model_save_path, mname)
     if model_url:
         model_path = FileOps.download(model_url, model_path)
     self.has_load = True
     if not (hasattr(self.estimator, "load")
             and os.path.exists(model_path)):
         return
     return self.estimator.load(model_url=model_path)
Exemple #9
0
    def update_db(self, task_info_file):

        _url = f"{self.kbserver}/update"

        try:
            with open(task_info_file, "rb") as fin:
                files = {"task": fin}
                outurl = http_request(url=_url, method="POST", files=files)

        except Exception as err:
            LOGGER.error(f"Update kb error: {err}")
            outurl = None
        else:
            if not FileOps.is_remote(outurl):
                outurl = outurl.lstrip("/")
                outurl = f"{self.kbserver}/{outurl}"
            FileOps.delete(task_info_file)
        return outurl
Exemple #10
0
 def model_info(self, model, relpath=None, result=None):
     ckpt = os.path.dirname(model)
     _, _type = os.path.splitext(model)
     if relpath:
         _url = FileOps.remove_path_prefix(model, relpath)
         ckpt_url = FileOps.remove_path_prefix(ckpt, relpath)
     else:
         _url = model
         ckpt_url = ckpt
     _type = _type.lstrip(".").lower()
     results = [{"format": _type, "url": _url, "metrics": result}]
     if _type == "pb":  # report ckpt path when model save as pb file
         results.append({
             "format": "ckpt",
             "url": ckpt_url,
             "metrics": result
         })
     return results
Exemple #11
0
 def train(self, train_data, valid_data=None, **kwargs):
     if callable(self.estimator):
         self.estimator = self.estimator()
     if self.fine_tune and FileOps.exists(self.model_save_path):
         self.finetune()
     self.has_load = True
     varkw = self.parse_kwargs(self.estimator.train, **kwargs)
     return self.estimator.train(train_data=train_data,
                                 valid_data=valid_data,
                                 **varkw)
    def load(self, task_index_url=None):
        """
        load task_detail (tasks/models etc ...) from task index file.
        It'll automatically loaded during `inference` and `evaluation` phases.

        Parameters
        ----------
        task_index_url : str
            task index file path, default self.task_index_url.
        """

        if task_index_url:
            self.task_index_url = task_index_url
        assert FileOps.exists(self.task_index_url), FileExistsError(
            f"Task index miss: {self.task_index_url}")
        task_index = FileOps.load(self.task_index_url)
        self.extractor = task_index['extractor']
        if isinstance(self.extractor, str):
            self.extractor = FileOps.load(self.extractor)
        self.task_groups = task_index['task_groups']
        self.models = [task.model for task in self.task_groups]
    def register(self, timeout=300):
        """
        Deprecated, Client proactively subscribes to the aggregation service.

        Parameters
        ----------
        timeout: int, connect timeout. Default: 300
        """
        self.log.info(
            f"Node {self.worker_name} connect to : {self.config.agg_uri}")
        self.node = AggregationClient(
            url=self.config.agg_uri,
            client_id=self.worker_name,
            ping_timeout=timeout
        )

        FileOps.clean_folder([self.config.model_url], clean=False)
        self.aggregation = self.aggregation()
        self.log.info(f"{self.worker_name} model prepared")
        if callable(self.estimator):
            self.estimator = self.estimator()
Exemple #14
0
 def model_info(self, model, relpath=None, result=None):
     _, _type = os.path.splitext(model)
     if relpath:
         _url = FileOps.remove_path_prefix(model, relpath)
     else:
         _url = model
     results = [{
         "format": _type.lstrip("."),
         "url": _url,
         "metrics": result
     }]
     return results
Exemple #15
0
    def update_status(self, data: KBUpdateResult = Body(...)):
        deploy = True if data.status else False
        tasks = data.tasks.split(",") if data.tasks else []
        with Session(bind=engine) as session:
            session.query(TaskGrp).filter(TaskGrp.name.in_(tasks)).update(
                {TaskGrp.deploy: deploy}, synchronize_session=False)

        # todo: get from kb
        _index_path = FileOps.join_path(self.save_dir, self.kb_index)
        task_info = joblib.load(_index_path)
        new_task_group = []

        default_task = task_info["task_groups"][0]
        # todo: get from transfer learning
        for task_group in task_info["task_groups"]:
            if not ((task_group.entry in tasks) == deploy):
                new_task_group.append(default_task)
                continue
            new_task_group.append(task_group)
        task_info["task_groups"] = new_task_group
        _index_path = FileOps.join_path(self.save_dir, self.kb_index)
        FileOps.dump(task_info, _index_path)
        return f"/file/download?files={self.kb_index}&name={self.kb_index}"
Exemple #16
0
 def run(self):
     while self.run_flag:
         time.sleep(self.check_time)
         conf = FileOps.download(self.hot_update_conf)
         if not (conf and FileOps.exists(conf)):
             continue
         with open(conf, "r") as fin:
             try:
                 conf_msg = json.load(fin)
                 model_msg = conf_msg["model_config"]
                 latest_version = str(model_msg["model_update_time"])
                 model = FileOps.download(
                     model_msg["model_path"],
                     FileOps.join_path(self.temp_path,
                                       f"model.{latest_version}"))
             except (json.JSONDecodeError, KeyError):
                 LOGGER.error(f"fail to parse model hot update config: "
                              f"{self.hot_update_conf}")
                 continue
         if not (model and FileOps.exists(model)):
             continue
         if latest_version == self.version:
             continue
         self.version = latest_version
         with self.MODEL_MANIPULATION_SEM:
             LOGGER.info(f"Update model start with version {self.version}")
             try:
                 self.production_estimator.load(model)
                 status = K8sResourceKindStatus.COMPLETED.value
                 LOGGER.info(f"Update model complete "
                             f"with version {self.version}")
             except Exception as e:
                 LOGGER.error(f"fail to update model: {e}")
                 status = K8sResourceKindStatus.FAILED.value
             if self.callback:
                 self.callback(task_info=None, status=status, kind="deploy")
         gc.collect()
Exemple #17
0
    def __init__(self,
                 host: str,
                 http_port: int = 8080,
                 workers: int = 1,
                 save_dir=""):
        servername = "knowledgebase"

        super(KBServer, self).__init__(servername=servername,
                                       host=host,
                                       http_port=http_port,
                                       workers=workers)
        self.save_dir = FileOps.clean_folder([save_dir], clean=False)[0]
        self.url = f"{self.url}/{servername}"
        self.kb_index = KBResourceConstant.KB_INDEX_NAME.value
        self.app = FastAPI(
            routes=[
                APIRoute(
                    f"/{servername}/update",
                    self.update,
                    methods=["POST"],
                ),
                APIRoute(
                    f"/{servername}/update/status",
                    self.update_status,
                    methods=["POST"],
                ),
                APIRoute(
                    f"/{servername}/query",
                    self.query,
                    response_model=TaskItem,
                    response_class=JSONResponse,
                    methods=["POST"],
                ),
                APIRoute(
                    f"/{servername}/file/download",
                    self.file_download,
                    methods=["GET"],
                ),
                APIRoute(
                    f"/{servername}/file/upload",
                    self.file_upload,
                    methods=["POST"],
                ),
            ],
            log_level="trace",
            timeout=600,
        )
Exemple #18
0
 def update_task_status(self, tasks: str, new_status=1):
     data = {
         "tasks": tasks,
         "status": int(new_status)
     }
     _url = f"{self.kbserver}/update/status"
     try:
         outurl = http_request(url=_url, method="POST", json=data)
     except Exception as err:
         LOGGER.error(f"Update kb error: {err}")
         outurl = None
     if not outurl:
         return None
     if not FileOps.is_remote(outurl):
         outurl = outurl.lstrip("/")
         outurl = f"{self.kbserver}/{outurl}"
     return outurl
Exemple #19
0
    def save(self, model_url="", model_name=None):
        mname = model_name or self.model_name
        if os.path.isfile(self.model_save_path):
            self.model_save_path, mname = os.path.split(self.model_save_path)

        FileOps.clean_folder([self.model_save_path], clean=False)
        model_path = FileOps.join_path(self.model_save_path, mname)
        self.estimator.save(model_path)
        if model_url and FileOps.exists(model_path):
            FileOps.upload(model_path, model_url)
            model_path = model_url
        return model_path
Exemple #20
0
 def model_path(self):
     if os.path.isfile(self.config.model_url):
         return self.config.model_url
     return self.get_parameters('model_path') or FileOps.join_path(
         self.config.model_url, self.estimator.model_name)
Exemple #21
0
 def save(self, output=""):
     return FileOps.dump(self, output)
Exemple #22
0
    def evaluate(self, data, post_process=None, **kwargs):
        """
        evaluated the performance of each task from training, filter tasks
        based on the defined rules.

        Parameters
        ----------
        data : BaseDataSource
            valid data, see `sedna.datasources.BaseDataSource` for more detail.
        kwargs: Dict
            parameters for `estimator` evaluate, Like:
            `ntree_limit` in Xgboost.XGBClassifier
        """

        callback_func = None
        if callable(post_process):
            callback_func = post_process
        elif post_process is not None:
            callback_func = ClassFactory.get_cls(ClassType.CALLBACK,
                                                 post_process)
        task_index_url = self.get_parameters("MODEL_URLS",
                                             self.config.task_index)
        index_url = self.estimator.estimator.task_index_url
        self.log.info(
            f"Download kb index from {task_index_url} to {index_url}")
        FileOps.download(task_index_url, index_url)
        res, tasks_detail = self.estimator.evaluate(data=data, **kwargs)
        drop_tasks = []

        model_filter_operator = self.get_parameters("operator", ">")
        model_threshold = float(self.get_parameters('model_threshold', 0.1))

        operator_map = {
            ">": lambda x, y: x > y,
            "<": lambda x, y: x < y,
            "=": lambda x, y: x == y,
            ">=": lambda x, y: x >= y,
            "<=": lambda x, y: x <= y,
        }
        if model_filter_operator not in operator_map:
            self.log.warn(f"operator {model_filter_operator} use to "
                          f"compare is not allow, set to <")
            model_filter_operator = "<"
        operator_func = operator_map[model_filter_operator]

        for detail in tasks_detail:
            scores = detail.scores
            entry = detail.entry
            self.log.info(f"{entry} scores: {scores}")
            if any(
                    map(lambda x: operator_func(float(x), model_threshold),
                        scores.values())):
                self.log.warn(
                    f"{entry} will not be deploy because all "
                    f"scores {model_filter_operator} {model_threshold}")
                drop_tasks.append(entry)
                continue
        drop_task = ",".join(drop_tasks)
        index_file = self.kb_server.update_task_status(drop_task, new_status=0)
        if not index_file:
            self.log.error(f"KB update Fail !")
            index_file = str(index_url)
        self.log.info(
            f"upload kb index from {index_file} to {self.config.task_index}")
        FileOps.upload(index_file, self.config.task_index)
        task_info_res = self.estimator.model_info(
            self.config.task_index,
            result=res,
            relpath=self.config.data_path_prefix)
        self.report_task_info(None,
                              K8sResourceKindStatus.COMPLETED.value,
                              task_info_res,
                              kind="eval")
        return callback_func(res) if callback_func else res
Exemple #23
0
    def train(self,
              train_data,
              valid_data=None,
              post_process=None,
              action="initial",
              **kwargs):
        """
        fit for update the knowledge based on training data.

        Parameters
        ----------
        train_data : BaseDataSource
            Train data, see `sedna.datasources.BaseDataSource` for more detail.
        valid_data : BaseDataSource
            Valid data, BaseDataSource or None.
        post_process : function
            function or a registered method, callback after `estimator` train.
        action : str
            `update` or `initial` the knowledge base
        kwargs : Dict
            parameters for `estimator` training, Like:
            `early_stopping_rounds` in Xgboost.XGBClassifier

        Returns
        -------
        train_history : object
        """

        callback_func = None
        if post_process is not None:
            callback_func = ClassFactory.get_cls(ClassType.CALLBACK,
                                                 post_process)
        res, task_index_url = self.estimator.train(
            train_data=train_data, valid_data=valid_data, **kwargs
        )  # todo: Distinguishing incremental update and fully overwrite

        if isinstance(task_index_url, str) and FileOps.exists(task_index_url):
            task_index = FileOps.load(task_index_url)
        else:
            task_index = task_index_url

        extractor = task_index['extractor']
        task_groups = task_index['task_groups']

        model_upload_key = {}
        for task in task_groups:
            model_file = task.model.model
            save_model = FileOps.join_path(self.config.output_url,
                                           os.path.basename(model_file))
            if model_file not in model_upload_key:
                model_upload_key[model_file] = FileOps.upload(
                    model_file, save_model)
            model_file = model_upload_key[model_file]

            try:
                model = self.kb_server.upload_file(save_model)
            except Exception as err:
                self.log.error(
                    f"Upload task model of {model_file} fail: {err}")
                model = set_backend(
                    estimator=self.estimator.estimator.base_model)
                model.load(model_file)
            task.model.model = model

            for _task in task.tasks:
                sample_dir = FileOps.join_path(
                    self.config.output_url,
                    f"{_task.samples.data_type}_{_task.entry}.sample")
                task.samples.save(sample_dir)
                try:
                    sample_dir = self.kb_server.upload_file(sample_dir)
                except Exception as err:
                    self.log.error(
                        f"Upload task samples of {_task.entry} fail: {err}")
                _task.samples.data_url = sample_dir

        save_extractor = FileOps.join_path(
            self.config.output_url,
            KBResourceConstant.TASK_EXTRACTOR_NAME.value)
        extractor = FileOps.dump(extractor, save_extractor)
        try:
            extractor = self.kb_server.upload_file(extractor)
        except Exception as err:
            self.log.error(f"Upload task extractor fail: {err}")
        task_info = {"task_groups": task_groups, "extractor": extractor}
        fd, name = tempfile.mkstemp()
        FileOps.dump(task_info, name)

        index_file = self.kb_server.update_db(name)
        if not index_file:
            self.log.error(f"KB update Fail !")
            index_file = name
        FileOps.upload(index_file, self.config.task_index)

        task_info_res = self.estimator.model_info(
            self.config.task_index, relpath=self.config.data_path_prefix)
        self.report_task_info(None, K8sResourceKindStatus.COMPLETED.value,
                              task_info_res)
        self.log.info(f"Lifelong learning Train task Finished, "
                      f"KB idnex save in {self.config.task_index}")
        return callback_func(self.estimator, res) if callback_func else res
Exemple #24
0
import time
import warnings

import cv2
import numpy as np

from sedna.common.config import Context
from sedna.common.file_ops import FileOps
from sedna.core.incremental_learning import IncrementalLearning
from interface import Estimator

he_saved_url = Context.get_parameters("HE_SAVED_URL", '/tmp')
rsl_saved_url = Context.get_parameters("RESULT_SAVED_URL", '/tmp')
class_names = ['person', 'helmet', 'helmet_on', 'helmet_off']

FileOps.clean_folder([he_saved_url, rsl_saved_url], clean=False)


def draw_boxes(img, labels, scores, bboxes, class_names, colors):
    line_type = 2
    text_thickness = 1
    box_thickness = 1
    #  get color code
    colors = colors.split(",")
    colors_code = []
    for color in colors:
        if color == 'green':
            colors_code.append((0, 255, 0))
        elif color == 'blue':
            colors_code.append((255, 0, 0))
        elif color == 'yellow':
Exemple #25
0
    def update(self, task: UploadFile = File(...)):
        tasks = task.file.read()
        fd, name = tempfile.mkstemp()
        with open(name, "wb") as fout:
            fout.write(tasks)
        os.close(fd)
        upload_info = joblib.load(name)

        with Session(bind=engine) as session:
            for task_group in upload_info["task_groups"]:
                grp, g_create = get_or_create(session=session,
                                              model=TaskGrp,
                                              name=task_group.entry)
                if g_create:
                    grp.sample_num = 0
                    grp.task_num = 0
                    session.add(grp)
                grp.sample_num += len(task_group.samples)
                grp.task_num += len(task_group.tasks)
                t_id = []
                for task in task_group.tasks:
                    t_obj, t_create = get_or_create(session=session,
                                                    model=Tasks,
                                                    name=task.entry)
                    if task.meta_attr:
                        t_obj.task_attr = json.dumps(task.meta_attr)
                    if t_create:
                        session.add(t_obj)

                    sample_obj = Samples(data_type=task.samples.data_type,
                                         sample_num=len(task.samples),
                                         data_url=getattr(
                                             task, 'data_url', ''))
                    session.add(sample_obj)

                    session.flush()
                    session.commit()
                    tsample = TaskSample(sample=sample_obj, task=t_obj)
                    session.add(tsample)
                    session.flush()
                    t_id.append(t_obj.id)

                model_obj, m_create = get_or_create(session=session,
                                                    model=TaskModel,
                                                    task=grp)
                model_obj.model_url = task_group.model.model
                model_obj.is_current = False
                if m_create:
                    session.add(model_obj)
                session.flush()
                session.commit()
                transfer_radio = 1 / grp.task_num
                for t in t_id:
                    t_obj, t_create = get_or_create(session=session,
                                                    model=TaskRelation,
                                                    task_id=t,
                                                    grp=grp)
                    t_obj.transfer_radio = transfer_radio
                    if t_create:
                        session.add(t_obj)
                        session.flush()
                    session.commit()
                session.query(TaskRelation).filter(
                    TaskRelation.grp == grp).update(
                        {"transfer_radio": transfer_radio})

            session.commit()

        # todo: get from kb
        _index_path = FileOps.join_path(self.save_dir, self.kb_index)
        _index_path = FileOps.dump(upload_info, _index_path)

        return f"/file/download?files={self.kb_index}&name={self.kb_index}"
Exemple #26
0
 async def file_download(self, files: str, name: str = ""):
     files = FileOps.join_path(self.save_dir, files)
     return self._file_endpoint(files, name=name)
Exemple #27
0
 def _get_db_index(self):
     _index_path = FileOps.join_path(self.save_dir, self.kb_index)
     if not FileOps.exists(_index_path):  # todo: get from kb
         pass
     return _index_path
Exemple #28
0
 def load_weights(self):
     model_path = FileOps.join_path(self.model_save_path, self.model_name)
     if os.path.exists(model_path):
         self.estimator.load_weights(model_path)
    def train(self,
              train_data: BaseDataSource,
              valid_data: BaseDataSource = None,
              post_process=None,
              **kwargs):
        """
        fit for update the knowledge based on training data.

        Parameters
        ----------
        train_data : BaseDataSource
            Train data, see `sedna.datasources.BaseDataSource` for more detail.
        valid_data : BaseDataSource
            Valid data, BaseDataSource or None.
        post_process : function
            function or a registered method, callback after `estimator` train.
        kwargs : Dict
            parameters for `estimator` training, Like:
            `early_stopping_rounds` in Xgboost.XGBClassifier

        Returns
        -------
        feedback : Dict
            contain all training result in each tasks.
        task_index_url : str
            task extractor model path, used for task mining.
        """

        tasks, task_extractor, train_data = self._task_definition(train_data)
        self.extractor = task_extractor
        task_groups = self._task_relationship_discovery(tasks)
        self.models = []
        callback = None
        if isinstance(post_process, str):
            callback = ClassFactory.get_cls(ClassType.CALLBACK, post_process)()
        self.task_groups = []
        feedback = {}
        rare_task = []
        for i, task in enumerate(task_groups):
            if not isinstance(task, TaskGroup):
                rare_task.append(i)
                self.models.append(None)
                self.task_groups.append(None)
                continue
            if not (task.samples
                    and len(task.samples) > self.min_train_sample):
                self.models.append(None)
                self.task_groups.append(None)
                rare_task.append(i)
                n = len(task.samples)
                LOGGER.info(f"Sample {n} of {task.entry} will be merge")
                continue
            LOGGER.info(f"MTL Train start {i} : {task.entry}")

            model = None
            for t in task.tasks:  # if model has train in tasks
                if not (t.model and t.result):
                    continue
                model_path = t.model.save(model_name=f"{task.entry}.model")
                t.model = model_path
                model = Model(index=i,
                              entry=t.entry,
                              model=model_path,
                              result=t.result)
                model.meta_attr = t.meta_attr
                break
            if not model:
                model_obj = set_backend(estimator=self.base_model)
                res = model_obj.train(train_data=task.samples, **kwargs)
                if callback:
                    res = callback(model_obj, res)
                model_path = model_obj.save(model_name=f"{task.entry}.model")
                model = Model(index=i,
                              entry=task.entry,
                              model=model_path,
                              result=res)

                model.meta_attr = [t.meta_attr for t in task.tasks]
            task.model = model
            self.models.append(model)
            feedback[task.entry] = model.result
            self.task_groups.append(task)

        if len(rare_task):
            model_obj = set_backend(estimator=self.base_model)
            res = model_obj.train(train_data=train_data, **kwargs)
            model_path = model_obj.save(model_name="global.model")
            for i in rare_task:
                task = task_groups[i]
                entry = getattr(task, 'entry', "global")
                if not isinstance(task, TaskGroup):
                    task = TaskGroup(entry=entry, tasks=[])
                model = Model(index=i,
                              entry=entry,
                              model=model_path,
                              result=res)
                model.meta_attr = [t.meta_attr for t in task.tasks]
                task.model = model
                task.samples = train_data
                self.models[i] = model
                feedback[entry] = res
                self.task_groups[i] = task

        task_index = {
            "extractor": self.extractor,
            "task_groups": self.task_groups
        }
        if valid_data:
            feedback, _ = self.evaluate(valid_data, **kwargs)
        try:
            FileOps.dump(task_index, self.task_index_url)
        except TypeError:
            return feedback, task_index
        return feedback, self.task_index_url