示例#1
0
def pred_input(str_input, path_hyper_parameter=path_hyper_parameters):
    # 输入预测
    # 加载超参数
    hyper_parameters = load_json(path_hyper_parameter)
    pt = PreprocessTextMulti()
    # 模式初始化和加载
    graph = Graph(hyper_parameters)
    graph.load_model()
    ra_ed = graph.word_embedding
    ques = str_input
    # str to token
    ques_embed = ra_ed.sentence2idx(ques)
    if hyper_parameters['embedding_type'] == 'bert':
        x_val_1 = np.array([ques_embed[0]])
        x_val_2 = np.array([ques_embed[1]])
        x_val = [x_val_1, x_val_2]
    else:
        x_val = ques_embed
    # 预测
    pred = graph.predict(x_val)
    print(pred)
    # 取id to label and pred
    pre = pt.prereocess_idx(pred[0])
    ls_nulti = []
    for ls in pre[0]:
        if ls[1] >= 0.73:
            ls_nulti.append(ls)
    print(str_input)
    print(pre[0])
    print(ls_nulti)
def pred_tet(path_hyper_parameter=path_hyper_parameters,
             path_test=None,
             rate=1.0):
    """
        测试集测试与模型评估
    :param hyper_parameters: json, 超参数
    :param path_test:str, path of test data, 测试集
    :param rate: 比率, 抽出rate比率语料取训练
    :return: None
    """
    hyper_parameters = load_json(path_hyper_parameter)
    if path_test:  # 从外部引入测试数据地址
        hyper_parameters['data']['val_data'] = path_test
    time_start = time.time()
    # graph初始化
    graph = Graph(hyper_parameters)
    print("graph init ok!")
    graph.load_model()
    print("graph load ok!")
    ra_ed = graph.word_embedding
    # 数据预处理
    pt = PreprocessText()
    y, x = read_and_process(hyper_parameters['data']['val_data'])
    # 取该数据集的百分之几的语料测试
    len_rate = int(len(y) * rate)
    x = x[1:len_rate]
    y = y[1:len_rate]
    y_pred = []
    count = 0
    for x_one in x:
        count += 1
        ques_embed = ra_ed.sentence2idx(x_one)
        if hyper_parameters['embedding_type'] == 'bert':  # bert数据处理, token
            x_val_1 = np.array([ques_embed[0]])
            x_val_2 = np.array([ques_embed[1]])
            x_val = [x_val_1, x_val_2]
        else:
            x_val = ques_embed
        # 预测
        pred = graph.predict(x_val)
        pre = pt.prereocess_idx(pred[0])
        label_pred = pre[0][0][0]
        if count % 1000 == 0:
            print(label_pred)
        y_pred.append(label_pred)

    print("data pred ok!")
    # 预测结果转为int类型
    index_y = [pt.l2i_i2l['l2i'][i] for i in y]
    index_pred = [pt.l2i_i2l['l2i'][i] for i in y_pred]
    target_names = [
        pt.l2i_i2l['i2l'][str(i)] for i in list(set((index_pred + index_y)))
    ]
    # 评估
    report_predict = classification_report(index_y,
                                           index_pred,
                                           target_names=target_names,
                                           digits=9)
    print(report_predict)
    print("耗时:" + str(time.time() - time_start))
