def push_train_task_to_redis(nlp_task, doc_type, train_task_id, model_version, train_config, mark_job_ids, custom=None): if nlp_task == NlpTaskEnum.classify: r.lpush( _get('CLASSIFY_MODEL_QUEUE_KEY'), json.dumps({ "version": model_version, "task_type": 'train', "event_id": train_task_id, "configs": train_config, "data_path": generate_classify_data(mark_job_ids), "label": DocTypeSchema().dump(doc_type), "custom": custom, "use_rule": 0, })) else: prefix_map = {"extract": "NER", "relation": "RE", "wordseg": "WS"} r.lpush( _get("{}_TRAIN_QUEUE_KEY".format(nlp_task.name.upper())), json.dumps({ "version": model_version, "doctype": prefix_map[nlp_task.name] + str(doc_type.doc_type_id), "tasks": mark_job_ids, "model_type": nlp_task.name, "configs": [json.dumps(x) for x in train_config], }))
def push_evaluate_task_to_redis(nlp_task, evaluate_task: EvaluateTask, train_task: TrainTask, doc_type: DocType, mark_job_ids, doc_term_ids, doc_relation_ids, use_rule): if nlp_task == NlpTaskEnum.classify: r.lpush(_get('CLASSIFY_MODEL_QUEUE_KEY'), json.dumps({ "version": train_task.model_version, "task_type": 'evaluate', "event_id": evaluate_task.evaluate_task_id, "configs": train_task.train_config, "data_path": generate_classify_data(mark_job_ids), "label": DocTypeSchema().dump(doc_type), "use_rule": use_rule, })) else: push_dict = { "evaluate_id": evaluate_task.evaluate_task_id, "model_version": train_task.model_version, "tasks": mark_job_ids, "fields": doc_term_ids, 'model_type': nlp_task.name, 'model_reload_result': {}, } if nlp_task == NlpTaskEnum.relation: push_dict.update({"relation_fields": doc_relation_ids}) r.lpush(_get("EXTRACT_EVALUATE_QUEUE_KEY"), json.dumps(push_dict))
def update_train_task_by_id(train_job_id, train_task_id, is_check_train_terms, model_type, args): """ 1. 根据字段状态更新训练状态和结果 2. 直接设置训练状态和结果 3. 模型上线状态更新(分类和抽取还不一样) """ train_job = TrainJobModel().get_by_id(train_job_id) train_task = TrainTaskModel().get_by_id(train_task_id) if is_check_train_terms: # 是否需要检查train_term的状态 _, training_terms = TrainTermTaskModel().get_by_filter(limit=99999, train_task_id=train_task_id, train_term_status=int(StatusEnum.training)) _, failed_terms = TrainTermTaskModel().get_by_filter(limit=99999, train_task_id=train_task_id, train_term_status=int(StatusEnum.fail)) if not training_terms: # 没有处于训练中 if not failed_terms: # 没有处于失败的 args["train_status"] = int(StatusEnum.success) else: args["train_status"] = int(StatusEnum.fail) else: args["train_status"] = int(StatusEnum.training) else: # no limit to set model_train_state=success/failed if args["train_status"] == int(StatusEnum.online): # validation if train_task.train_status == StatusEnum.online: abort(400, message="该模型已经上线") if train_task.train_status != StatusEnum.success: abort(400, message="只能上线训练成功的模型") # send model train http request service_url = _get("CLASSIFY_MODEL_ONLINE") if model_type == "classify" else _get("EXTRACT_MODEL_ONLINE") resp = requests.post(f"{service_url}?model_version={train_task.model_version}") if resp.status_code < 200 or resp.status_code >= 300: abort(500, message=f"上线服务 <{service_url}> 出现错误: {resp.text}") # find all online model under this doc_type_id online_models = TrainTaskModel().get_by_doc_type_id(doc_type_id=train_job.doc_type_id, train_status=int(StatusEnum.online)) # unload online models TrainTaskModel().bulk_update([train.train_task_id for train in online_models], train_status=int(StatusEnum.success)) # update train task train_task = TrainTaskModel().update(train_task_id, **args) session.commit() return train_task
def export_with_annotation( self, labels: typing.List[typing.Dict]) -> typing.Dict: """ :param labels: 标注数据 支持格式为[{ "index": 12 # 全文下标位置, "word": "中国" # 正文内容 "color": "#ccc" # 颜色,必须是#开头 "annotation": "注释名称" # tooltip展示内容 }] :return: { "file_name": '', "path": '' } """ annotation_fileset = FileSet(folder='') for label in labels: label["color"] = self.hex_to_int(label["color"]) data = { "input_path": annotation_fileset.get_relative_path(self.unique_name, replace_ext='pdf'), "output_path": annotation_fileset.get_relative_path(self.unique_name, replace_ext='pdf', suffix='_print'), "pdf_json_path": annotation_fileset.get_relative_path(self.unique_name, replace_ext='json'), "content_path": annotation_fileset.get_relative_path(self.unique_name, replace_ext='txt'), "labels": labels } r = requests.post(_get('PDF_PRINTER'), json=data, timeout=600) if r.status_code != 200: logger.error(f'label export request failed, response is {r.text}') raise Exception("导出PDF服务出现异常,请联系运维人员进行解决") return self.export(replace_ext='pdf', suffix='_print')
def push_mark_task_message(mark_job, mark_task, doc, business, use_rule=False): r.lpush(_get('EXTRACT_TASK_QUEUE_KEY'), json.dumps({ 'files': [ { 'file_name': doc.doc_unique_name, 'is_scan': mark_job.mark_job_type == FileTypeEnum.ocr, 'doc_id': doc.doc_id, 'doc_type': mark_job.doc_type_id, }, ], 'is_multi': False, "use_rule": use_rule, 'doc_id': doc.doc_id, 'doc_type': mark_job.doc_type_id, 'business': business, 'task_id': mark_task.mark_task_id, 'app_id': g.app_id }))
def create_export_task(current_user: CurrentUser, mark_job_ids, mark_type, export_type): # raise no result found exception redis_message = {} doc_type_id = MarkJobModel().get_by_id(int(mark_job_ids.split(',')[0])).doc_type_id doc_terms = [str(row.doc_term_id) for row in DocTermModel().get_by_filter(doc_type_id=doc_type_id)] if mark_type == 'wordseg': doc_terms = ['10086'] elif mark_type == 'relation': relation_2_entity_mapping = [{i[0]: [d for d in i[1].split(",")]} for i in RelationM2mTermModel.get_relation_term_mapping(doc_type_id)] redis_message.update({ 'relation_2_entity_mapping': relation_2_entity_mapping, }) version = '{}{}_{}_{}'.format(datetime.now().strftime("%Y%m%d%H%M%S"), str(uuid.uuid4())[:4], doc_type_id, mark_job_ids) file_path = 'upload/export/{}.zip'.format(version) new_export_job = ExportJobModel().create(**{ "export_file_path": file_path, "doc_type_id": doc_type_id, "created_by": current_user.user_id, "export_job_status": StatusEnum.processing.value, "export_mark_job_ids": [int(i) for i in mark_job_ids.split(',')] }) export_id = new_export_job.export_job_id session.commit() # 发送给offline nlp redis_message.update({ 'export_id': export_id, 'export_type': export_type, 'file_path': file_path, 'version': version, 'doc_type': doc_type_id, 'fields': ','.join(doc_terms), 'mark_job_ids': mark_job_ids, 'task_type': mark_type, }) r.lpush(_get('DATA_EXPORT_QUEUE_KEY'), json.dumps(redis_message))