def set_right_data(data_id_info={}): """ 根据data_id、beach_id 修改XUncertainCategoryTable 表中的内容,插入到XIsCategoryTable表中,如果存在修改category的值 XIsCategoryTable中data_id 必须唯一 Args: data_id_info (Dict): { data_id:category } 目前缺少beach_id Returns: """ with database.db_session() as db: for k, v in data_id_info.items(): db.query(XUncertainCategoryTable).filter( XUncertainCategoryTable.data_oid == k).update({ XUncertainCategoryTable.finished: 1, XUncertainCategoryTable.judgment_type: 1, XUncertainCategoryTable.judgment_category: v }) query_result_list = db.query(XIsCategoryTable).filter( XIsCategoryTable.data_oid == k) if query_result_list.count() == 0: data = XIsCategoryTable(data_oid=k, category=v) db.add(data) db.commit() else: data = query_result_list.one() if data.category != v: data.category = v db.commit() else: pass
def get_x_is_not_category_data(): """ [ 从 XISNOTCategory 的数据,方便计算预测数据的准确数值. ] Returns: list: [{data_oid:"",category:""}] """ # 从XISNOTCategory中取出数据出来 和 XUncertainCategoryTable 中 取出数据出来,用来进一步了解预测的精确率,情况. # 前提条件是: # 1. XISNOTCategory 中的dataoid,不在XIsCategoryTable 中出现. # 2. XISNOTCategory data_oid对应的category和XUncertainCategoryTable 的值是一样的. ans_data_info = {} with database.db_session() as db: # 从XIsCategoryTable和XUncertainCategoryTable 中取出数据出来,可以知道数据的真实值和预测值 #https://blog.csdn.net/weixin_42752248/article/details/106079115 order by data_list = db.query( database.XUncertainCategoryTable.data_oid, database.XUncertainCategoryTable.category).order_by( database.XUncertainCategoryTable.data_oid) for data in data_list: ans_data_info[data.data_oid] = { "data_oid": data.data_oid, "category": data.category } return ans_data_info
def get_regular_list(batch_id: str, category_id=None) -> list: """ 获取regular列表 Returns: [{centent、regular_type、regular_id}] """ regular_info_list = [] with database.db_session() as db: if category_id: query_result_list = db.query(BatchRegularInfoTable).filter( BatchRegularInfoTable.batch_id == batch_id, BatchRegularInfoTable.category_id == category_id) else: query_result_list = db.query(BatchRegularInfoTable).filter( BatchRegularInfoTable.batch_id == batch_id) for query_result in query_result_list: info = { "category_id": query_result.category_id, "content": query_result.content, "regular_type": query_result.regular_type, "regular_id": query_result.regular_id, "category_name": db.query(BatchCategoryInfoTable).filter( BatchCategoryInfoTable.category_id == query_result.category_id).one().category } regular_info_list.append(info) return regular_info_list
def update_metric_info(meta_info, hai_type, batch_id=0, hai_id=0): """ [ 更新metric的数据 ] Args: meta_info ([type]): [description] hai_type ([type]): [description] batch_id (int, optional): [description]. Defaults to 0. hai_id (int, optional): [description]. Defaults to 0. """ with database.db_session() as db: query_result_list = db.query(database.MetricInfoTable).filter( database.MetricInfoTable.hai_type == hai_type).filter( database.MetricInfoTable.batch_id == batch_id).filter( database.MetricInfoTable.hai_id == hai_id) str_meta_info = json.dumps(meta_info, ensure_ascii=False, indent=4, cls=NpEncoder) if query_result_list.count() > 0: metric_info_obj = query_result_list.one() metric_info_obj.meta_info = str_meta_info else: metric_info_obj = database.MetricInfoTable(batch_id=batch_id, hai_id=hai_id, hai_type=hai_type, meta_info=str_meta_info) db.add(metric_info_obj) db.commit()
def add_category(batch_id, category, category_mapping_id, category_desc): """ 添加一条类别信息 """ with database.db_session() as db: query_result_list = db.query(database.BatchCategoryInfoTable).filter( database.BatchCategoryInfoTable.batch_id == str(batch_id), database.BatchCategoryInfoTable.category == category, database.BatchCategoryInfoTable.category_mapping_id == category_mapping_id, database.BatchCategoryInfoTable.category_desc == category_desc) if query_result_list.count() == 0: if not category_desc: data = database.BatchCategoryInfoTable( batch_id=batch_id, category=category, category_mapping_id=category_mapping_id) else: data = database.BatchCategoryInfoTable( batch_id=batch_id, category=category, category_mapping_id=category_mapping_id, category_desc=category_desc) db.add(data) db.commit()
def get_not_sure_data(page_no: int = 0, page_size: int = 20, category="预计的业绩"): """ 根据category查询page_no页,page_size数量的信息 Args: page_no (int): 数据库查询到的页数 page_size (int): 数据库查询每页的数量 category (str): 类别 Returns: [ ValidationDataEntity ] """ page_no = int(page_no) page_size = int(page_size) with database.db_session() as db: query_result_list = db.query(XUncertainCategoryTable).filter( XUncertainCategoryTable.finished == 0, XUncertainCategoryTable.category == category).order_by( XUncertainCategoryTable.prob.desc()).limit(page_size).offset( page_size * page_no).all() ans_list = [] for query_result in query_result_list: ans_list.append( schemas.ValidationDataEntity( data_oid=query_result.data_oid, category=query_result.category, category_from=query_result.category_from, prob=query_result.prob, accuray=query_result.prob, text=db.query(OriginTextDataTable.text).filter( OriginTextDataTable.data_oid == query_result.data_oid).first()[0])) return ans_list
def set_wrong_data(data_id_info={}): """ 根据data_id、beach_id修改XUncertainCategoryTable 表中的内容,插入到XIsNotCategoryTable表中, 如果category相同则跳过处理,如果category不同,需要添加此条信息而不是更新 Args: data_id_info (Dict): { data_id:category } 目前缺少beach_id,不能进行准确定位。暂时进行统一的增加处理 Returns: """ with database.db_session() as db: for k, v in data_id_info.items(): db.query(XUncertainCategoryTable).filter( XUncertainCategoryTable.data_oid == k, XUncertainCategoryTable.finished == 0).update({ XUncertainCategoryTable.finished: 1, XUncertainCategoryTable.judgment_type: 2 }) query_result_list = db.query(XIsNotCategoryTable).filter( XIsNotCategoryTable.data_oid == k) if query_result_list.count() == 0: data = XIsNotCategoryTable(data_oid=k, category=v) db.add(data) db.commit() else: for query_result in query_result_list: if query_result.category == v: return data = XIsNotCategoryTable(data_oid=k, category=v) db.add(data) db.commit()
def delete_batch(batch_id): """ 删除一条batch信息 """ with database.db_session() as db: db.query(database.BatchInfoTable).filter( database.BatchInfoTable.batch_id == batch_id).delete() db.commit()
def view(cls, oid: int): with database.db_session() as db: query_result_list = db.query( cls.ClassName).filter(cls.ClassName.oid == oid) if query_result_list.count() == 0: return None else: return query_result_list.one()
def get_vec_by_data_oid(data_oid): str_vec = "" list_vec = [] with database.db_session() as db: str_vec = db.query(database.TextVecTable.vec).filter( database.TextVecTable.data_oid == data_oid).one() list_vec = list(map(lambda x: float(x), str_vec.vec.split(','))) #print(list_vec) return list_vec
def get_category_mapping_info(): with database.db_session() as db: info_list = db.query(database.BatchCategoryInfoTable.category, database.BatchCategoryInfoTable.category_id) mapping_info = {} for category, category_id in info_list: mapping_info[category] = category_id mapping_info[category_id] = category return mapping_info
def delete(cls, oid: int): with database.db_session() as db: query_result_list = db.query( cls.ClassName).filter(cls.ClassName.oid == oid) if query_result_list.count() == 0: return 0 else: query_result_list.delete() db.commit() return 1
def list_view(cls, parameter: BasicQueryParameter): with database.db_session() as db: query_result = db.query(cls.ClassName).limit( parameter.page_size).offset(parameter.page_index * parameter.page_size) ans_list = [] if query_result.count() == 0: return ans_list for result in query_result: ans_list.append(result) print(result.jsonify()) return ans_list
def delete_regular(regular_id: int) -> bool: """ 通过匹配参数,删除一条规则信息 """ with database.db_session() as db: if db.query(BatchRegularInfoTable).filter( BatchRegularInfoTable.regular_id == regular_id).count() == 0: return False db.query(BatchRegularInfoTable).filter( BatchRegularInfoTable.regular_id == regular_id).delete() db.commit() return True
def get_category_mapping_id(batch_id: str, category: str): with database.db_session() as db: query_result_list = db.query(database.BatchCategoryInfoTable).filter( database.BatchCategoryInfoTable.batch_id == batch_id, database.BatchCategoryInfoTable.category == category) if query_result_list.count() == 0: return -1 else: query_result = db.query(database.BatchCategoryInfoTable).filter( database.BatchCategoryInfoTable.batch_id == batch_id, database.BatchCategoryInfoTable.category == category).first() return query_result.category_mapping_id
def add_rule_predict_task(dataset_id: int, batch_id: int): try: with database.db_session() as db: task_obj = database.RulePredictTask(dataset_id=dataset_id, batch_id=batch_id) db.add(task_obj) db.commit() return task_obj.oid except Exception as e: print(e) return -1
def detail_view(cls, Data: CreateView): with database.db_session() as db: cii = aliased(ClassificationItemInfoTable) csi = aliased(ClassificationStandardInfoTable) # 两表联查 query_result = db.query(cii).join(csi, cii.csi_oid == csi.oid) ans_list = [] if query_result.count() == 0: return 0 for querylst in query_result: ans_list.append(querylst.jsonify()) # print(ans_list) return ans_list
def update(cls, oid: int, **value_dict): with database.db_session() as db: query_result_list = db.query( cls.ClassName).filter(cls.ClassName.oid == oid) if query_result_list.count() == 0: return 0 else: obj = query_result_list[0] for key, value in value_dict.items(): if hasattr(obj, key) and key != "oid": setattr(obj, key, value) db.commit() return 1
def get_rule_prediction_result(batch_id=0, model_id=0): """ [ 获取规则系统预测的数据结果,返回的是一个字典 ] Args: batch_id (int, optional): [批量号码]. Defaults to 0. ai_id (int, optional): [模型预测号码]. Defaults to 0. Returns: List[database.AiPredictonOfResultsTable]: [ [{ "data_id":1, "category":"", "proba":0.2, }] ] """ ans_data = {} with database.db_session() as db: data_list = db.query( database.RulePredictionOfResultsTable.data_oid, database.RulePredictionOfResultsTable.meta_info, database.RulePredictionOfResultsTable.category).filter( database.RulePredictionOfResultsTable.batch_id == batch_id).filter( database.RulePredictionOfResultsTable.model_id == model_id) metric_info = get_metric_info(model_id, "hi_{}".format(model_id)) for data in data_list: data_info_dict = json.loads(data.meta_info) proba = 0.01 # 这里面应该有一个计算过程的可以计算出来真正的概率,具体的逻辑我还有想出来,想到了就从一个地方去计算吧,这个是可以计算出来的. try: data_info_dict = json.loads(data.meta_info) category_id = data_info_dict["category_id"] category = data_info_dict["category"] precision = metric_info["precision"].get( category, proba ) # 注意这部分计算概率的方式和ai的计算方式不一样,以后还是需要改善的,其实使用贝叶斯的算法可能更好一点. proba = precision except Exception as identifier: print(identifier) ans_data[data.data_oid] = { "data_oid": data.data_oid, "category": data.category, "category_id": category_id, "proba": proba } return ans_data
def delete_origin_text_info(oid: int): with database.db_session() as db: try: data_list = db.query(database.OriginTextDataTable).filter( database.OriginTextDataTable.oid == oid) if data_list.count() == 0: return 0 db.query(database.OriginTextDataTable).filter( database.OriginTextDataTable.oid == oid).delete() db.commit() return 1 except Exception as e: print(e) return 0
def list_dataset_info(): ans_list = [] with database.db_session() as db: try: ds_obj_list = db.query(database.DataSetInfoTable).all() for ds_obj in ds_obj_list: ans_list.append( schemas.DatasetInfoEntity(dataset_id=ds_obj.oid, name=ds_obj.dataset_name, desc=ds_obj.dataset_desc)) except Exception as e: print(e) print(ans_list) return ans_list
def delete_dataset_info(dataset_id: int): with database.db_session() as db: try: data_list = db.query(database.DataSetInfoTable).filter( database.DataSetInfoTable.oid == dataset_id) if data_list.count() == 0: return 0 db.query(database.DataSetInfoTable).filter( database.DataSetInfoTable.oid == dataset_id).delete() db.commit() return 1 except Exception as e: print(e) return 0
def update_regular_info(regular_id: int, content: str, regular_type: str): """ 通过匹配参数,修改一条正则信息 """ with database.db_session() as db: if db.query(BatchRegularInfoTable).filter( BatchRegularInfoTable.regular_id == regular_id).count() == 0: return False query_result = db.query(BatchRegularInfoTable).filter( BatchRegularInfoTable.regular_id == regular_id).one() query_result.content = content query_result.regular_type = regular_type db.commit() return True
def detail_filter_view(cls, Data: CreateView): with database.db_session() as db: query_result_list = db.query(cls.ClassName).filter( or_(cls.ClassName.name == Data.name, cls.ClassName.desc == Data.desc, cls.ClassName.create_datetime == Data.create_datetime, cls.ClassName.mapping_value == Data.mapping_value, cls.ClassName.csi_oid == Data.csi_oid, cls.ClassName.csi_name == Data.csi_name)).all() print(query_result_list) if query_result_list: return query_result_list else: return None
def list_origin_text_info(dataset_id: int, page_index=0, page_size=50): ans_list = [] with database.db_session() as db: query_result_list = db.query(database.OriginTextDataTable).filter( OriginTextDataTable.dataset_id == dataset_id).order_by( OriginTextDataTable.oid.desc()).limit(page_size).offset( page_size * page_index).all() for query_result in query_result_list: ans_list.append({ "oid": query_result.oid, "text": query_result.text, "dataset_id": query_result.dataset_id }) return ans_list
def get_dataset_info(dataset_id): with database.db_session() as db: try: ds_obj_list = db.query(database.DataSetInfoTable).filter( database.DataSetInfoTable.oid == dataset_id) if ds_obj_list.count() > 0: ds_obj = ds_obj_list[0] return schemas.DatasetInfoEntity(dataset_id=ds_obj.oid, name=ds_obj.dataset_name, desc=ds_obj.dataset_desc) else: return {} except Exception as e: print(e) return {}
def delete_category(category_id): """ 删除一条类别信息 """ with database.db_session() as db: if db.query(database.BatchCategoryInfoTable).filter( database.BatchCategoryInfoTable.category_id == category_id).count() == 0: return False db.query(database.BatchCategoryInfoTable).filter( database.BatchCategoryInfoTable.category_id == category_id).delete() db.commit() db.close() return True
def filter_list_view(cls, parameter: FilterQueryParameter): with database.db_session() as db: csi_oid = parameter.csi_oid filter_list = db.query( cls.ClassName).filter(cls.ClassName.csi_oid == csi_oid).limit( parameter.page_size).offset(parameter.page_index * parameter.page_size) print(filter_list) ans_list = [] if filter_list.count() == 0: return 0 for filterlst in filter_list: ans_list.append(filterlst.jsonify()) # print(ans_list) return ans_list
def get_need_predict_data(): """ 获得需要验证的数据。 还有获得原始数据。 """ ans_data = [] with database.db_session() as db: data_list = db.query(database.TextVecTable.data_oid, database.TextVecTable.vec).all() for row in data_list: list_vec = [float(x) for x in row.vec.split(',') ] #list(map(lambda x:float(x),str_vec.vec.split(','))) ans_data.append([row.data_oid, list_vec]) return ans_data
def get_need_predict_text_data(): #这里为何返回了 表中的所有txt内容 """ [返回文本数据,主要为规则提供数据接口] TODO: 一次拿到所有的文本内容,在数据量大的时候可能会出现问题。需要优化后期 Returns: [type]: [description] """ ans_data = [] with database.db_session() as db: data_list = db.query( database.OriginTextDataTable.data_oid, database.OriginTextDataTable.text).all() #是否该添加限制条件,来做一个查询数量的限制。 for row in data_list: yield [row.data_oid, row.text]