示例#3
0
def pred_input(path_hyper_parameter=path_hyper_parameters):
    # 输入预测
    # 加载超参数
    hyper_parameters = load_json(path_hyper_parameter)
    pt = PreprocessTextMulti(path_model_dir)
    # 模式初始化和加载
    graph = Graph(hyper_parameters)
    graph.load_model()
    ra_ed = graph.word_embedding
    ques = '我要打王者荣耀'
    # str to token
    ques_embed = ra_ed.sentence2idx(ques)
    if hyper_parameters['embedding_type'] in ['bert', 'albert']:
        x_val_1 = np.array([ques_embed[0]])
        x_val_2 = np.array([ques_embed[1]])
        x_val = [x_val_1, x_val_2]
    else:
        x_val = ques_embed
    # 预测
    pred = graph.predict(x_val)
    print(pred)
    # 取id to label and pred
    pre = pt.prereocess_idx(pred[0])
    ls_nulti = []
    for ls in pre[0]:
        if ls[1] >= 0.5:
            ls_nulti.append(ls)
    print(pre[0])
    print(ls_nulti)
    while True:
        print("请输入: ")
        ques = input()
        ques_embed = ra_ed.sentence2idx(ques)
        print(ques_embed)
        if hyper_parameters['embedding_type'] in ['bert', 'albert']:
            x_val_1 = np.array([ques_embed[0]])
            x_val_2 = np.array([ques_embed[1]])
            x_val = [x_val_1, x_val_2]
        else:
            x_val = ques_embed
        pred = graph.predict(x_val)
        pre = pt.prereocess_idx(pred[0])
        ls_nulti = []
        for ls in pre[0]:
            if ls[1] >= 0.5:
                ls_nulti.append(ls)
        print(pre[0])
        print(ls_nulti)
def pred_input(path_hyper_parameter=path_hyper_parameters):
    """
       输入预测
    :param path_hyper_parameter: str, 超参存放地址
    :return: None
    """
    # 加载超参数
    hyper_parameters = load_json(path_hyper_parameter)
    pt = PreprocessSim(path_model_dir)
    # 模式初始化和加载
    graph = Graph(hyper_parameters)
    graph.load_model()
    ra_ed = graph.word_embedding
    sen1 = '我要打王者荣耀'
    sen2 = '我要打梦幻西游'

    # str to token
    ques_embed = ra_ed.sentence2idx(text=sen1, second_text=sen2)
    if hyper_parameters['embedding_type'] in ['bert', 'albert']:
        x_val_1 = np.array([ques_embed[0]])
        x_val_2 = np.array([ques_embed[1]])
        x_val = [x_val_1, x_val_2]
        # 预测
        pred = graph.predict(x_val)
        # 取id to label and pred
        pre = pt.prereocess_idx(pred[0])
        print(pre)
        while True:
            print("请输入sen1: ")
            sen1 = input()
            print("请输入sen2: ")
            sen2 = input()

            ques_embed = ra_ed.sentence2idx(text=sen1, second_text=sen2)
            print(ques_embed)
            if hyper_parameters['embedding_type'] in ['bert', 'albert']:
                x_val_1 = np.array([ques_embed[0]])
                x_val_2 = np.array([ques_embed[1]])
                x_val = [x_val_1, x_val_2]
                pred = graph.predict(x_val)
                pre = pt.prereocess_idx(pred[0])
                print(pre)
            else:
                print("error, just support bert or albert")

    else:
        print("error, just support bert or albert")
