コード例 #1
0
def label_counts(label_sys_id, task_id_list):
    """
    目前只统计了单标签的数量
    """
    sess = SessionLocal()
    if not task_id_list:
        res = sess.query(Document_.id, Label_.name).filter(Document_.state == 1).join(TaggingRecords_).join(Label_)\
            .filter(Label_.label_sys_id == label_sys_id).all()
    else:
        res = sess.query(Document_.id, Label_.name).filter(Document_.state == 1).join(TaggingRecords_).join(Label_)\
            .filter(Label_.label_sys_id == label_sys_id).filter(Document_.task_id.in_(task_id_list)).all() ## 特殊的in语句

    doc_label = {}
    for each in res:
        if each.id in doc_label.keys():
            doc_label[each.id] = doc_label[each.id] + ',' + each.name
        else:
            doc_label[each.id] = each.name
    labels = list(doc_label.values())


    # if not task_id_list:
    #     res = sess.query(Label_.name).join(TaggingRecords_).join(Document_).filter(Document_.state == 1) \
    #          .filter(Label_.label_sys_id == label_sys_id).all()
    # else:
    #     res = sess.query(Label_.name).join(TaggingRecords_).join(Document_).filter(Document_.state == 1) \
    #         .filter(Label_.label_sys_id == label_sys_id).filter(Document_.task_id.in_(task_id_list)).all()
    # labels = [each.name for each in res]
    c  =Counter(labels)
    sess.close()
    return c
コード例 #2
0
def get_detail_by_id(task_id: int):
    sess = SessionLocal()
    task = sess.query(Task_.id, Task_.name, Task_.desc,
                      Task_.create_time).filter(Task_.id == task_id).first()
    if not task:
        sess.close()
        return None
    ls_list = sess.query(LabelSys_.id, LabelSys_.name, LabelSys_.multi).join(TaskRecords_)\
                         .filter(TaskRecords_.task_id == task_id).all()
    print(ls_list)
    label_sys_list = []
    label_list = []
    for ls in ls_list:
        label_sys_list.append({
            "id": ls.id,
            "name": ls.name,
            "multi": ls.multi
        })
        l_list = sess.query(Label_.id, Label_.name,
                            Label_.desc).join(LabelSys_).filter(
                                Label_.label_sys_id == ls.id).all()
        label_list.append([{
            "id": l.id,
            "name": l.name,
            "desc": l.desc
        } for l in l_list])
    res = {
        "name": task.name,
        "desc": task.desc,
        "create_time": task.create_time,
        "label_sys_list": label_sys_list,
        "label_list": label_list
    }
    sess.close()
    return res
コード例 #3
0
def get_tagged_docs(task_id: int, user_id: int):
    sess = SessionLocal()
    if user_id == -1:
        docs = sess.query(Document_.id, Document_.title).filter(
            Document_.task_id == task_id).filter(Document_.state == 1).all()
    else:
        docs = sess.query(Document_.id,Document_.title).filter(Document_.task_id == task_id).filter(Document_.state == 1)\
            .join(TaggingRecords_).filter(TaggingRecords_.user_id == user_id).all()
    sess.close()
    if not docs:
        return None
    return [{'doc_id': doc.id, 'title': doc.title} for doc in docs]
コード例 #4
0
def get_detail_by_id(label_sys_id: int):
    sess = SessionLocal()
    label_sys = sess.query(LabelSys_).filter(LabelSys_.id == label_sys_id).first()
    if not label_sys:
        sess.close()
        return None
    labels = sess.query(Label_).filter(Label_.label_sys_id == label_sys_id).all()
    label_sys_detail = {'id': label_sys_id, 'name': label_sys.name, 'desc': label_sys.desc,
                        'multi': str(label_sys.multi), 'num_labels': len(labels), 'labels': []}
    for label in labels:
        label_sys_detail['labels'].append({'id': label.id, 'name': label.name, 'desc': label.desc, 'keywords': label.keywords})
    sess.close()
    return label_sys_detail
コード例 #5
0
def count_docs(task_id: int):
    sess = SessionLocal()
    docs = sess.query(Document_.id).filter(Document_.task_id == task_id).all()
    tagged_docs = sess.query(Document_.id).filter(Document_.task_id == task_id) \
        .filter(Document_.state == 1).all()
    # unsure_docs = sess.query(Document_.id).filter(Document_.task_id == task_id) \
    #     .filter(Document_.state == 2).all()
    counts = {
        "num_docs": len(docs),
        "num_tagged_docs": len(tagged_docs)
    }  # , "num_unsure_docs": len(unsure_docs)
    sess.close()
    return counts
