def rel_statistics():
        flag = 14254  # 从这个开始是测试集
        # 1 读取2份数据
        # 2 读取对应的
        f1s = ct.file_read_all_lines_strip('../data/nlpcc2016/class/q.rdf.m_s.filter.txt')

        f1s_train = f1s[0:flag]
        f1s_test = f1s[flag:]
        r_train = set()
        r_test = set()

        for item in f1s_train:
            if len(str(item).split('\t')) < 4:
                continue
            r_train.add(str(item).split('\t')[3].lower().replace(' ', ''))
        for item in f1s_test:
            if len(str(item).split('\t')) < 4:
                continue
            r_test.add(str(item).split('\t')[3].lower().replace(' ', ''))

        # 得到指定数据集里面的全部属性
        r3 = (r_train | r_test) - r_train
        for r in r3:
            ct.just_log("../data/nlpcc2016/class/q.rdf.m_s.filter[r_in_test_not_in_train].txt", r)
        print(1)
Example #2
0
 def re_write(f1, f2):
     """将问题格式转换"""
     f1s = ct.file_read_all_lines_strip(f1)
     f2s = []
     for l1 in f1s:
         if str(l1).__contains__('question id'):
             f2s.append(str(l1).split('\t')[1].replace(' ', '').lower())
     ct.file_wirte_list(f2, f2s)
Example #3
0
 def re_write_m2id(f1, f_out):
     f1s = ct.file_read_all_lines_strip(f1)  # 读取所有的问题
     f2 = []
     for l1 in f1s:
         l1 = str(l1).replace(' ', '').replace('|||', '\t')
         l1 = ct.clean_str_s(l1)
         f2.append(l1)
     ct.file_wirte_list(f_out, f2)
     pass
Example #4
0
 def stat_all_space(f1):
     f1s = ct.file_read_all_lines_strip(f1)
     t1 = 0
     for l1 in f1s:
         l1_len = len(str(l1).split('\t'))
         if l1_len == 1:
             t1 += 1
     print(t1)
     pass
Example #5
0
def get_shuffle_indices_test(dh, step, train_part, model, train_step):
    """

    :param dh:
    :param step:
    :param train_part:
    :param model: train valid test
    :return:
    """
    if train_part == 'relation':
        if model == "valid":
            id_list = ct.get_static_id_list_debug(
                len(dh.train_question_list_index))
        else:
            id_list = ct.get_static_id_list_debug_test(
                len(dh.test_question_list_index))

        id_list = ct.random_get_some_from_list(id_list,
                                               FLAGS.evaluate_batchsize)

        id_list2 = [str(x) for x in id_list]
        # step  训练模式    训练部分
        ct.just_log(
            config.cc_par('combine_test'), '%s\t%s\t%s\t%s' %
            (train_step, model, train_part, '\t'.join(id_list2)))
    else:
        f1s = ct.file_read_all_lines_strip(config.cc_par('combine_test'))
        line = ''
        exist = False
        for l1 in f1s:
            if str(l1).split('\t')[0] == str(train_step) \
                    and str(l1).split('\t')[1] == model:
                line = str(l1)
                exist = True
                break
        if exist:
            line_split = line.split('\t')
            line_split = line_split[3:]
            line_split = [int(x) for x in line_split]
            id_list = np.array(line_split)
            ct.print(
                'get_shuffle_indices_test exist %s %s ' % (train_step, model),
                'shuffle_indices_test')
        else:  # 不存在就自己写
            if model == "valid":
                id_list = ct.get_static_id_list_debug(
                    len(dh.train_question_list_index))
            else:
                id_list = ct.get_static_id_list_debug_test(
                    len(dh.test_question_list_index))

            id_list = ct.random_get_some_from_list(id_list,
                                                   FLAGS.evaluate_batchsize)
            ct.print('get_shuffle_indices_test not exist %s ' % train_step,
                     'shuffle_indices_test')

    return id_list