def pred_input(path_hyper_parameter=path_hyper_parameters):
    """
       输入预测
    :param path_hyper_parameter: str, 超参存放地址
    :return: None
    """
    # 加载超参数
    hyper_parameters = load_json(path_hyper_parameter)
    pt = PreprocessText()
    # 模式初始化和加载
    graph = Graph(hyper_parameters)
    graph.load_model()
    ra_ed = graph.word_embedding
    ques = '我要打王者荣耀'
    # str to token
    ques_embed = ra_ed.sentence2idx(ques)
    if hyper_parameters['embedding_type'] == 'bert':
        x_val_1 = np.array([ques_embed[0]])
        x_val_2 = np.array([ques_embed[1]])
        x_val = [x_val_1, x_val_2]
    else:
        x_val = ques_embed
    # 预测
    pred = graph.predict(x_val)
    # 取id to label and pred
    pre = pt.prereocess_idx(pred[0])
    print(pre)
    while True:
        print("请输入: ")
        ques = input()
        ques_embed = ra_ed.sentence2idx(ques)
        print(ques_embed)
        if hyper_parameters['embedding_type'] == 'bert':
            x_val_1 = np.array([ques_embed[0]])
            x_val_2 = np.array([ques_embed[1]])
            x_val = [x_val_1, x_val_2]
        else:
            x_val = ques_embed
        pred = graph.predict(x_val)
        pre = pt.prereocess_idx(pred[0])
        print(pre)
    def preprocess_label_ques_to_idx(self,
                                     embedding_type,
                                     batch_size,
                                     path,
                                     embed,
                                     rate=1,
                                     epcoh=20):
        label_set, len_all = self.preprocess_get_label_set(path)
        # 获取label转index字典等, 如果label2index存在则不转换了, dev验证集合的时候用
        if not os.path.exists(self.path_fast_text_model_l2i_i2l):
            count = 0
            label2index = {}
            index2label = {}
            for label_one in label_set:
                label2index[label_one] = count
                index2label[count] = label_one
                count = count + 1

            l2i_i2l = {}
            l2i_i2l['l2i'] = label2index
            l2i_i2l['i2l'] = index2label
            save_json(l2i_i2l, self.path_fast_text_model_l2i_i2l)
        else:
            l2i_i2l = load_json(self.path_fast_text_model_l2i_i2l)

        # 读取数据的比例
        len_ql = int(rate * len_all)
        if len_ql <= 500:  # sample时候不生效,使得语料足够训练
            len_ql = len_all

        def process_line(line):
            # 对每一条数据操作,获取label和问句index
            line_sp = line.split(",")
            ques = str(line_sp[1]).strip().upper()
            label = str(line_sp[0]).strip().upper()
            label = "NAN" if label == "" else label
            que_embed = embed.sentence2idx(ques)
            label_zeros = [0] * len(l2i_i2l['l2i'])
            label_zeros[l2i_i2l['l2i'][label]] = 1
            return que_embed, label_zeros

        for _ in range(epcoh):
            while True:
                file_csv = open(path, "r", encoding="utf-8")
                cout_all_line = 0
                cnt = 0
                x, y = [], []
                # 跳出循环
                if len_ql < cout_all_line:
                    break
                for line in file_csv:
                    cout_all_line += 1
                    if cout_all_line > 1:  # 第一条是标签'label,ques',不选择
                        x_line, y_line = process_line(line)
                        x.append(x_line)
                        y.append(y_line)
                        cnt += 1
                        if cnt == batch_size:
                            if embedding_type in ['bert', 'albert']:
                                x_, y_ = np.array(x), np.array(y)
                                x_1 = np.array([x[0] for x in x_])
                                x_2 = np.array([x[1] for x in x_])
                                x_all = [x_1, x_2]
                            elif embedding_type == 'xlnet':
                                x_, y_ = x, np.array(y)
                                x_1 = np.array([x[0][0] for x in x_])
                                x_2 = np.array([x[1][0] for x in x_])
                                x_3 = np.array([x[2][0] for x in x_])
                                x_all = [x_1, x_2, x_3]
                            else:
                                x_all, y_ = np.array(x), np.array(y)

                            cnt = 0
                            yield (x_all, y_)
                            x, y = [], []
            file_csv.close()
        print("preprocess_label_ques_to_idx ok")
    def preprocess_label_ques_to_idx(self,
                                     embedding_type,
                                     batch_size,
                                     path,
                                     embed,
                                     rate=1,
                                     epcoh=20):
        label_set, len_all = self.preprocess_get_label_set(path)
        # 获取label转index字典等, 如果label2index存在则不转换了, dev验证集合的时候用
        if not os.path.exists(self.path_fast_text_model_l2i_i2l):
            count = 0
            label2index = {}
            index2label = {}
            for label_one in label_set:
                label2index[label_one] = count
                index2label[count] = label_one
                count = count + 1

            l2i_i2l = {}
            l2i_i2l['l2i'] = label2index
            l2i_i2l['i2l'] = index2label
            save_json(l2i_i2l, self.path_fast_text_model_l2i_i2l)
        else:
            l2i_i2l = load_json(self.path_fast_text_model_l2i_i2l)

        # 读取数据的比例
        len_ql = int(rate * len_all)
        if len_ql <= 500:  # sample时候不生效,使得语料足够训练
            len_ql = len_all

        def process_line(line):
            # 对每一条数据操作,获取label和问句index
            data = json.loads(line)
            label = data['label']
            ques_1 = data['sentence1']
            ques_2 = data['sentence2']
            offset = data['offset']
            mention_1 = data["mention"]
            offset_i = int(offset)
            que_embed_1 = embed.sentence2idx(text=ques_1)
            que_embed_2 = embed.sentence2idx(text=ques_2)
            """ques1"""
            [input_id_1, input_type_id_1, input_mask_1] = que_embed_1
            input_start_mask_1 = [0] * len(input_id_1)
            input_start_mask_1[offset_i] = 1
            input_end_mask_1 = [0] * len(input_id_1)
            input_end_mask_1[offset_i + len(mention_1) - 1] = 1
            input_entity_mask_1 = [0] * len(input_id_1)
            input_entity_mask_1[offset_i:offset_i +
                                len(mention_1)] = [1] * len(mention_1)
            """ques2"""
            [input_id_2, input_type_id_2, input_mask_2] = que_embed_2
            kind_2 = [0] * len(input_type_id_2)
            kind_21 = [0] * len(input_type_id_2)
            que_2_sp = ques_2.split("|")
            if len(que_2_sp) >= 2:
                que_2_sp_sp = que_2_sp[0].split(":")
                if len(que_2_sp_sp) == 2:
                    kind_2_start = len(que_2_sp_sp[0]) - 1
                    kind_2_end = kind_2_start + len(que_2_sp_sp[1]) - 1
                    kind_2[kind_2_start:kind_2_end] = [1] * (kind_2_end -
                                                             kind_2_start)
                if "标签:" in que_2_sp[1]:
                    que_21_sp_sp = que_2_sp[1].split(":")
                    kind_21_start = len(que_2_sp[0]) + len(que_21_sp_sp[0]) - 1
                    kind_21_end = len(que_2_sp[0]) + len(
                        que_21_sp_sp[0]) + len(que_21_sp_sp[1]) - 1
                    kind_21[kind_21_start:kind_21_end] = [1] * (kind_21_end -
                                                                kind_21_start)
            que_embed_x = [
                input_id_1, input_type_id_1, input_mask_1, input_start_mask_1,
                input_end_mask_1, input_entity_mask_1, input_id_2,
                input_type_id_2, input_mask_2, kind_2, kind_21
            ]
            label_zeros = [0] * len(l2i_i2l['l2i'])
            label_zeros[l2i_i2l['l2i'][label]] = 1
            return que_embed_x, label_zeros

        for _ in range(epcoh):
            while True:
                file_csv = open(path, "r", encoding="utf-8")
                cout_all_line = 0
                cnt = 0
                x, y = [], []
                # 跳出循环
                if len_ql < cout_all_line:
                    break
                for line in file_csv:
                    cout_all_line += 1
                    x_line, y_line = process_line(line)
                    x.append(x_line)
                    y.append(y_line)
                    cnt += 1
                    if cnt == batch_size:
                        if embedding_type in ['bert', 'albert']:
                            x_, y_ = np.array(x), np.array(y)
                            x_all = []
                            for i in range(len(x_[0])):
                                x_1 = np.array([x[i] for x in x_])
                                x_all.append(x_1)
                        elif embedding_type == 'xlnet':
                            x_, y_ = x, np.array(y)
                            x_1 = np.array([x[0][0] for x in x_])
                            x_2 = np.array([x[1][0] for x in x_])
                            x_3 = np.array([x[2][0] for x in x_])
                            x_all = [x_1, x_2, x_3]
                        else:
                            x_all, y_ = np.array(x), np.array(y)

                        cnt = 0
                        yield (x_all, y_)
                        x, y = [], []
                file_csv.close()
        print("preprocess_label_ques_to_idx ok")
 def __init__(self, path_model_dir):
     self.l2i_i2l = None
     self.path_fast_text_model_vocab2index = path_model_dir + 'vocab2index.json'
     self.path_fast_text_model_l2i_i2l = path_model_dir + 'l2i_i2l.json'
     if os.path.exists(self.path_fast_text_model_l2i_i2l):
         self.l2i_i2l = load_json(self.path_fast_text_model_l2i_i2l)
    def preprocess_label_ques_to_idx_old(self,
                                         embedding_type,
                                         batch_size,
                                         path,
                                         embed,
                                         rate=1,
                                         epcoh=20):
        label_set, len_all = self.preprocess_get_label_set(path)
        # 获取label转index字典等, 如果label2index存在则不转换了, dev验证集合的时候用
        if not os.path.exists(self.path_fast_text_model_l2i_i2l):
            count = 0
            label2index = {}
            index2label = {}
            for label_one in label_set:
                label2index[label_one] = count
                index2label[count] = label_one
                count = count + 1

            l2i_i2l = {}
            l2i_i2l['l2i'] = label2index
            l2i_i2l['i2l'] = index2label
            save_json(l2i_i2l, self.path_fast_text_model_l2i_i2l)
        else:
            l2i_i2l = load_json(self.path_fast_text_model_l2i_i2l)

        # 读取数据的比例
        len_ql = int(rate * len_all)
        if len_ql <= 500:  # sample时候不生效,使得语料足够训练
            len_ql = len_all

        def process_line(line):
            # 对每一条数据操作,获取label和问句index
            data = json.loads(line)
            label = data['label']
            ques_1 = data['sentence1']
            ques_2 = data['sentence2']
            offset = data['offset']
            mention = data["mention"]
            offset_i = int(offset)
            # if data.get("label_l2i"):
            #     ques_entity = data.get("label_l2i") + "#" + ques_1[:offset_i] + "#" + mention + "#" + ques_1[offset_i+len(mention):]
            # else:
            #     ques_entity = ques_1[:offset_i] + "#" + mention + "#" + ques_1[offset_i+len(mention):] + "$$" + ques_2
            # que_embed = embed.sentence2idx(text=ques_entity)
            que_embed = embed.sentence2idx(ques_1, second_text=ques_2)
            label_zeros = [0] * len(l2i_i2l['l2i'])
            label_zeros[l2i_i2l['l2i'][label]] = 1
            return que_embed, label_zeros

        for _ in range(epcoh):
            while True:
                file_csv = open(path, "r", encoding="utf-8")
                cout_all_line = 0
                cnt = 0
                x, y = [], []
                # 跳出循环
                if len_ql < cout_all_line:
                    break
                for line in file_csv:
                    cout_all_line += 1
                    x_line, y_line = process_line(line)
                    x.append(x_line)
                    y.append(y_line)
                    cnt += 1
                    if cnt == batch_size:
                        if embedding_type in ['bert', 'albert']:
                            x_, y_ = np.array(x), np.array(y)
                            x_1 = np.array([x[0] for x in x_])
                            x_2 = np.array([x[1] for x in x_])
                            x_all = [x_1, x_2]
                        elif embedding_type == 'xlnet':
                            x_, y_ = x, np.array(y)
                            x_1 = np.array([x[0][0] for x in x_])
                            x_2 = np.array([x[1][0] for x in x_])
                            x_3 = np.array([x[2][0] for x in x_])
                            x_all = [x_1, x_2, x_3]
                        else:
                            x_all, y_ = np.array(x), np.array(y)

                        cnt = 0
                        yield (x_all, y_)
                        x, y = [], []
                file_csv.close()
        print("preprocess_label_ques_to_idx ok")