コード例 #6
0
def get_id_by_token(token: str, admin_only=False):
    sess = SessionLocal()
    admin_user = sess.query(Admin_).filter(Admin_.token == token).first()
    normal_user = sess.query(User_).filter(User_.token == token).first()
    if admin_only:
        user = admin_user
    else:
        if not admin_user:
            user = normal_user
        else:
            user = admin_user
    if not user:
        return None
    return user.id
コード例 #7
0
def check_task_name(task_name: str):
    sess = SessionLocal()
    res = sess.query(Task_).filter(Task_.name == task_name).first()
    sess.close()
    if res == None:
        return 0
    else:
        return 1
コード例 #8
0
def check_label_sys_name(label_sys_name: str):
    sess = SessionLocal()
    res = sess.query(LabelSys_).filter(LabelSys_.name == label_sys_name).first()
    sess.close()
    if res is None:
        return 0
    else:
        return 1
コード例 #9
0
def get_related_tasks(label_sys_id: int):
    sess = SessionLocal()
    recs = sess.query(Task_).join(TaskRecords_).filter(TaskRecords_.label_sys_id == label_sys_id).all()
    tasks = []
    for rec in recs:
        tasks.append({'id':rec.id, 'name': rec.name})
    sess.close()
    return tasks
コード例 #10
0
def get_summary():
    """
    获取所有打标任务的一个总览信息.
    包括名称,描述,分类体系名称列表,创建时间,完成进度(待定)

    返回:None或者summary
    summary格式:
    [{'name':str,'desc':str,'create_time':str,
    'num_docs':int, 'num_tagged_docs':int,
    'label_sys_list':[{'id':int, 'name':str}]}]
    """
    sess = SessionLocal()
    tasks = sess.query(Task_.id, Task_.name, Task_.desc,
                       Task_.create_time).all()
    if not tasks:
        sess.close()
        return None

    task_summary = []
    for task in tasks:
        # 查询label system关联信息:
        label_sys_list = []
        lss = sess.query(LabelSys_.id, LabelSys_.name).join(TaskRecords_)\
                     .filter(TaskRecords_.task_id == task.id).all()
        for ls in lss:
            label_sys_list.append({'id': ls.id, 'name': ls.name})

        # 查询document关联信息:
        docs = sess.query(
            Document_.id).filter(Document_.task_id == task.id).all()
        tagged_docs = sess.query(Document_.id).filter(Document_.task_id == task.id)\
                                              .filter(Document_.state == 1).all()

        task_summary.append({
            'id': task.id,
            'name': task.name,
            'desc': task.desc,
            'create_time': task.create_time,
            'num_docs': len(docs),
            'num_tagged_docs': len(tagged_docs),
            'label_sys_list': label_sys_list
        })
    sess.close()
    return task_summary
コード例 #11
0
def get_label_sys_list():
    sess = SessionLocal()
    res = sess.query(LabelSys_.id, LabelSys_.name).all()
    sess.close()
    if not res:
        return None
    label_sys_list = []
    for ls in res:
        label_sys_list.append({'id': ls.id, 'name': ls.name})
    return label_sys_list
コード例 #12
0
def update_label_sys(label_sys: LabelSys, admin_id: int):
    """
    可以直接在下面的update语句里面,调整允许更新的字段。
    目前multi字段是不允许更新的
    """
    assert label_sys.id is not None, "label_sys的id没给我!"
    sess = SessionLocal()
    res = sess.query(LabelSys_).filter(LabelSys_.id == label_sys.id)\
        .update({LabelSys_.name : label_sys.name, LabelSys_.desc : label_sys.desc})
    print('ls_res:',res)
    labels = label_sys.labels
    if not labels:
        return 1
    for label in labels:
        assert label.id is not None, "label的id没给我!"
        res = sess.query(Label_).filter(Label_.id == label.id)\
            .update({Label_.name : label.name, Label_.desc : label.desc, Label_.keywords : label.keywords})
        print('l_res:',res)
    sess.commit()
    sess.close()
    return 1