Example #6
0
def load_vocabulary(file_name):
    f_name1 = os.path.join(path, file_name)
    print(f_name1)
    f1s = ct.file_read_all_lines_strip(f_name1)
    d1 = dict()
    for x in f1s:
        d1[str(x).split('\t')[0].split('_')[1]] = str(x).split('\t')[1]

    return d1
Example #7
0
def prepare_data():
    f1 = '../data/nlpcc2016/6-answer/q.rdf.ms.re.v1.txt'
    f3 = '../data/nlpcc2016/4-ner/extract_entitys_all_tj.txt'
    f4 = '../data/nlpcc2016/4-ner/extract_entitys_all_tj.sort_by_ner_lstm.txt'
    f1s = ct.file_read_all_lines_strip(f1)
    f3s = ct.file_read_all_lines_strip(f3)
    f1s_new = []
    f3s_new = []
    for i in range(len(f1s)):
        if str(f1s[i]).__contains__('NULL'):
            continue
        f1s_new.append(f1s[i])
        f3s_new.append(f3s[i])

    # 过滤NULL
    # 获取候选实体逐个去替代和判断

    # cs.append('立建候时么什是♠')
    # 读取出所有候选实体并打分取出前3 看准确率

    f4s = []
    _index = -1
    for l1 in f1s_new:  # 遍历每个问题
        _index += 1
        replace_qs = []
        for l3 in f3s_new[_index].split('\t'):
            q_1 = str(l1).split('\t')[0].replace(l3, '♠')
            replace_qs.append((q_1, l3))
        entitys = []
        for content, l3 in replace_qs:
            # content = input("input:")
            r1 = '1'
            entitys.append((l3, r1))
            # print(content)
            # print(r1)
            # print(score_list)
        entitys.sort(key=lambda x: x[1])
        entitys_new = [x[0] for x in entitys]

        f4s.append('\t'.join(entitys_new))
    ct.file_wirte_list(f4, f4s)
    def __init__(self, f1='../data/nlpcc2016/8-logistics/logistics-2018-03-10.txt_bak.txt',
                 f2= '../data/nlpcc2016/8-logistics/logistics-2018-03-10.txt_bak.txt'):
        self.all_datas = []
        # f1 = '../data/nlpcc2016/8-logistics/logistics-2018-03-10.txt_bak.txt'
        f1s = ct.file_read_all_lines_strip(f1)
        self.train_data = []
        self.test_data = []
        for f1_l in f1s:
            if str(f1_l).__contains__('\tvalid\t'):
                self.train_data.append(self.extract_line(f1_l))
            else:
                self.test_data.append(self.extract_line(f1_l))

        print('init ok')
Example #9
0
def get_shuffle_indices_train(total, step, train_part, model, train_step):
    """

    :param dh:
    :param step:
    :param train_part:
    :param model: train valid test
    :return:
    """
    if train_part == 'relation':
        shuffle_indices = np.random.permutation(np.arange(total))  # 打乱样本下标
        shuffle_indices1 = [str(x) for x in list(shuffle_indices)]
        # step  训练模式    训练部分
        ct.just_log(
            config.cc_par('combine'), '%s\t%s\t%s\t%s' %
            (train_step, model, train_part, '\t'.join(shuffle_indices1)))
    else:
        f1s = ct.file_read_all_lines_strip(config.cc_par('combine'))
        line = ''
        exist = False
        for l1 in f1s:
            if str(l1).split('\t')[0] == str(train_step):
                line = str(l1)
                exist = True
                break
        if exist:
            line_split = line.split('\t')
            line_split = line_split[3:]
            line_split = [int(x) for x in line_split]
            shuffle_indices = np.array(line_split)
            ct.print('get_shuffle_indices_train   exist %s' % train_step,
                     'shuffle_indices_train')
        else:  # 不存在就自己写
            shuffle_indices = np.random.permutation(np.arange(total))  # 打乱样本下标
            ct.print('get_shuffle_indices_train   not exist %s' % train_step,
                     'shuffle_indices_train')
            # step  训练模式    训练部分
            # ct.file_wirte_list(config.cc_par('combine'),
            #                    '%s\t%s\t%s\t%s' % (train_step, model, train_part, '\t'.join(shuffle_indices)))

    return shuffle_indices
    def ner_re_writer(f1='../data/nlpcc2016/ner_t1/q.rdf.m_s.filter.txt',
                      f2='../data/nlpcc2016/class/q.rdf.m_s.filter.re_writer.txt'):
        """
        重写问句库
        """
        # 1. 读取问句库
        # 2. 替换问句并输出

        f1s = ct.file_read_all_lines_strip(f1)
        f1s_new = []
        for f1s_l in f1s:
            s1 = str(f1s_l).split('\t')
            e1 = str(f1s_l).split('\t')[5]
            q1 = str(f1s_l).split('\t')[0].replace(' ','').lower()
            # .replace('','♠')
            q2 = str(q1).replace(e1, '♠')
            s1.append(q2)
            f1s_new.append('\t'.join(s1))
        ct.file_wirte_list(f2, f1s_new)

        print(1)
