def update_status(self, data: KBUpdateResult = Body(...)): deploy = True if data.status else False tasks = data.tasks.split(",") if data.tasks else [] with Session(bind=engine) as session: session.query(TaskGrp).filter(TaskGrp.name.in_(tasks)).update( {TaskGrp.deploy: deploy}, synchronize_session=False) # todo: get from kb _index_path = FileOps.join_path(self.save_dir, self.kb_index) task_info = joblib.load(_index_path) new_task_group = [] default_task = task_info["task_groups"][0] # todo: get from transfer learning for task_group in task_info["task_groups"]: if not ((task_group.entry in tasks) == deploy): new_task_group.append(default_task) continue new_task_group.append(task_group) task_info["task_groups"] = new_task_group _index_path = FileOps.join_path(self.save_dir, self.kb_index) FileOps.dump(task_info, _index_path) return f"/file/download?files={self.kb_index}&name={self.kb_index}"
def train(self, train_data, valid_data=None, post_process=None, action="initial", **kwargs): """ fit for update the knowledge based on training data. Parameters ---------- train_data : BaseDataSource Train data, see `sedna.datasources.BaseDataSource` for more detail. valid_data : BaseDataSource Valid data, BaseDataSource or None. post_process : function function or a registered method, callback after `estimator` train. action : str `update` or `initial` the knowledge base kwargs : Dict parameters for `estimator` training, Like: `early_stopping_rounds` in Xgboost.XGBClassifier Returns ------- train_history : object """ callback_func = None if post_process is not None: callback_func = ClassFactory.get_cls(ClassType.CALLBACK, post_process) res, task_index_url = self.estimator.train( train_data=train_data, valid_data=valid_data, **kwargs ) # todo: Distinguishing incremental update and fully overwrite if isinstance(task_index_url, str) and FileOps.exists(task_index_url): task_index = FileOps.load(task_index_url) else: task_index = task_index_url extractor = task_index['extractor'] task_groups = task_index['task_groups'] model_upload_key = {} for task in task_groups: model_file = task.model.model save_model = FileOps.join_path(self.config.output_url, os.path.basename(model_file)) if model_file not in model_upload_key: model_upload_key[model_file] = FileOps.upload( model_file, save_model) model_file = model_upload_key[model_file] try: model = self.kb_server.upload_file(save_model) except Exception as err: self.log.error( f"Upload task model of {model_file} fail: {err}") model = set_backend( estimator=self.estimator.estimator.base_model) model.load(model_file) task.model.model = model for _task in task.tasks: sample_dir = FileOps.join_path( self.config.output_url, f"{_task.samples.data_type}_{_task.entry}.sample") task.samples.save(sample_dir) try: sample_dir = self.kb_server.upload_file(sample_dir) except Exception as err: self.log.error( f"Upload task samples of {_task.entry} fail: {err}") _task.samples.data_url = sample_dir save_extractor = FileOps.join_path( self.config.output_url, KBResourceConstant.TASK_EXTRACTOR_NAME.value) extractor = FileOps.dump(extractor, save_extractor) try: extractor = self.kb_server.upload_file(extractor) except Exception as err: self.log.error(f"Upload task extractor fail: {err}") task_info = {"task_groups": task_groups, "extractor": extractor} fd, name = tempfile.mkstemp() FileOps.dump(task_info, name) index_file = self.kb_server.update_db(name) if not index_file: self.log.error(f"KB update Fail !") index_file = name FileOps.upload(index_file, self.config.task_index) task_info_res = self.estimator.model_info( self.config.task_index, relpath=self.config.data_path_prefix) self.report_task_info(None, K8sResourceKindStatus.COMPLETED.value, task_info_res) self.log.info(f"Lifelong learning Train task Finished, " f"KB idnex save in {self.config.task_index}") return callback_func(self.estimator, res) if callback_func else res
def update(self, task: UploadFile = File(...)): tasks = task.file.read() fd, name = tempfile.mkstemp() with open(name, "wb") as fout: fout.write(tasks) os.close(fd) upload_info = joblib.load(name) with Session(bind=engine) as session: for task_group in upload_info["task_groups"]: grp, g_create = get_or_create(session=session, model=TaskGrp, name=task_group.entry) if g_create: grp.sample_num = 0 grp.task_num = 0 session.add(grp) grp.sample_num += len(task_group.samples) grp.task_num += len(task_group.tasks) t_id = [] for task in task_group.tasks: t_obj, t_create = get_or_create(session=session, model=Tasks, name=task.entry) if task.meta_attr: t_obj.task_attr = json.dumps(task.meta_attr) if t_create: session.add(t_obj) sample_obj = Samples(data_type=task.samples.data_type, sample_num=len(task.samples), data_url=getattr( task, 'data_url', '')) session.add(sample_obj) session.flush() session.commit() tsample = TaskSample(sample=sample_obj, task=t_obj) session.add(tsample) session.flush() t_id.append(t_obj.id) model_obj, m_create = get_or_create(session=session, model=TaskModel, task=grp) model_obj.model_url = task_group.model.model model_obj.is_current = False if m_create: session.add(model_obj) session.flush() session.commit() transfer_radio = 1 / grp.task_num for t in t_id: t_obj, t_create = get_or_create(session=session, model=TaskRelation, task_id=t, grp=grp) t_obj.transfer_radio = transfer_radio if t_create: session.add(t_obj) session.flush() session.commit() session.query(TaskRelation).filter( TaskRelation.grp == grp).update( {"transfer_radio": transfer_radio}) session.commit() # todo: get from kb _index_path = FileOps.join_path(self.save_dir, self.kb_index) _index_path = FileOps.dump(upload_info, _index_path) return f"/file/download?files={self.kb_index}&name={self.kb_index}"
def train(self, train_data: BaseDataSource, valid_data: BaseDataSource = None, post_process=None, **kwargs): """ fit for update the knowledge based on training data. Parameters ---------- train_data : BaseDataSource Train data, see `sedna.datasources.BaseDataSource` for more detail. valid_data : BaseDataSource Valid data, BaseDataSource or None. post_process : function function or a registered method, callback after `estimator` train. kwargs : Dict parameters for `estimator` training, Like: `early_stopping_rounds` in Xgboost.XGBClassifier Returns ------- feedback : Dict contain all training result in each tasks. task_index_url : str task extractor model path, used for task mining. """ tasks, task_extractor, train_data = self._task_definition(train_data) self.extractor = task_extractor task_groups = self._task_relationship_discovery(tasks) self.models = [] callback = None if isinstance(post_process, str): callback = ClassFactory.get_cls(ClassType.CALLBACK, post_process)() self.task_groups = [] feedback = {} rare_task = [] for i, task in enumerate(task_groups): if not isinstance(task, TaskGroup): rare_task.append(i) self.models.append(None) self.task_groups.append(None) continue if not (task.samples and len(task.samples) > self.min_train_sample): self.models.append(None) self.task_groups.append(None) rare_task.append(i) n = len(task.samples) LOGGER.info(f"Sample {n} of {task.entry} will be merge") continue LOGGER.info(f"MTL Train start {i} : {task.entry}") model = None for t in task.tasks: # if model has train in tasks if not (t.model and t.result): continue model_path = t.model.save(model_name=f"{task.entry}.model") t.model = model_path model = Model(index=i, entry=t.entry, model=model_path, result=t.result) model.meta_attr = t.meta_attr break if not model: model_obj = set_backend(estimator=self.base_model) res = model_obj.train(train_data=task.samples, **kwargs) if callback: res = callback(model_obj, res) model_path = model_obj.save(model_name=f"{task.entry}.model") model = Model(index=i, entry=task.entry, model=model_path, result=res) model.meta_attr = [t.meta_attr for t in task.tasks] task.model = model self.models.append(model) feedback[task.entry] = model.result self.task_groups.append(task) if len(rare_task): model_obj = set_backend(estimator=self.base_model) res = model_obj.train(train_data=train_data, **kwargs) model_path = model_obj.save(model_name="global.model") for i in rare_task: task = task_groups[i] entry = getattr(task, 'entry', "global") if not isinstance(task, TaskGroup): task = TaskGroup(entry=entry, tasks=[]) model = Model(index=i, entry=entry, model=model_path, result=res) model.meta_attr = [t.meta_attr for t in task.tasks] task.model = model task.samples = train_data self.models[i] = model feedback[entry] = res self.task_groups[i] = task task_index = { "extractor": self.extractor, "task_groups": self.task_groups } if valid_data: feedback, _ = self.evaluate(valid_data, **kwargs) try: FileOps.dump(task_index, self.task_index_url) except TypeError: return feedback, task_index return feedback, self.task_index_url
def save(self, output=""): return FileOps.dump(self, output)