def upload_to_db(df, task_id): """ df: 通过检查的dataframe task_id: task_id """ # 开始往数据库插入: sess = SessionLocal() num_uploaded_docs = len(df) num_success_docs = 0 for item in df.iterrows(): try: title = item[1]['title'] content = item[1]['content'] doc_id = int(my_snow.get_id()) db_doc = Document_(id=doc_id, task_id=task_id, title=title, content=content, state=0) #初次上传,state都为0 sess.add(db_doc) # sess.refresh(db_doc) num_success_docs += 1 except: sess.rollback() # 报错的话需要通过rollback来撤销当前session的操作 print(traceback.format_exc()) sess.commit() sess.close() return { "num_uploaded_docs": num_uploaded_docs, "num_success_docs": num_success_docs }
def add_task_info(task: Task, task_id, admin_id): """ 添加新的打标任务的基础信息. (当前默认admin_id=1; 暂不考虑state,doc_type信息) 从前端接收json信息(暂定) json格式如下: { 'name':str, 'desc':str, 'label_sys_ids':[str] } """ sess = SessionLocal() # 插入task表: current_time_str = datetime.fromtimestamp(int(time.time())) db_task = Task_(id=task_id, name=task.name, desc=task.desc, admin_id=admin_id, create_time=current_time_str) # ?? sess.add(db_task) sess.commit() # 插入task_records表 for label_sys_id in task.label_sys_ids: db_task_record = TaskRecords_(task_id=int(task_id), label_sys_id=label_sys_id) sess.add(db_task_record) sess.commit() sess.close() return {'task_id': task_id}
def add_label_sys(label_sys: LabelSys, admin_id: int): """ 添加新的分类体系. label_sys的json格式如下: { 'name':str, 'desc':str, 'multi':int, 'labels':[{'name':str,'desc':str,'keywords':str}, {'name':str,'desc':str,'keywords':str},...] } """ if check_label_sys_name(label_sys.name): return None sess = SessionLocal() label_sys_id = int(my_snow.get_id()) db_label_sys = LabelSys_(id=label_sys_id, name=label_sys.name, desc=label_sys.desc, multi=label_sys.multi, admin_id=admin_id) # (当前默认admin_id=1) sess.add(db_label_sys) sess.commit() sess.refresh(db_label_sys) label_id_list = [] for label in label_sys.labels: label_id = int(my_snow.get_id()) label_id_list.append(label_id) db_label = Label_(id=label_id, name=label.name, desc=label.desc, keywords=label.keywords, label_sys_id=label_sys_id) sess.add(db_label) sess.commit() sess.refresh(db_label) res = {'label_sys_id': label_sys_id, 'label_id_list': label_id_list} sess.close() return res
def update_label_sys(label_sys: LabelSys, admin_id: int): """ 可以直接在下面的update语句里面,调整允许更新的字段。 目前multi字段是不允许更新的 """ assert label_sys.id is not None, "label_sys的id没给我!" sess = SessionLocal() res = sess.query(LabelSys_).filter(LabelSys_.id == label_sys.id)\ .update({LabelSys_.name : label_sys.name, LabelSys_.desc : label_sys.desc}) print('ls_res:',res) labels = label_sys.labels if not labels: return 1 for label in labels: assert label.id is not None, "label的id没给我!" res = sess.query(Label_).filter(Label_.id == label.id)\ .update({Label_.name : label.name, Label_.desc : label.desc, Label_.keywords : label.keywords}) print('l_res:',res) sess.commit() sess.close() return 1
def check_login_info(username: str, password: str, user_type: str): sess = SessionLocal() if user_type == 'admin': user = sess.query(Admin_).filter(Admin_.username == username).first() else: user = sess.query(User_).filter(User_.username == username).first() if not user: return {"res": 0, "msg": "该用户名不存在!"} if user.password != password: return {"res": 0, "msg": "密码错误!"} token = generate_token(username, password) # 把token存入数据库中,便于后面查询对应的用户 sess = SessionLocal() if user_type == 'admin': sess.query(Admin_).filter(Admin_.username == user.username).update( {Admin_.token: token}) else: sess.query(User_).filter(User_.username == user.username).update( {User_.token: token}) sess.commit() return {"res": 1, "token": token}
def delete_task(task_id: int): sess = SessionLocal() res = sess.query(Task_).filter(Task_.id == task_id).delete() sess.commit() sess.close() return res
def delete_label_sys(label_sys_id:int): sess = SessionLocal() res = sess.query(LabelSys_).filter(LabelSys_.id == label_sys_id).delete() sess.commit() sess.close() return res