示例#10
0
def evaluate(path_hyper_parameter=path_hyper_parameters, rate=1.0):
    # 输入预测
    # 加载超参数
    hyper_parameters = load_json(path_hyper_parameter)
    pt = PreprocessTextMulti()
    # 模式初始化和加载
    graph = Graph(hyper_parameters)
    graph.load_model()
    ra_ed = graph.word_embedding
    # get validation data
    ques_list, val_list, que, val = pt.preprocess_label_ques_to_idx(
        hyper_parameters['embedding_type'],
        hyper_parameters['data']['val_data'],
        ra_ed,
        rate=rate,
        shuffle=True)
    print(len(ques_list))
    print("que:", len(que))
    # print(val)

    # str to token
    ques_embed_list = []
    count = 0
    acc_count = 0
    not_none_count = 0
    not_none_acc_count = 0
    sum_iou = 0
    sum_all_iou = 0
    for index, que___ in enumerate(que):
        # print("原句 ", index, que[index])
        # print("真实分类 ", index, val[index])
        # print("ques: ", ques)
        ques_embed = ra_ed.sentence2idx(que[index])
        if hyper_parameters['embedding_type'] == 'albert':
            x_val_1 = np.array([ques_embed[0]])
            x_val_2 = np.array([ques_embed[1]])
            ques_embed = [x_val_1, x_val_2]
        else:
            x_val = ques_embed
        # print("ques_embed: ", ques_embed)
        if hyper_parameters['embedding_type'] == 'bert':
            x_val_1 = np.array([ques_embed[0]])
            x_val_2 = np.array([ques_embed[1]])
            x_val = [x_val_1, x_val_2]
        else:
            x_val = ques_embed
        # print("x_val", x_val)
        ques_embed_list.append(x_val)
        # 预测
        pred = graph.predict(x_val)
        # print(pred)
        # 取id to label and pred
        pre = pt.prereocess_idx(pred[0])
        # print("pre",pre)
        ls_nulti = []
        threshold = 0.44
        top_threshold = 0
        for i, ls in enumerate(pre[0]):
            if i == 0 or ls[1] > threshold:
                ls_nulti.append(ls)
                top_threshold = ls[1]
            elif abs(ls[1] - top_threshold) < top_threshold / 4.0:
                ls_nulti.append(ls)
        # print("预测结果", index, pre[0])
        # print(ls_nulti)
        res = cal_acc(ls_nulti, val[index].split(","))
        res_iou, res_all_iou = cal_iou(ls_nulti, val[index].split(","))
        sum_iou += res_iou
        sum_all_iou += res_all_iou
        if res:
            if val[index] != "无":
                not_none_acc_count += 1
            acc_count += 1
        else:
            print("原句 ", index, que[index])
            print("真实分类 ", index, val[index])
            print("pre ", pre)
            print("iou ", res_iou)
        count += 1
        if val[index] != "无":
            not_none_count += 1
    print("acc: ", acc_count / count)
    print("not none acc: ", not_none_acc_count / not_none_count)
    print("average iou: ", sum_iou / sum_all_iou)
    # log
    append_log(hyper_parameters, acc_count / count,
               not_none_acc_count / not_none_count, threshold)