Example #11
0
    def __init__(self, f1=''):

        self.min_max_scaler = preprocessing.MinMaxScaler()

        self.all_datas = []
        # f1 = '../data/nlpcc2016/8-logistics/logistics-2018-03-10.txt_bak.txt'
        f1s = ct.file_read_all_lines_strip(f1)
        self.train_data = []
        self.test_data = []
        index = -1
        for f1_l in f1s:
            index += 1
            need_skip = False
            if str(f1_l).__contains__('NULL'):
                need_skip = True
            if str(f1_l).__contains__('####'):
                need_skip = True
            # 改成不包含则跳过
            if str(f1_l).__contains__('@@@@@@'):
                f1_l = str(f1_l).replace('1@@@@@@', '').replace('@@@@@@', '')
                # need_skip = True

            if need_skip:  # 实际没有跳过的
                print(f1_l, 'skip')
                continue

            # if index < config.cc_par('real_split_train_test_skip'):  # <= int(len(f1s)*0.8):
            m1 = False
            if m1:
                is_train = index < config.cc_par('real_split_train_test_skip')
            else:
                is_train = index < int(len(f1s) * 0.8)

            if is_train:
                self.train_data.append(self.extract_line(f1_l))
            else:
                self.test_data.append(self.extract_line(f1_l))
                # 在这里除了下归一化

        print('init ok')
Example #12
0
    def class2(f5='../data/nlpcc2016/5-class/class2.txt',
               f1="../data/nlpcc2016/3-questions/q.rdf.ms.re.v1.filter.txt"):
        f1s = ct.file_read_all_lines_strip(f1)
        f1s_new = [str(x).split('\t')[6] for x in f1s]

        q_patten_set = set()
        q_patten_dict = dict()
        q_count_dict = dict()
        for f1_line in f1s_new:
            q_patten_set.add(f1_line)
        # for q1 in q_patten_set:
        #     q_patten_dict[q1] = set()
        #     q_count_dict[q1] = 0

        gc1 = ct.generate_counter()
        for q1 in q_patten_set:  # 遍历唯一问题集合
            for f1_line in f1s:  # 遍历问题集合
                index = gc1()
                if index % 100000 == 0:
                    print("%d - %d " % (index / 100000, len(q_patten_set) * len(f1s) / 100000))
                _q1 = str(f1_line).split('\t')[6]
                _ps = str(f1_line).split('\t')[3]
                q1 = str(q1)
                if _q1!= '♠' and  _q1.__contains__(q1):  # 相等 或者 包含?
                    if q1 in q_patten_dict:
                        s1 = q_patten_dict[q1]
                        s1.add(_ps)
                        q_patten_dict[q1] = s1
                        q_count_dict[q1] += 1
                    else:
                        s1 = set()
                        s1.add(_ps)
                        try:
                            q_patten_dict[q1] = s1
                        except Exception as e11:
                            print(e11)
                        q_count_dict[q1] = 1

        tp = ct.sort_dict(q_count_dict)
        f5s = []
        for t in tp:
            f5s.append("%s\t%s\t%s" % (t[0], t[1], '\t'.join(list(q_patten_dict[t[0]]))))
        ct.file_wirte_list(f5, f5s)

        #  -------

        keys = q_patten_dict.keys()
        words_bag_list = []
        for key in keys:

            # words = set(str(key).split('\t'))
            words = q_patten_dict.get(key)  # words  规划总面积	建筑面积	显示器尺寸	面积	占地总面积

            exist = False
            wl_index = -1
            for word in words:  # 遍历每个单词
                for wl_index in range(len(words_bag_list)):  # 这个单词去匹配一遍所有的
                    if word in words_bag_list[wl_index]:
                        exist = True
                        break
                if exist:
                    break
                    # 把当前的words全部整合进去
            if exist:
                wbl = words_bag_list[wl_index]
                for word in words:  # 遍历每个单词
                    wbl.add(word)
                words_bag_list[wl_index] = wbl
            else:
                s1 = set()
                for word in words:  # 遍历每个单词
                    s1.add(word)
                words_bag_list.append(s1)
        # 输出 words_bag_list
        f5s = []
        for words_bag in words_bag_list:
            f5s.append('\t'.join(list(words_bag)))
        ct.file_wirte_list(f5 + '.combine.txt', f5s)
 def init_data(self,f1=''):
     self.question_list = ct.file_read_all_lines_strip(f1)
     # 转换为 数字
     print(1)
