示例#1
0
    def update_train_task_by_id(train_job_id, train_task_id, is_check_train_terms, model_type, args):
        """
        1. 根据字段状态更新训练状态和结果
        2. 直接设置训练状态和结果
        3. 模型上线状态更新(分类和抽取还不一样)
        """
        train_job = TrainJobModel().get_by_id(train_job_id)
        train_task = TrainTaskModel().get_by_id(train_task_id)

        if is_check_train_terms: # 是否需要检查train_term的状态
            _, training_terms = TrainTermTaskModel().get_by_filter(limit=99999, train_task_id=train_task_id,
                                                                   train_term_status=int(StatusEnum.training))
            _, failed_terms = TrainTermTaskModel().get_by_filter(limit=99999, train_task_id=train_task_id,
                                                                 train_term_status=int(StatusEnum.fail))
            if not training_terms:
                # 没有处于训练中
                if not failed_terms:
                    # 没有处于失败的
                    args["train_status"] = int(StatusEnum.success)
                else:
                    args["train_status"] = int(StatusEnum.fail)
            else:
                args["train_status"] = int(StatusEnum.training)
        else:
            # no limit to set model_train_state=success/failed
            if args["train_status"] == int(StatusEnum.online):
                # validation
                if train_task.train_status == StatusEnum.online:
                    abort(400, message="该模型已经上线")
                if train_task.train_status != StatusEnum.success:
                    abort(400, message="只能上线训练成功的模型")

                # send model train http request
                service_url = _get("CLASSIFY_MODEL_ONLINE") if model_type == "classify" else _get("EXTRACT_MODEL_ONLINE")
                resp = requests.post(f"{service_url}?model_version={train_task.model_version}")
                if resp.status_code < 200 or resp.status_code >= 300:
                    abort(500, message=f"上线服务 <{service_url}> 出现错误: {resp.text}")

                # find all online model under this doc_type_id
                online_models = TrainTaskModel().get_by_doc_type_id(doc_type_id=train_job.doc_type_id, train_status=int(StatusEnum.online))

                # unload online models
                TrainTaskModel().bulk_update([train.train_task_id for train in online_models], train_status=int(StatusEnum.success))

        # update train task
        train_task = TrainTaskModel().update(train_task_id, **args)
        session.commit()
        return train_task
示例#2
0
    def update_train_task_by_model_version(model_version, is_check_train_terms, args):
        train_task = TrainTaskModel().get_by_filter(model_version=model_version)[1][0]

        if is_check_train_terms:
            _, training_terms = TrainTermTaskModel().get_by_filter(limit=99999, train_task_id=train_task.train_task_id,
                                                                   train_term_status=int(StatusEnum.training))
            _, failed_terms = TrainTermTaskModel().get_by_filter(limit=99999, train_task_id=train_task.train_task_id,
                                                                 train_term_status=int(StatusEnum.fail))
            if not training_terms:
                # 没有处于训练中
                if not failed_terms:
                    # 没有处于失败的
                    args["train_status"] = int(StatusEnum.success)
                else:
                    args["train_status"] = int(StatusEnum.fail)
            else:
                args["train_status"] = int(StatusEnum.training)
        train_task = TrainTaskModel().update(train_task.train_task_id, **args)
        session.commit()
        return train_task
示例#3
0
 def update_train_task_term_by_id(train_term_task_id, args):
     train_term_task = TrainTermTaskModel().update(train_term_task_id, **args)
     session.commit()
     return train_term_task
示例#4
0
 def get_train_term_list_by_train_task_id(train_task_id, **kwargs) -> (int, [TrainTermTask]):
     count, result = TrainTermTaskModel().get_by_filter(limit=99999, train_task_id=train_task_id, **kwargs)
     return count, result
示例#5
0
 def update_train_term_by_model_version_and_doc_term_id(model_version, doc_term_id, args):
     train_term = TrainTermTaskModel().get_by_model_version_and_doc_term_id(model_version=model_version, doc_term_id=doc_term_id)
     train_term_task = TrainTermTaskModel().update(train_term.train_term_task_id, **args)
     session.commit()
     return train_term_task