示例#11
0
def evaluate(path_hyper_parameter=path_hyper_parameters, rate=1.0):
    # 输入预测
    # 加载超参数
    hyper_parameters = load_json(path_hyper_parameter)
    pt = PreprocessTextMulti()
    # 模式初始化和加载
    graph = Graph(hyper_parameters)
    graph.load_model()
    ra_ed = graph.word_embedding

    # init confusion table
    dict_all = initConfusion()

    # get validation data
    ques_list, val_list, que, val = pt.preprocess_label_ques_to_idx(
        hyper_parameters['embedding_type'],
        hyper_parameters['data']['test_data'],
        ra_ed,
        rate=rate,
        shuffle=True)
    print(len(ques_list))
    print("que:", len(que))
    # print(val)

    # str to token
    ques_embed_list = []
    count = 0
    acc_count = 0
    not_none_count = 0
    not_none_acc_count = 0
    sum_iou = 0
    sum_all_iou = 0
    for index, que___ in enumerate(que):
        # print("原句 ", index, que[index])
        # print("真实分类 ", index, val[index])
        # print("ques: ", ques)
        ques_embed = ra_ed.sentence2idx(que[index])
        if hyper_parameters['embedding_type'] == 'albert':
            x_val_1 = np.array([ques_embed[0]])
            x_val_2 = np.array([ques_embed[1]])
            ques_embed = [x_val_1, x_val_2]
        else:
            x_val = ques_embed
        # print("ques_embed: ", ques_embed)
        if hyper_parameters['embedding_type'] == 'bert':
            x_val_1 = np.array([ques_embed[0]])
            x_val_2 = np.array([ques_embed[1]])
            x_val = [x_val_1, x_val_2]
        else:
            x_val = ques_embed
        # print("x_val", x_val)
        ques_embed_list.append(x_val)
        # 预测
        pred = graph.predict(x_val)
        # print(pred)
        # 取id to label and pred
        pre = pt.prereocess_idx(pred[0])
        # print("pre",pre)
        ls_nulti = []
        threshold = 0.65
        has_scope = False
        has_dense = False
        for i, ls in enumerate(pre[0]):
            if ls[0] in ['多发', '散发', '无']:
                if not has_scope:
                    has_scope = True
                    ls_nulti.append(ls)
                    if ls[0] in val[index].split(","):
                        dict_all[ls[0]]['TP'] += 1
                    else:
                        dict_all[ls[0]]['FN'] += 1
                else:
                    if ls[0] in val[index].split(","):
                        dict_all[ls[0]]['FP'] += 1
                    else:
                        dict_all[ls[0]]['TN'] += 1
            if ls[0] not in ['多发', '散发', '无']:
                if ls[1] > threshold or not has_dense:
                    ls_nulti.append(ls)
                    if ls[0] in val[index].split(","):
                        dict_all[ls[0]]['TP'] += 1
                    else:
                        dict_all[ls[0]]['FP'] += 1
                    has_dense = True
                else:
                    if ls[0] in val[index].split(","):
                        dict_all[ls[0]]['FN'] += 1
                    else:
                        dict_all[ls[0]]['TN'] += 1
        # print("预测结果", index, pre[0])
        # print(ls_nulti)
        res = cal_acc(ls_nulti, val[index].split(","))
        res_iou = cal_iou(ls_nulti, val[index].split(","))
        sum_iou += res_iou
        # sum_all_iou+=res_all_iou
        if res:
            # if val[index] != "无":
            #     not_none_acc_count += 1
            acc_count += 1
        else:
            print("原句 ", index, que[index])
            print("真实分类 ", index, val[index])
            print("pre ", pre)
            print("iou ", res_iou)
            print(ls_nulti)
        count += 1
        if val[index] != "无":
            not_none_count += 1
    print("acc: ", acc_count / count)
    # print("not none acc: ", not_none_acc_count / not_none_count)
    print("average iou: ", sum_iou / count)
    import prettytable as pt
    tb = pt.PrettyTable()
    tb.field_names = [" ", "Recall", "Precision", "TP", "FP", "TN", "FN"]
    for item in dict_all:
        if dict_all[item]['TP'] + dict_all[item]['FN'] == 0:
            recall = 1
        else:
            recall = dict_all[item]['TP'] / (dict_all[item]['TP'] +
                                             dict_all[item]['FN'])
        if dict_all[item]['TP'] + dict_all[item]['FP'] == 0:
            precision = 1
        else:
            precision = dict_all[item]['TP'] / (dict_all[item]['TP'] +
                                                dict_all[item]['FP'])
        # print(item,recall,precision)
        tb.add_row([
            item, recall, precision, dict_all[item]['TP'],
            dict_all[item]['FP'], dict_all[item]['TN'], dict_all[item]['FN']
        ])
    print(tb)
    # log
    append_log(hyper_parameters, acc_count / count,
               not_none_acc_count / not_none_count, threshold)
    pathlib.Path(os.path.abspath(__file__)).parent.parent.parent)
