예제 #1
0
def predict(lstTxt, model_preDir):
    pass
    res_pre = []
    filename = os.path.join(arg_dic['output_predict'], 'test_results.tsv')
    bc = Bert_Class(model_preDir)
    for i in os.listdir(model_preDir):
        if '.pb' in i:
            for j in lstTxt:
                res_pb = bc.predict_on_pb(j)
                res_pre.append(res_pb)
            with open(filename, 'w', newline='') as csvfile:
                writer = csv.writer(csvfile)
                for row in res_pre:
                    writer.writerow(row)
        elif 'checkpoint' in i:
            bc.predict_on_ckpt(lstTxt)
    time_end = int(time.time())
    result = []
    result.append(time_end)
    try:
        with open(filename, 'r') as f:
            data = f.read()
        lst_ret = data.splitlines()
        lst_ret1 = [int(x) for x in lst_ret]
        result.append(lst_ret1)
        # 统计时去除“其它”分类0
        lst_ret = [int(x) for x in lst_ret if not x in ['', '0']]
        # 合并次数
        lstset = Counter(lst_ret)
        result.append(lstset)
        return result
    except:
        return ''
예제 #2
0
def skl_getMatrix():
    skl_print = '开始打印混淆矩阵!'.center(50)
    # bobao(skl_print)
    pass
    path = r'D:\work_space\Weibo\Weibo_multi-label-classifier\data\dataAll\dat_20200403'
    model_preDir = r'./model_predict/'
    filename = os.path.join(arg_dic['output_predict'], 'test_results.tsv')
    dev = os.path.join(path,'dev.tsv')
    y_true = []
    lines = []
    with open(dev,'r',encoding='utf-8') as f:
        reader = csv.reader(f,delimiter="\t")
        for line in reader:
            y_true.append(int(line[0]))
            lines.append(line[1])
    bc = Bert_Class(model_preDir)
    bc.predict_on_ckpt(lines)
    with open(filename, 'r') as f:
        data = f.read()
    lst_ret = data.splitlines()
    y_pred = [int(x) for x in lst_ret]
    print(y_true)
    print(y_pred)
    eval_report = classification_report(y_true,y_pred)
    print(eval_report)
    return  0
예제 #3
0
# -*- coding: utf-8 -*-
'''
@author: [email protected]
@license: (C) Copyright 2019
@desc: 项目描述。
@DateTime: Created on 2019/7/22, at 下午 05:07 by PyCharm
'''

from sanic import Sanic
from sanic.response import json as Rjson
from predict_GPU import Bert_Class

app = Sanic()
my = Bert_Class()


@app.route("/", methods=['GET', 'POST'])
async def home(request):
    # 1,首先要从HTTP请求获取用户的字符串
    dict1 = {'tips': '请用POST方法,传递“用户id、question”字段'}
    if request.method == 'GET':
        user, key_str = request.args.get('user_id'), request.args.get(
            'question')
    elif request.method == 'POST':
        for k, v in request.form.items():
            dict1[k] = v  # 最关心的问题字段是keyword
        user, key_str = request.form.get('user_id'), request.form.get(
            'question')
    else:
        return Rjson(dict1)
    if not key_str or not user:  # 如果有空的字段,返回警告信息。