示例#6
0
    def create_classify_train_job_by_doc_type_id(doc_type_id, train_job_name,
                                                 train_job_desc, train_config,
                                                 mark_job_ids, custom_id):
        # verify doc_type
        doc_type = DocTypeModel().get_by_id(doc_type_id)
        # get nlp_task name
        nlp_task = NlpTaskEnum.classify
        # generate model version by nlp task
        model_version = generate_model_version_by_nlp_task(
            doc_type_id, mark_job_ids, nlp_task)

        # create TrainJob table
        train_job = TrainJobModel().create(train_job_name=train_job_name,
                                           train_job_desc=train_job_desc,
                                           doc_type_id=doc_type_id,
                                           train_job_status=int(
                                               StatusEnum.training),
                                           preprocess={})
        # create TrainM2mMark table
        train_m2m_mark_list = [{
            "train_job_id": train_job.train_job_id,
            "mark_job_id": _id
        } for _id in mark_job_ids]
        TrainM2mMarkbModel().bulk_create(train_m2m_mark_list)

        # create TrainTask table
        train_task = TrainTaskModel().create(
            train_job_id=train_job.train_job_id,
            train_model_name=train_job_name,
            train_model_desc=train_job_desc,
            train_config=train_config,
            train_status=int(StatusEnum.training),
            model_version=model_version)
        # bulk create train term
        doc_term_list = DocTermModel().get_by_filter(limit=99999,
                                                     doc_type_id=doc_type_id)
        TrainTermTaskModel().bulk_create([
            dict(train_task_id=train_task.train_task_id,
                 doc_term_id=doc_term.doc_term_id,
                 train_term_status=int(StatusEnum.training))
            for doc_term in doc_term_list
        ])
        # assign doc term list to doc type
        doc_type.doc_term_list = doc_term_list

        if custom_id:
            custom_item = CustomAlgorithmModel().get_by_id(custom_id)
            custom = CustomAlgorithmSchema(
                only=("custom_id_name", "custom_ip",
                      "custom_port")).dump(custom_item)
        else:
            custom = {}

        # push to redis
        push_train_task_to_redis(nlp_task, doc_type, train_task.train_task_id,
                                 model_version, train_config, mark_job_ids,
                                 custom)
        session.commit()

        # add some attribute for dumping
        train_task.mark_job_ids = mark_job_ids
        train_job.train_list = [train_task]
        train_job.doc_type = doc_type
        train_job.model_version = model_version
        return train_job
示例#7
0
    def create_train_job_by_doc_type_id(doc_type_id, train_job_name,
                                        train_job_desc, train_config,
                                        mark_job_ids):
        # verify doc_type
        doc_type = DocTypeModel().get_by_id(doc_type_id)
        # get nlp_task name
        nlp_task = NlpTaskEnum(doc_type.nlp_task_id)
        # generate model version by nlp task
        model_version = generate_model_version_by_nlp_task(
            doc_type_id, mark_job_ids, nlp_task)

        preprocess_type = {
            "split_by_sentence":
            train_config[0].get("train_type", "") in ["ner", "wordseg"]
        }

        # 为model_train_config补充model_version字段,供后台服务处理
        for config in train_config:
            config['version'] = model_version

        # create TrainJob table
        train_job = TrainJobModel().create(train_job_name=train_job_name,
                                           train_job_desc=train_job_desc,
                                           doc_type_id=doc_type_id,
                                           train_job_status=int(
                                               StatusEnum.training),
                                           preprocess=preprocess_type)
        # bulk create TrainM2mMark table
        train_m2m_mark_list = [{
            "train_job_id": train_job.train_job_id,
            "mark_job_id": _id
        } for _id in mark_job_ids]
        TrainM2mMarkbModel().bulk_create(train_m2m_mark_list)
        # create TrainTask table
        train_task = TrainTaskModel().create(
            train_job_id=train_job.train_job_id,
            train_model_name=train_job_name,
            train_model_desc=train_job_desc,
            train_config=train_config,
            train_status=int(StatusEnum.training),
            model_version=model_version)
        if nlp_task in [NlpTaskEnum.extract, NlpTaskEnum.relation]:
            # create TrainTermTask table for each doc term
            train_term_task_list = []
            for field_config in train_config:
                train_term_task_list.append({
                    "train_task_id":
                    train_task.train_task_id,
                    "doc_term_id":
                    field_config["field_id"],
                    "train_term_status":
                    int(StatusEnum.training)
                })
            TrainTermTaskModel().bulk_create(train_term_task_list)

        # push to redis
        push_train_task_to_redis(nlp_task, doc_type, train_task.train_task_id,
                                 model_version, train_config, mark_job_ids)
        session.commit()

        # add some attribute for dumping
        train_task.mark_job_ids = mark_job_ids
        train_job.train_list = [train_task]
        train_job.doc_type = doc_type
        train_job.model_version = model_version
        return train_job