class TaskListResource(Resource, CurrentUserMixin): @parse({ "query": fields.String(missing=''), "offset": fields.Integer(missing=0), "limit": fields.Integer(missing=10), 'order_by': fields.String(missing='-task_id'), "doc_type_id": fields.Integer(), "job_id": fields.Integer(), "job_type": fields.String( required=True, validate=lambda x: x in ('mark', 'classify_mark', 'relation_mark', 'wordseg_mark')), "task_state": fields.String(missing="", validate=lambda x: x in ("", "processing", "failed", "success", "unaudit", "audited", "unlabel")) }) def get(self: Resource, args: typing.Dict): count, processed, result = ManualTaskService( ).get_user_task_or_mark_task_result_by_role(self.get_current_user(), args) return { "message": "请求成功", "result": result, "count": count, "processed": processed }, 200
class UpdateModelEvaluateResource(Resource): @parse({ "model_evaluate_id": fields.Integer(required=True), "model_evaluate_state": fields.String(required=True), "model_evaluate_result": fields.Dict(), "model_type": fields.String(required=True, validate=lambda x: x in ('extract', 'classify', 'relation', 'wordseg')) }) def put(self, args): """ 更新一条评估记录 """ update_params = {} if args.get("model_evaluate_state"): update_params.update(evaluate_task_status=status_str2int_mapper()[ args["model_evaluate_state"]]) if args.get("model_evaluate_result"): update_params.update( evaluate_task_result=args["model_evaluate_result"]) evaluate_task = ModelEvaluateService().update_evaluate_task_by_id( evaluate_task_id=args["model_evaluate_id"], args=update_params) result = EvaluateTaskSchema().dump(evaluate_task) return { "message": "更新成功", "result": result, }, 201
class WordsegMarkJobImportResource(Resource): @parse( { "mark_job_name": fields.String(required=True), "mark_job_type": fields.String(required=True), "mark_job_desc": fields.String(), "doc_type_id": fields.Integer(required=True), "files": fields.List(fields.File(), required=True), "task_type": fields.String(required=True, validate=lambda x: x in ['machine', 'manual']), }, locations=('form', 'files')) def post(self: Resource, args: typing.Dict) -> typing.Tuple[typing.Dict, int]: """ 上传已标注数据 """ files = args['files'] # validate file extensions for f in files: if get_ext(f.filename) not in ["txt"]: abort(400, message="上传已标注分词数据仅支持txt格式。") result = MarkJobService().import_mark_job(files, args, nlp_task=NlpTaskEnum.wordseg) return {"message": "创建成功", "result": result}, 201
class ClassifyMarkJobImportResource(Resource): @parse({ "mark_job_name": fields.String(required=True), "mark_job_type": fields.String(required=True), "mark_job_desc": fields.String(), "doc_type_id": fields.Integer(required=True), "files": fields.List(fields.File(), required=True), }, locations=('form', 'files')) def post( self: Resource, args: typing.Dict ) -> typing.Tuple[typing.Dict, int]: files = args['files'] for f in files: if get_ext(f.filename) not in ["csv"]: abort(400, message="已标注分类数据仅支持csv格式。") try: result = MarkJobService().import_mark_job(files, args, nlp_task=NlpTaskEnum.classify) return { "message": "创建成功", "result": result }, 201 except UnicodeDecodeError: abort(400, message="文件编码错误 请上传utf-8编码文件") except KeyError: abort(400, message="文件格式不合规 请查看csv文件模版")
class TrainTermItemResource(Resource): @parse({ "train_term_state": fields.String(), "train_term_result": fields.Dict(), "model_type": fields.String(required=True, validate=lambda x: x in ('extract', 'classify')), }) def patch( self: Resource, args: typing.Dict, model_id: int, model_train_id: int, train_term_id: int, ) -> typing.Tuple[typing.Dict, int]: """ 修改模型训练的一个字段状态 """ update_params = {} if args.get("train_term_state"): update_params.update(train_term_status=status_str2int_mapper()[args["train_term_state"]]) if args.get("train_term_result"): update_params.update(train_term_result=args["train_term_result"]) train_term_task = ModelTrainService().update_train_task_term_by_id(train_term_task_id=train_term_id, args=update_params) result = TrainTermTaskSchema().dump(train_term_task) return { "message": "更新成功", "result": result, }, 200
class RelationDocTypeItemResource(Resource, CurrentUserMixin): def get(self: Resource, doc_type_id: int) -> typing.Tuple[typing.Dict, int]: """ 获取一个文档类型 """ result = DocTypeService().get_doc_type_items(doc_type_id) return { "message": "请求成功", "result": result, }, 200 @parse({ "doc_type_name": fields.String(), "doc_type_desc": fields.String(), }) def patch(self: Resource, args: typing.Dict, doc_type_id: int) -> typing.Tuple[typing.Dict, int]: """ 修改一个文档类型,不包括修改它的条款 """ result = DocTypeService().update_relation_doc_type(args, doc_type_id) return { "message": "更新成功", "result": result, }, 201 def delete(self: Resource, doc_type_id: int) -> typing.Tuple[typing.Dict, int]: """ 删除一个文档类型 """ DocTypeService().delete_doc_type(doc_type_id) return { "message": "删除成功", }, 204
class UpdateTrainTermResource(Resource): @parse({ "model_version": fields.String(required=True), "doc_term_id": fields.Integer(required=True), "train_term_state": fields.String(required=True), "train_term_result": fields.Dict(), "term_type": fields.String(required=True), "model_type": fields.String(required=True, validate=lambda x: x in ('extract', 'classify', 'relation', 'wordseg')) }) def put(self, args): """ 修改模型训练的一个字段状态 """ update_params = {} if args.get("train_term_state"): update_params.update(train_term_status=status_str2int_mapper()[args["train_term_state"]]) if args.get("train_term_result"): update_params.update(train_term_result=args["train_term_result"]) train_term_task = ModelTrainService().update_train_term_by_model_version_and_doc_term_id(model_version=args["model_version"], doc_term_id=args["doc_term_id"], args=update_params) result = TrainTermTaskSchema().dump(train_term_task) return { "message": "更新成功", "result": result, }, 201
class WordsegDocLexiconItemResource(Resource): def get(self, doc_type_id, doc_lexicon_id): result = DocTermService().get_wordseg_lexicon_item(doc_lexicon_id) return { "message": "请求成功", "result": result, }, 200 @parse({ "seg_type": fields.String(required=True), "word": fields.String(required=True), "state": fields.Integer(required=True) }) def put(self, args, doc_type_id, doc_lexicon_id): args.update({"is_active": args.pop("state")}) result = DocTermService().update_wordseg_lexicon(doc_lexicon_id, args) return { "message": "更新成功", "result": result, }, 200 def delete(self, doc_type_id, doc_lexicon_id): DocTermService().delete_wordseg_lexicon_by_id(doc_lexicon_id) return { "message": "删除成功", }, 204
class WordsegDocLexiconListResource(Resource): @parse({ "offset": fields.Integer(missing=0), "limit": fields.Integer(missing=10), }) def get(self, args, doc_type_id): """ 规则列表 """ result, count = DocTermService().get_wordseg_lexicon(doc_type_id, args.get("offset"), args.get("limit")) return { "message": "请求成功", "result": result, "count": count, }, 200 @parse({ "seg_type": fields.String(required=True), "word": fields.String(required=True), "state": fields.Integer(required=True) }) def post(self, args, doc_type_id): args.update({"doc_type_id": doc_type_id}) args.update({"is_active": args.pop("state")}) result = DocTermService().create_wordseg_lexicon(args) return { "message": "创建成功", "result": result, }, 201
class EntityDocTermItemResource(Resource): @parse({ "doc_term_name": fields.String(), "doc_term_color": fields.String(), "doc_term_index": fields.Integer(allow_none=True, default=1), "doc_term_desc": fields.String(allow_none=True), "doc_term_data_type": fields.String(), }) def patch(self: Resource, args: typing.Dict, doc_type_id: int, doc_term_id: int) -> typing.Tuple[ typing.Dict, int]: """ 修改一个条款 """ result = DocTermService().update_doc_term(doc_term_id, args) return { "message": "更新成功", "result": result, }, 201 def delete(self: Resource, doc_type_id: int, doc_term_id: int) -> typing.Tuple[typing.Dict, int]: """ 删除一个条款 """ if DocTermService().check_term_in_relation(doc_term_id): abort(400, message="该条款仍有关联关系,请确保条款没有关联关系后再做清除") DocTermService().remove_doc_term(doc_term_id) session.commit() return { "message": "删除成功", }, 204
class TaskItemNextResource(Resource, CurrentUserMixin): @parse({ "job_id": fields.Integer(), "job_type": fields.String( required=True, validate=lambda x: x in ('mark', 'classify_mark', 'relation_mark', 'wordseg_mark')), "task_state": fields.String(missing="", validate=lambda x: x in ("", "processing", "failed", "success", "unaudit", "audited", "unlabel")), "query": fields.String(missing=""), }) def get(self: Resource, args: typing.Dict, task_id: int) -> typing.Tuple[typing.Dict, int]: preview_task_id, next_task_id = ManualTaskService( ).get_preview_and_next_task_id(self.get_current_user(), task_id, args) return { "message": "请求成功", "next_id": next_task_id, "prev_id": preview_task_id }, 200
class ExportJobSchema(Schema): export_id = fields.Integer(attribute="export_job_id") file_path = fields.String(attribute="export_file_path") mark_type = fields.Function(lambda obj: NlpTaskEnum(obj.nlp_task_id).name) # nlp_task_id export_state = fields.Function(lambda obj: StatusEnum(obj.export_job_status).name) project_name = fields.String(attribute="doc_type_name") created_time = fields.String() mark_job_ids = fields.List(fields.Integer(), attribute="export_mark_job_ids")
class PredictJobSchema(Schema): doc_type = fields.Nested(DocTypeSchema) task_list = fields.List(fields.Nested(PredictTaskSchema)) extract_job_id = fields.Integer(attribute="predict_job_id") extract_job_name = fields.String(attribute="predict_job_name") extract_job_type = fields.String(attribute="predict_job_type.value") extract_job_state = fields.Function(lambda obj: StatusEnum(obj.predict_job_status).name) extract_job_desc = fields.String(attribute="predict_job_desc") is_batch = fields.Boolean() created_time = fields.DateTime()
class ExportHistoryResource(Resource, CurrentUserMixin): @parse({ "query": fields.String(missing=""), "offset": fields.Integer(missing=0), "limit": fields.Integer(missing=10), "model_type": fields.String(missing="", validate=lambda x: x in ('', 'extract', 'classify', 'relation', 'wordseg')) }) def get(self: Resource, args: Dict[str, Any]) -> Tuple[Dict[str, Any], int]: """ 获取标注导出记录 """ result, count = ExportService().get_export_history( self.get_current_user(), args) return { "message": "请求成功", "result": result, "count": count, }, 200 @parse({ "mark_type": fields.String(required=True, validate=lambda x: x in (DB_TABLE_2_BUSINESS_MAPPING)), "mark_job_ids": fields.String(required=True), "export_type": fields.String(required=True, validate=lambda x: x in ('BMES', 'BIO', 'label_analysis', 'all')), }) def post(self: Resource, args: Dict[str, Any]) -> Tuple[Dict[str, Any], int]: """ 新建标注导出记录 """ mark_job_ids = args.get('mark_job_ids') mark_type = args.get('mark_type') if mark_type in DB_TABLE_2_BUSINESS_MAPPING: mark_type = DB_TABLE_2_BUSINESS_MAPPING[mark_type] else: abort(400, message='不支持的任务类型') export_type = args.get('export_type') ExportService().create_export_task(self.get_current_user(), mark_job_ids, mark_type, export_type) return { "message": "创建成功", }, 201
class EvaluateTaskSchema(Schema): model_evaluate_id = fields.Integer(attribute="evaluate_task_id") model_evaluate_name = fields.String(attribute="evaluate_task_name") model_evaluate_desc = fields.String(attribute="evaluate_task_desc") model_evaluate_state = fields.Function( lambda obj: StatusEnum(obj.evaluate_task_status).name) model_evaluate_result = fields.Dict(attribute="evaluate_task_result") model_id = fields.Integer(attribute="train_job_id") mark_job_ids = fields.List(fields.Integer()) created_time = fields.DateTime() last_updated_time = fields.DateTime(attribute="updated_time")
class TrainJobSchema(Schema): # type: ignore model_id = fields.Integer(attribute='train_job_id') model_name = fields.String(attribute="train_job_name") model_desc = fields.String(attribute="train_job_desc") status = fields.Function(lambda obj: not obj.is_deleted) doc_type = fields.Nested(DocTypeSchema) created_time = fields.String() model_version = fields.String() train_list = fields.List(fields.Nested(TrainTaskSchema)) model_evaluate = fields.Nested(EvaluateTaskSchema) preprocess = fields.Dict()
class ModelListResource(Resource, CurrentUserMixin): @parse({ "query": fields.String(missing=''), "offset": fields.Integer(missing=0), "limit": fields.Integer(missing=10), "doc_type_id": fields.Integer(missing=0), 'order_by': fields.String(missing='-created_time'), }) def get(self: Resource, args: Dict[str, Any]) -> Tuple[Dict[str, Any], int]: """ 获取模型记录,分页 """ nlp_task_id = Common.get_nlp_task_id_by_route() count, train_job_list = ModelService( ).get_train_job_list_by_nlp_task_id( nlp_task_id=nlp_task_id, doc_type_id=args['doc_type_id'], search=args['query'], offset=args['offset'], limit=args['limit'], current_user=self.get_current_user()) result = TrainJobSchema().dump(train_job_list, many=True) return { "message": "请求成功", "result": result, "count": count, }, 200 @parse({ "model_name": fields.String(required=True), "model_desc": fields.String(missing=""), "doc_type_id": fields.Integer(required=True), "model_train_config": fields.Raw(required=True ), # algorithm_type = ('extract', 'ner', 'seg', 'pos') "mark_job_ids": fields.List(fields.Integer(), missing=[]), }) def post(self: Resource, args: Dict[str, Any]) -> Tuple[Dict[str, Any], int]: """ 创建模型 """ # create new train_job = ModelService().create_train_job_by_doc_type_id( doc_type_id=args["doc_type_id"], train_job_name=args["model_name"], train_job_desc=args["model_desc"], train_config=args["model_train_config"], mark_job_ids=args["mark_job_ids"]) result = TrainJobSchema().dump(train_job) return {"message": "创建成功", "result": result}, 201
class AsyncMQResource(Resource): @parse({ "task_result": fields.Raw(), "message": fields.Raw(), "task_state": fields.String(required=True, validate=lambda x: x in ('success', 'failed')), "error_message": fields.String() }) def post(self: Resource, args: typing.Dict) -> typing.Tuple[typing.Dict, int]: """ message queue回调统一入口 """ message = args['message'] logger.info(f"receive callback info from mq. response is: {json.dumps(args)}") if message['business'] in [ 'label', # 实体预标注 'classify_label', # 分类预标注 'relation_label', # 实体关系预标注 'wordseg_label' # 分词预标注 ]: update_params = {} if args.get("task_state"): if args['task_state'] == 'success': # 如果mq预标注返回成功,则初试状态是unlabel update_params.update(mark_task_status=int(StatusEnum.unlabel)) else: # 如果mq预标注返回失败,则初试状态是fail update_params.update(mark_task_status=int(StatusEnum.fail)) if args.get("task_result"): update_params.update(mark_task_result=args["task_result"]) mark_task, user_task_list = MarkJobService()\ .update_mark_task_and_user_task_by_mark_task_id(mark_task_id=message["task_id"], args=update_params) MarkJobService().update_mark_job_status_by_mark_task(mark_task=mark_task) result = UserTaskSchema(many=True).dump(user_task_list) return { "message": "更新成功", "result": result, }, 201 elif message['business'] in [ 'extract', # 实体抽取 'classify_extract', # 分类抽取 'relation_extract', # 实体关系抽取 'wordseg_extract' # 分词抽取 ]: update_params = {} if args.get("task_state"): update_params.update(predict_task_status=status_str2int_mapper()[args["task_state"]]) if args.get("task_result"): update_params.update(predict_task_result=args["task_result"]) predict_task = PredictService().update_predict_task_by_id(predict_task_id=message["task_id"], args=update_params) result = PredictTaskSchema().dump(predict_task) return { "message": "更新成功", "result": result, }, 201
class DocTypeSchema(Schema): doc_terms = fields.List(fields.Integer()) doc_term_list = fields.List(fields.Nested(DocTermSchema)) doc_relation_list = fields.List(fields.Nested(EntityDocRelationSchema)) doc_lexicon_list = fields.List(fields.Nested(WordsegDocLexiconSchema), attribute='doc_rules') doc_type_id = fields.Integer() doc_type_name = fields.String() doc_type_desc = fields.String() is_top = fields.Boolean(attribute="is_favorite") created_time = fields.DateTime() group_id = fields.Integer() status = fields.Function(lambda obj: not obj.is_deleted)
class MarkTaskSchema(Schema): # type: ignore task_id = fields.Integer(attribute="mark_task_id") doc = fields.Nested(DocSchema) doc_type = fields.Nested({ "doc_type_id": fields.Integer(), "doc_type_name": fields.String(), "doc_type_desc": fields.String(), }) user_task_list = fields.List(fields.Nested(UserTaskSchema)) task_state = fields.Function( lambda obj: StatusEnum(obj.mark_task_status).name) status = fields.Function(lambda obj: not obj.is_deleted) created_time = fields.String() task_result = fields.List(fields.Dict, attribute="mark_task_result")
class ClassifyDocRuleListResource(Resource): @parse({ "offset": fields.Integer(missing=0), "limit": fields.Integer(missing=10), "timestamp": fields.String(), }) def get(self, args, doc_type_id): """ 规则列表 """ redis_key = f'classify:rule:{doc_type_id}' if args.get("timestamp"): try: result = json.loads(r.get(redis_key)) if result['timestamp'] == args["timestamp"]: result.update(update=False) return result except Exception: pass # TODO 查询优化 result, count, timestamp = DocTermService().get_classify_doc_rule(doc_type_id, args.get("offset"), args.get("limit")) data = { "message": "请求成功", "result": result, "timestamp": timestamp, "update": True, "count": count, } if args.get("timestamp"): r.set(redis_key, json.dumps(data), ex=24 * 60 * 60) return data, 200 @parse({ "doc_term_id": fields.Integer(required=True), "rule_type": fields.String(required=True), "rule_content": fields.Dict(required=True) }) def post(self, args, doc_type_id): try: result = DocTermService().create_new_rule(args) DocTermService().update_rule_to_redis(doc_type_id) r.delete(f'classify:rule:{doc_type_id}') return { "message": "创建成功", "result": result, }, 201 except ValueError as e: abort(400, message=str(e))
class UserTaskSchema(Schema): doc = fields.Nested(DocSchema) doc_type = fields.Nested({ "doc_type_id": fields.Integer(), "doc_type_name": fields.String(), "doc_type_desc": fields.String(), }) task_id = fields.Integer(attribute="user_task_id") labeler_id = fields.Integer(attribute="annotator_id") manual_task_id = fields.Integer(attribute="mark_task_id") task_result = fields.List(fields.Dict, attribute="user_task_result") task_state = fields.Function( lambda obj: StatusEnum(obj.user_task_status).name) status = fields.Function(lambda obj: not obj.is_deleted) created_time = fields.String()
class ClassifyModelListResource(Resource, CurrentUserMixin): @parse({ "query": fields.String(missing=''), "offset": fields.Integer(missing=0), "limit": fields.Integer(missing=10), "doc_type_id": fields.Integer(missing=0), 'order_by': fields.String(missing='-created_time'), }) def get(self: Resource, args: Dict[str, Any]) -> Tuple[Dict[str, Any], int]: """ 获取模型记录,分页 """ count, train_job_list = ModelService( ).get_train_job_list_by_nlp_task_id( nlp_task_id=int(NlpTaskEnum.classify), doc_type_id=args['doc_type_id'], search=args['query'], offset=args['offset'], limit=args['limit'], current_user=self.get_current_user()) # get the serialized result result = TrainJobSchema().dump(train_job_list, many=True) return { "message": "请求成功", "result": result, "count": count, }, 200 @parse({ "model_name": fields.String(required=True), "model_desc": fields.String(missing=""), "doc_type_id": fields.Integer(required=True), "model_train_config": fields.Dict(required=True), "mark_job_ids": fields.List(fields.Integer(), missing=[]), "custom_id": fields.Integer(missing=0), }) def post(self: Resource, args: Dict[str, Any]) -> Tuple[Dict[str, Any], int]: train_job = ModelService().create_classify_train_job_by_doc_type_id( doc_type_id=args["doc_type_id"], train_job_name=args["model_name"], train_job_desc=args["model_desc"], train_config=args["model_train_config"], mark_job_ids=args["mark_job_ids"], custom_id=args['custom_id']) result = TrainJobSchema().dump(train_job) return {"message": "创建成功", "result": result}, 201
class EntityDocRelationItemResource(Resource): @parse({ "doc_relation_name": fields.String(), "doc_term_ids": fields.List(fields.Integer(), required=True) }) def patch(self: Resource, args: typing.Dict, doc_type_id: int, doc_relation_id: int) -> typing.Tuple[typing.Dict, int]: """ 修改一个条款 """ doc_term_ids = list(set(args.pop("doc_term_ids", []))) if len(doc_term_ids) != 2: abort(400, message="文档条款不正确,请确保填写了正确的文档条款") result = DocTypeService().update_relation(doc_type_id, args.get("doc_relation_name"), args.get("doc_term_ids", [])) return { "message": "更新成功", "result": result, }, 201 def delete(self: Resource, doc_type_id: int, doc_relation_id: int) -> typing.Tuple[typing.Dict, int]: """ 删除一个条款 """ DocTypeService().delete_relation(doc_relation_id) return { "message": "删除成功", }, 204
class EntityDocRelationListResource(Resource): @parse({ "doc_relation_ids": fields.List(fields.Integer(), missing=[]), "offset": fields.Integer(missing=0), "limit": fields.Integer(missing=10), }) def get(self, args: typing.Dict, doc_type_id: int) -> typing.Tuple[typing.Dict, int]: """ 获取所有条款,不分页 """ result, count = DocTypeService().get_relation_list(doc_type_id, args.get("offset"), args.get("limit"), doc_relation_ids=args.get("doc_relation_ids")) return { "message": "请求成功", "result": result, "count": count, }, 200 @parse({ "doc_relation_name": fields.String(required=True), "doc_term_ids": fields.List(fields.Integer(), required=True) }) def post(self, args: typing.Dict, doc_type_id: int) -> typing.Tuple[typing.Dict, int]: """ 创建一个关系 """ result = DocTypeService().create_relation(doc_type_id, args.get("doc_term_ids"), args.get("doc_relation_name")) return { "message": "创建成功", "result": result, }, 201
class CustomItemResource(Resource): def get(self: Resource, custom_id: int) -> Tuple[Dict[str, Any], int]: """ 获取单条自定义容器的记录 """ custom_algorithm = ModelCustomService().get_custom_algorithm_by_id( custom_algorithm_id=custom_id) result = CustomAlgorithmSchema().dump(custom_algorithm) return { "message": "请求成功", "result": result, }, 200 @parse({ 'custom_ip': fields.String(required=True), 'custom_port': fields.Integer(required=True), 'custom_evaluate_port': fields.Integer(required=True), 'custom_name': fields.String(required=True), 'custom_desc': fields.String(allow_none=True), 'custom_config': fields.String(allow_none=True), }) def put(self: Resource, args, custom_id: int) -> Tuple[Dict[str, Any], int]: # generate update params update_params = { "custom_algorithm_ip": args["custom_ip"], "custom_algorithm_predict_port": args["custom_port"], "custom_algorithm_evaluate_port": args["custom_evaluate_port"], "custom_algorithm_name": args["custom_name"] } if args.get("custom_desc"): update_params.update(custom_algorithm_desc=args["custom_desc"]) if args.get("custom_config"): update_params.update(custom_algorithm_config=args["custom_config"]) custom_algorithm = ModelCustomService().update_custom_algorithm_by_id( custom_algorithm_id=custom_id, args=update_params) result = CustomAlgorithmSchema().dump(custom_algorithm) return { "message": "更新成功", "result": result, }, 200 def delete(self: Resource, custom_id: int) -> Tuple[Dict[str, Any], int]: ModelCustomService().delete_custom_algorithm_by_id( custom_algorithm_id=custom_id) return {"message": "删除成功"}, 200
class CustomAlgorithmSchema(Schema): custom_id = fields.Integer(attribute="custom_algorithm_id") custom_ip = fields.String(attribute="custom_algorithm_ip") custom_port = fields.Integer(attribute="custom_algorithm_predict_port") custom_evaluate_port = fields.Integer( attribute="custom_algorithm_evaluate_port") custom_name = fields.String(attribute="custom_algorithm_name") custom_id_name = fields.String(attribute="custom_algorithm_alias") custom_desc = fields.String(attribute="custom_algorithm_desc") custom_type = fields.Function( lambda obj: "ner" if obj.nlp_task_id == NlpTaskEnum.extract and obj.preprocess.get( "split_by_sentence", False) else NlpTaskEnum(obj.nlp_task_id).name) custom_state = fields.Function( lambda obj: StatusEnum(obj.custom_algorithm_status).name) custom_config = fields.String(attribute="custom_algorithm_config") created_time = fields.DateTime() preprocess = fields.Dict()
class MachineTaskListResource(Resource, CurrentUserMixin): @parse({ "query": fields.String(missing=''), "offset": fields.Integer(missing=0), "limit": fields.Integer(missing=10), 'order_by': fields.String(missing='-created_time'), "job_type": fields.String(required=True, validate=lambda x: x in AVAILABLE_EXTRACT_JOB_TYPES), "task_state": fields.String(), "doc_type_id": fields.Integer(), "extract_job_id": fields.Integer(required=True), }) def get(self: Resource, args: typing.Dict): order_by = args["order_by"][1:] order_by_desc = True if args["order_by"][0] == "-" else False filtered_list = {} if args.get("task_state"): filtered_list.update(predict_task_status=status_str2int_mapper()[ args["task_state"]]) count, predict_task_list = TaskMachineService( ).get_predict_task_list_by_predict_job_id( predict_job_id=args["extract_job_id"], search=args['query'], order_by=order_by, order_by_desc=order_by_desc, offset=args['offset'], limit=args['limit'], current_user=self.get_current_user(), args=filtered_list) result = PredictTaskSchema( many=True, exclude=('task_result', )).dump(predict_task_list) return { "message": "请求成功", "result": result, "count": count, }, 200
class MarkJobSchema(Schema): mark_job_id = fields.Integer() mark_job_name = fields.String() mark_job_type = fields.Function(lambda obj: obj.mark_job_type.value) assign_mode = fields.Function(lambda obj: obj.assign_mode.value) mark_job_state = fields.Function( lambda obj: StatusEnum(obj.mark_job_status).name) mark_job_desc = fields.String() task_list = fields.List(fields.Integer()) created_time = fields.DateTime() assessor_id = fields.Function(lambda obj: obj.reviewer_ids[0] if len(obj.reviewer_ids) > 0 else 0) doc_type = fields.Nested(DocTypeSchema, exclude=('doc_term_list', )) labeler_ids = fields.List(fields.Integer(), attribute='annotator_ids') stats = fields.Nested({ "all": fields.Integer(), "labeled": fields.Integer(), "audited": fields.Integer() })
class TrainTaskSchema(Schema): # type: ignore model_train_id = fields.Integer(attribute="train_task_id") model_train_config = fields.Dict(attribute="train_config") model_train_state = fields.Function( lambda obj: StatusEnum(obj.train_status).name) model_id = fields.Integer(attribute="train_job_id") mark_job_ids = fields.List(fields.Integer()) train_terms = fields.List(fields.Nested(TrainTermTaskSchema)) created_time = fields.DateTime() last_updated_time = fields.DateTime(attribute="updated_time") model_version = fields.String()