Example #14
0
def valid_step(sess, lstm, step, train_op, test_q, test_r, labels, merged,
               writer, dh, model, global_index):
    start_time = time.time()
    feed_dict = {
        lstm.test_input_q: test_q,
        lstm.test_input_r: test_r,
    }
    question = ''
    relations = []
    for _ in test_q:
        v_s_1 = dh.converter.arr_to_text_no_unk(_)
        valid_msg = model + " test_q 1:" + v_s_1
        ct.just_log2("valid_step", valid_msg)
        question = v_s_1
    for _ in test_r:
        v_s_1 = dh.converter.arr_to_text_no_unk(_)
        valid_msg = model + " test_r 1:" + v_s_1
        ct.just_log2("valid_step", valid_msg)
        relations.append(v_s_1)

    error_test_q = []
    error_test_pos_r = []
    error_test_neg_r = []
    fuzzy_boundary = []

    test_q_r_cosin = sess.run([lstm.test_q_r], feed_dict=feed_dict)

    test_q_r_cosin = test_q_r_cosin[0]
    right, wrong, score = [0.0] * 3
    st_list = []  # 各个关系的得分

    # 用第一个和其他的比较,这个是另一种判定正确的办法,
    # for i in range(1, len(test_q_r_cosin)):
    #     compare_res = ct.nump_compare_matix(test_q_r_cosin[0], test_q_r_cosin[i])
    #     ct.print("compare_res:" + str(compare_res))

    for i in range(0, len(test_q_r_cosin)):
        st = ct.new_struct()
        st.index = i
        st.cosine_matix = test_q_r_cosin[i]
        ori_cand_score_mean = np.mean(test_q_r_cosin[i])
        st.score = ori_cand_score_mean
        st_list.append(st)
        # ct.print(ori_cand_score_mean)
    # 将得分和index结合,然后得分排序
    st_list.sort(key=ct.get_key)
    st_list.reverse()
    st_list_sort = st_list  # 取全部 st_list[0:5]

    ct.just_log2("info", "\n ##3 score")
    score_list = []
    for st in st_list_sort:
        # ct.print("index:%d ,score= %f " % (st.index, st.score))
        # mylog.logger.info("index:%d ,score= %f " % (st.index, st.score))
        # 得到得分排序前X的index
        # 根据index找到对应的关系数组
        # 得到得分最高的关系跟labels做判断是否是正确答案,加入统计
        better_index = st.index
        # 根据对应的关系数组找到对应的文字
        r1 = dh.converter.arr_to_text_no_unk(test_r[better_index])
        # ct.print(r1)
        ct.just_log2(
            "info", "step:%d st.index:%d,score:%f,r:%s" %
            (step, st.index, st.score, r1))
        if st.index == 0:
            _tmp_right = 1
        else:
            _tmp_right = 0
        # 训练的epoches步骤,R的index,得分,是否正确,关系,字表面特征
        score_list.append(
            "%d_%d_%f_%s" %
            (st.index, _tmp_right, st.score, r1.replace('_', '-')))
    _tmp_msg1 = "%d\t%s\t%d\t%s\t%s" % (step, model, global_index, question,
                                        '\t'.join(score_list))
    ct.just_log2("logistics", _tmp_msg1)
    # 记录到单独文件

    is_right = False
    msg = " win r =%d  " % st_list_sort[0].index
    ct.log3(msg)
    if st_list_sort[0].index == 0:
        ct.print(
            "================================================================ok"
        )
        is_right = True
    else:
        # todo: 在此记录该出错的题目和积分比pos高的neg关系
        # q,pos,neg
        # error_test_q.append()
        # 找到
        for st in st_list_sort:
            # 在此记录st list的neg
            if st.index == 0:
                break
            else:
                error_test_neg_r.append(test_r[st.index])
                error_test_q.append(test_q[0])
                error_test_pos_r.append(test_r[0])
        ct.print(
            "================================================================error"
        )
        ct.just_log2("info", "!!!!! error %d  " % step)
    ct.just_log2("info", "\n =================================end\n")

    # 在这里增加跳变检查,通过一个文件动态判断是否执行
    # 实际是稳定的模型来执行
    run = ct.file_read_all_lines_strip('config')[0] == '1'
    ct.print("run %s " % run, 'info')
    maybe_list = []
    if run:
        st_list_sort = list(st_list_sort)
        for index in range(len(st_list_sort) - 1):
            # if index == len(st_list_sort) :
            #     continue

            space = st_list_sort[index].score - st_list_sort[index + 1].score
            maybe_list.append(st_list_sort[index])

            if space > config.skip_threshold():
                break

        # 判断是否在其中
        pos_in_it = False
        for index in range(len(maybe_list)):
            item = maybe_list[index]
            # 输出相关的相近属性,并记录是否在其中,并作出全局准确率预测
            if item.index == 0 and pos_in_it == False:
                pos_in_it = True
            better_index = item.index
            r1 = dh.converter.arr_to_text_no_unk(test_r[better_index])
            item.relation = r1
            msg1 = "step:%d st.index:%d,score:%f,r:%s" % (step, item.index,
                                                          item.score, r1)
            item.msg1 = msg1
            maybe_list[index] = item
            ct.print(msg1, "maybe")

        if pos_in_it:
            maybe_dict['r'] += 1
        else:
            maybe_dict['e'] += 1
        acc0 = maybe_dict['r'] / (maybe_dict['r'] + maybe_dict['e'])
        # ct.print("%f pos_in_it  %s" % (acc0, pos_in_it), "maybe")
        # ct.print("\n", "maybe")

    time_elapsed = time.time() - start_time
    time_str = datetime.datetime.now().isoformat()
    ct.print("%s: step %s,  score %s, is_right %s, %6.7f secs/batch" %
             (time_str, step, score, str(is_right), time_elapsed))
    return is_right, error_test_q, error_test_pos_r, error_test_neg_r, maybe_list