sys.path.append(project_path)
# 地址
from keras_textclassification.conf.path_config import path_model, path_fineture, path_model_dir, path_hyper_parameters
# 训练验证数据地址
from keras_textclassification.conf.path_config import path_baidu_qa_2019_train, path_baidu_qa_2019_valid
# 数据预处理, 删除文件目录下文件
from keras_textclassification.data_preprocess.text_preprocess import PreprocessText, read_and_process, load_json
# 模型图
from keras_textclassification.m02_TextCNN.graph import TextCNNGraph as Graph
import numpy as np
# flask
from flask import Flask, request, jsonify
app = Flask(__name__)

hyper_parameters = load_json(path_hyper_parameters)
pt = PreprocessText(path_model_dir)
# 模式初始化和加载
graph = Graph(hyper_parameters)
graph.load_model()
ra_ed = graph.word_embedding
ques = '我要打王者荣耀'
# str to token
ques_embed = ra_ed.sentence2idx(ques)
if hyper_parameters['embedding_type'] in ['bert', 'albert']:
    x_val_1 = np.array([ques_embed[0]])
    x_val_2 = np.array([ques_embed[1]])
    x_val = [x_val_1, x_val_2]
else:
    x_val = ques_embed
# 预测
 def __init__(self):
     self.l2i_i2l = None
     if os.path.exists(path_fast_text_model_l2i_i2l):
         self.l2i_i2l = load_json(path_fast_text_model_l2i_i2l)