コード例 #13
0
def get_detail_by_name(label_sys_name: str):
    """
    获取某分类体系详情.

    返回:None或label_sys_detail

    label_sys_detail格式:
    {'id':str,'name':str,'desc':str,'multi':str,'num_labels':str,
     'labels':[{'id':str,'name':str,'desc':str},...]}
    """
    sess = SessionLocal()
    label_sys = sess.query(LabelSys_).filter(LabelSys_.name == label_sys_name).first()
    if not label_sys:
        sess.close()
        return None
    labels = sess.query(Label_).filter(Label_.label_sys_id == label_sys.id).all()
    label_sys_detail = {'id': label_sys.id, 'name': label_sys.name, 'desc': label_sys.desc,
                        'multi': str(label_sys.multi), 'num_labels': len(labels), 'labels': []}
    for label in labels:
        label_sys_detail['labels'].append({'id': label.id, 'name': label.name, 'desc': label.desc})
    sess.close()
    return label_sys_detail
コード例 #14
0
def label_sys_tagged_data_download(label_sys_id, task_id_list):
    current_time_str = str(datetime.fromtimestamp(int(time.time()))).replace(' ','-').replace(':','-')
    sess = SessionLocal()
    ls_info = sess.query(LabelSys_).filter(LabelSys_.id == label_sys_id).first()
    ls_name = ls_info.name
    file_name = ls_name+'__'+current_time_str
    if not task_id_list:
        res = sess.query(Document_.id, Document_.title, Document_.content, Label_.label_sys_id, Label_.name).filter(Document_.state == 1).join(TaggingRecords_).join(Label_)\
            .filter(Label_.label_sys_id == label_sys_id).all()
    else:
        res = sess.query(Document_.id, Document_.title, Document_.content, Label_.label_sys_id, Label_.name).filter(Document_.state == 1).join(TaggingRecords_).join(Label_)\
            .filter(Label_.label_sys_id == label_sys_id).filter(Document_.task_id.in_(task_id_list)).all() ## 特殊的in语句

    d = {}
    for each in res:
        if each.id in d.keys():
            d[each.id]['label'] = d[each.id]['label'] + ',' + each.name
        else:
            d[each.id] = {'title':each.title, 'content':each.content, 'label':each.name}
    df = pd.concat([pd.DataFrame(d.keys(),columns=['doc_id']),pd.DataFrame(d.values())],axis=1)
    df.to_csv('download_datasets/%s.csv'%file_name)
    sess.close()
    return file_name
コード例 #15
0
def get_summary():
    """
    获取所有分类体系的一个总览信息.
    返回:None或者summary
    summary格式:
    [{'id':int,'name':str,'desc':str,'multi':str,'num_labels':str},...]
    """
    sess = SessionLocal()
    res = sess.query(LabelSys_.id, LabelSys_.name, LabelSys_.desc, LabelSys_.multi,
                     func.count(Label_.id).label('num_labels')).join(Label_) \
                     .group_by(LabelSys_.id).all()
    if not res:
        sess.close()
        return None
    summary = []
    for each in res:
        summary.append({'id': each.id, 'name': each.name, 'desc': each.desc,
                        'multi': str(each.multi), 'num_labels': str(each.num_labels)})
    sess.close()
    return summary
コード例 #16
0
def check_login_info(username: str, password: str, user_type: str):
    sess = SessionLocal()
    if user_type == 'admin':
        user = sess.query(Admin_).filter(Admin_.username == username).first()
    else:
        user = sess.query(User_).filter(User_.username == username).first()
    if not user:
        return {"res": 0, "msg": "该用户名不存在!"}
    if user.password != password:
        return {"res": 0, "msg": "密码错误!"}
    token = generate_token(username, password)
    # 把token存入数据库中,便于后面查询对应的用户
    sess = SessionLocal()
    if user_type == 'admin':
        sess.query(Admin_).filter(Admin_.username == user.username).update(
            {Admin_.token: token})
    else:
        sess.query(User_).filter(User_.username == user.username).update(
            {User_.token: token})
    sess.commit()
    return {"res": 1, "token": token}
コード例 #17
0
def delete_task(task_id: int):
    sess = SessionLocal()
    res = sess.query(Task_).filter(Task_.id == task_id).delete()
    sess.commit()
    sess.close()
    return res
コード例 #18
0
def delete_label_sys(label_sys_id:int):
    sess = SessionLocal()
    res = sess.query(LabelSys_).filter(LabelSys_.id == label_sys_id).delete()
    sess.commit()
    sess.close()
    return res