Example #15
0
            output[ngramTemp] = 0  # 典型的字典操作
        output[ngramTemp] += 1
    return output


def all_gram(new_line):
    # for i in range(new_line_len):
    all_entitys = []
    for i in range(len(new_line)):
        index = len(new_line) - int(i)
        print(index)
        all_entitys.extend(get_ngrams(new_line, index))
    return all_entitys


f1s = ct.file_read_all_lines_strip('../data/nlpcc2016/ner_t1/q.rdf.txt')
f2s = ct.file_read_all_lines_strip('../data/nlpcc2016/ner_t1/q.rdf-1.txt')

l1 = []
l2 = []


def math1(p1):
    count = 0
    p1s = all_gram(p1)
    for p in p1s:
        if p in line:
            count += 1
    return count

Example #16
0
def main(_):
    # prepare_data()
    # FLAGS.start_string = FLAGS.start_string.decode('utf-8')
    # converter = TextConverter(filename=FLAGS.converter_path)
    if os.path.isdir(FLAGS.checkpoint_path):
        FLAGS.checkpoint_path = \
            tf.train.latest_checkpoint(FLAGS.checkpoint_path)

    model_path = os.path.join('model', FLAGS.name)
    if os.path.exists(model_path) is False:
        os.makedirs(model_path)
    model = 'ner'
    dh = data_helper.DataClass(model)
    train_batch_size = 1
    # g = dh.batch_iter_char_rnn(train_batch_size)  # (FLAGS.num_seqs, FLAGS.num_steps)
    embedding_weight = dh.embeddings

    model = CharRNN(dh.converter.vocab_size,  # 词汇表大小 从其中生成所有候选
                    num_seqs=train_batch_size,  # FLAGS.num_seqs,  # ? 一个batch 的 句子 数目
                    num_steps=dh.max_document_length,  # FLAGS.num_steps,  # 一个句子的长度
                    lstm_size=FLAGS.lstm_size,
                    num_layers=FLAGS.num_layers,
                    learning_rate=FLAGS.learning_rate,
                    train_keep_prob=FLAGS.train_keep_prob,
                    use_embedding=FLAGS.use_embedding,
                    embedding_size=FLAGS.embedding_size,
                    embedding_weight=embedding_weight,
                    sampling=True,
                    dh=dh
                    )

    model.load(FLAGS.checkpoint_path)
    # cs = []
    # cs.append('♠是什么类型的产品')
    # cs.append('♠是谁')
    # cs.append('♠是哪个公司的长度')
    f1 = '../data/nlpcc2016/6-answer/q.rdf.ms.re.v1.txt'
    f3 = '../data/nlpcc2016/4-ner/extract_entitys_all_tj.v1.txt'
    f4 = '../data/nlpcc2016/4-ner/extract_entitys_all_tj.sort_by_ner_lstm.v1.txt'
    f1s = ct.file_read_all_lines_strip(f1)
    f3s = ct.file_read_all_lines_strip(f3)
    f1s_new = []
    f3s_new = []
    for i in range(len(f1s)):
        # if str(f1s[i]).__contains__('NULL'):
        #     continue
        f1s_new.append(f1s[i])
        f3s_new.append(f3s[i])

    # 过滤NULL
    # 获取候选实体逐个去替代和判断

    # cs.append('立建候时么什是♠')
    # 读取出所有候选实体并打分取出前3 看准确率

    f4s = []
    _index = -1
    for l1 in f1s_new:  # 遍历每个问题
        _index += 1
        replace_qs = []
        for l3 in f3s_new[_index].split('\t'):
            q_1 = str(l1).split('\t')[0].replace(l3, '♠')
            replace_qs.append((q_1, l3))
        entitys = []
        for content, l3 in replace_qs:
            # content = input("input:")
            start = dh.convert_str_to_indexlist_2(content, False)

            # arr = model.sample(FLAGS.max_length, start, dh.converter.vocab_size,dh.get_padding_num())
            # #converter.vocab_size
            r1, score_list = model.judge(start, dh.converter.vocab_size)
            entitys.append((l3, r1))
            # print(content)
            # print(r1)
            # print(score_list)
            ct.print("%s\t%s\t%s" % (content, l3, r1), 'debug_process')
        entitys.sort(key=lambda x: x[1])
        entitys_new = [x[0] for x in entitys]
        ct.print('\t'.join(entitys_new))
        f4s.append('\t'.join(entitys_new))
    ct.file_wirte_list(f4, f4s)