def pred_tet(path_hyper_parameter=path_hyper_parameters,
             path_test=None,
             rate=1.0):
    """
        测试集测试与模型评估
    :param hyper_parameters: json, 超参数
    :param path_test:str, path of test data, 测试集
    :param rate: 比率, 抽出rate比率语料取训练
    :return: None
    """
    hyper_parameters = load_json(path_hyper_parameter)
    if path_test:  # 从外部引入测试数据地址
        hyper_parameters['data']['test_data'] = path_test
    time_start = time.time()
    # graph初始化
    graph = Graph(hyper_parameters)
    print("graph init ok!")
    graph.load_model()
    print("graph load ok!")
    ra_ed = graph.word_embedding
    # 数据预处理
    pt = PreprocessSim(path_model_dir)

    data = pd.read_csv(hyper_parameters['data']['test_data'])
    sentence_1 = data["sentence1"].values.tolist()
    sentence_2 = data["sentence2"].values.tolist()
    labels = data["label"].values.tolist()
    sentence_1 = [extract_chinese(str(line1).upper()) for line1 in sentence_1]
    sentence_2 = [extract_chinese(str(line2).upper()) for line2 in sentence_2]
    labels = [extract_chinese(str(line3).upper()) for line3 in labels]

    # 取该数据集的百分之几的语料测试
    len_rate = int(len(labels) * rate)
    sentence_1 = sentence_1[0:len_rate]
    sentence_2 = sentence_2[0:len_rate]
    labels = labels[0:len_rate]
    y_pred = []
    count = 0
    for i in range(len_rate):
        count += 1
        ques_embed = ra_ed.sentence2idx(text=sentence_1[i],
                                        second_text=sentence_2[i])
        # print(hyper_parameters['embedding_type'])
        if hyper_parameters['embedding_type'] in ['bert',
                                                  'albert']:  # bert数据处理, token
            x_val_1 = np.array([ques_embed[0]])
            x_val_2 = np.array([ques_embed[1]])
            x_val = [x_val_1, x_val_2]
            # 预测
            pred = graph.predict(x_val)
            pre = pt.prereocess_idx(pred[0])
            label_pred = pre[0][0][0]
            if count % 1000 == 0:
                print(label_pred)
            y_pred.append(label_pred)

    print("data pred ok!")
    # 预测结果转为int类型
    index_y = [pt.l2i_i2l['l2i'][i] for i in labels]
    index_pred = [pt.l2i_i2l['l2i'][i] for i in y_pred]
    target_names = [
        pt.l2i_i2l['i2l'][str(i)] for i in list(set((index_pred + index_y)))
    ]
    # 评估
    report_predict = classification_report(index_y,
                                           index_pred,
                                           target_names=target_names,
                                           digits=9)
    print(report_predict)
    print("耗时:" + str(time.time() - time_start))