Esempio n. 1
0
    def prodeuce_embedding_vec_file(
            filename, path="../data/simple_questions/fb_0_2m_files"):
        dh = data_helper.DataClass("sq")
        model = models.Word2Vec.load(filename)
        # 遍历每个单词,查出word2vec然后输出

        v_base = model['end']
        ct.print(v_base)

        for word in dh.converter.vocab:
            try:
                v = model[word]
            except Exception as e1:
                msg1 = "%s : %s " % (word, e1)
                ct.print(msg1)
                ct.just_log(path + "/wiki.vector.log", msg1)
                v = model['end']
            m_v = ' '.join([str(x) for x in list(v)])
            msg = "%s %s" % (word, str(m_v))
            ct.just_log(path + "/wiki.vector", msg)
        # 多记录一个单词
        word = 'end'
        v = model[word]
        m_v = ' '.join([str(x) for x in list(v)])
        msg = "%s %s" % (word, str(m_v))
        ct.just_log(path + "/wiki.vector", msg)
    def __init__(self):
        self.init_data()
        ct.print("init_cc_questions finish.")
        self.converter = read_utils.TextConverter(filename=config.par('cc_vocab'), type="zh-cn")


        # self.load_all_q_r_tuple(config.get_static_q_num_debug(), config.get_static_num_debug(), is_record=True)
        self.get_max_length()
Esempio n. 3
0
def build_spo(dh, line, s_list):
    spo_tuple = []
    all_cands = s_list  # 候选实体列表
    q_current = line  # 当前问句
    for cand_s_neg_item in all_cands:
        num = 99  # 最多99个属性
        temp_cand_ps_neg, temp_cand_as_neg = \
            dh.bh.kb_get_p_o_by_s_limit(cand_s_neg_item, [], num)
        for _cand_ps_neg_item, _as in \
                zip(temp_cand_ps_neg, temp_cand_as_neg):
            # 如果实体和属性都是正确的,则跳过
            ct.print("%s\t%s\t%s" % (cand_s_neg_item, _cand_ps_neg_item, _as))
            _spo = (cand_s_neg_item, _cand_ps_neg_item, _as)
            spo_tuple.append(_spo)

    # 填充数据
    q_ = []  # 问题集合  q_
    q_p = []  # 用于训练属性的问题集合 q_p
    q_s = []  # 用于训练的实体问题集合 q_s
    q_a = []  # 用于答案属性的问题集合
    s_pos = []  # 正确的实体 s_pos
    s_neg = []  # 错误的实体 s_neg
    p_pos = []  # 正确的属性
    p_neg = []  # 错误的属性
    a_pos = []  # 正确的答案
    a_neg = []  # 错误的答案
    label_p_pos = []  # p 相同
    label_s_pos = []  # p 相同
    _index = 0
    for item in spo_tuple:
        cand_s_neg_item = item[0]
        _cand_ps_neg_item = item[1]
        _as = item[2]
        _index += 1
        # 使用原始的问句来避免跟验证属性的冲突
        q_.append(dh.convert_str_to_indexlist(q_current))  # 待增加模式替换对应S
        q_s.append(
            dh.convert_str_to_indexlist(
                q_current.replace(_cand_ps_neg_item, '♢')))
        q_current_for_p = q_current.replace(cand_s_neg_item,
                                            '♠')  # 去掉实体的问句,用于属性训练
        q_p.append(dh.convert_str_to_indexlist(q_current_for_p))
        q_current_for_a = q_current_for_p.replace(_cand_ps_neg_item,
                                                  '♢')  # 去掉属性的问句,用于属性训练
        q_a.append(dh.convert_str_to_indexlist(q_current_for_a))
        # q_current_for_e = str(q_current).replace('♠', s1_in_q)  # 去掉属性的问句,用于实体训练
        # 问题 question_list_index[global_index]
        # y_pos.append(self.relation_list_index[global_index])
        # s_pos.append(dh.convert_str_to_indexlist(s1_in_q))  # 正确的实体
        s_neg.append(dh.convert_str_to_indexlist(cand_s_neg_item))  # 候选的实体
        # p_pos.append(dh.convert_str_to_indexlist(r_pos1))
        p_neg.append(dh.convert_str_to_indexlist(_cand_ps_neg_item))
        # a_pos.append(dh.convert_str_to_indexlist(a_in_q_pos))
        a_neg.append(dh.convert_str_to_indexlist(_as))
    data_dict = dh.return_dict(a_neg, a_pos, p_neg, p_pos, q_, q_a, q_p, q_s,
                               s_neg, s_pos, label_s_pos, label_p_pos)
    return data_dict
Esempio n. 4
0
def get_shuffle_indices_test(dh, step, train_part, model, train_step):
    """

    :param dh:
    :param step:
    :param train_part:
    :param model: train valid test
    :return:
    """
    if train_part == 'relation':
        if model == "valid":
            id_list = ct.get_static_id_list_debug(
                len(dh.train_question_list_index))
        else:
            id_list = ct.get_static_id_list_debug_test(
                len(dh.test_question_list_index))

        id_list = ct.random_get_some_from_list(id_list,
                                               FLAGS.evaluate_batchsize)

        id_list2 = [str(x) for x in id_list]
        # step  训练模式    训练部分
        ct.just_log(
            config.cc_par('combine_test'), '%s\t%s\t%s\t%s' %
            (train_step, model, train_part, '\t'.join(id_list2)))
    else:
        f1s = ct.file_read_all_lines_strip(config.cc_par('combine_test'))
        line = ''
        exist = False
        for l1 in f1s:
            if str(l1).split('\t')[0] == str(train_step) \
                    and str(l1).split('\t')[1] == model:
                line = str(l1)
                exist = True
                break
        if exist:
            line_split = line.split('\t')
            line_split = line_split[3:]
            line_split = [int(x) for x in line_split]
            id_list = np.array(line_split)
            ct.print(
                'get_shuffle_indices_test exist %s %s ' % (train_step, model),
                'shuffle_indices_test')
        else:  # 不存在就自己写
            if model == "valid":
                id_list = ct.get_static_id_list_debug(
                    len(dh.train_question_list_index))
            else:
                id_list = ct.get_static_id_list_debug_test(
                    len(dh.test_question_list_index))

            id_list = ct.random_get_some_from_list(id_list,
                                                   FLAGS.evaluate_batchsize)
            ct.print('get_shuffle_indices_test not exist %s ' % train_step,
                     'shuffle_indices_test')

    return id_list
Esempio n. 5
0
    def checkpoint(self, sess, state):
        # Output directory for models and summaries

        out_dir = ct.log_path_checkpoint(state)
        ct.print("Writing to {}\n".format(out_dir))
        # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it
        checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
        # checkpoint_prefix = os.path.join(checkpoint_dir, "model")
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=2)
        save_path = saver.save(sess, os.path.join(out_dir, "model.ckpt"), 1)
        # load_path = saver.restore(sess, save_path)
        # 保存完加载一次试试看
        msg1 = "save_path:%s" % save_path
        ct.just_log2('model', msg1)
Esempio n. 6
0
def valid_batch_debug(sess, lstm, step, train_op, merged, writer, dh,
                      batchsize, train_question_list_index,
                      train_relation_list_index, model,
                      test_question_global_index, train_part, id_list):
    right = 0
    wrong = 0
    # 产生随机的index给debug那边去获得index
    # 仅供现在验证用
    # if model == "valid":
    #     id_list = ct.get_static_id_list_debug(len(dh.train_question_list_index))
    # else:
    #     id_list = ct.get_static_id_list_debug_test(len(dh.test_question_list_index))

    # id_list = ct.random_get_some_from_list(id_list, FLAGS.evaluate_batchsize)

    error_test_q_list = []
    error_test_pos_r_list = []
    error_test_neg_r_list = []
    maybe_list_list = []
    maybe_global_index_list = []  # 问题的全局index
    questions_ok_dict = dict()
    if batchsize > len(id_list):
        batchsize = len(id_list)
        ct.print('batchsize too big ,now is %d' % batchsize, 'error')
    for i in range(batchsize):
        try:
            index = id_list[i]
        except Exception as e1:
            ct.print(e1, 'error')
        if model == "test":
            global_index = test_question_global_index[index]
        else:
            global_index = test_question_global_index[index]
        ct.print("valid_batch_debug:%s %d ,index = %d ;global_index=%d " %
                 (model, i, index, global_index))
        test_q, test_r, labels = \
            dh.batch_iter_wq_test_one_debug(train_question_list_index, train_relation_list_index, model, index,
                                            train_part)

        ok, error_test_q, error_test_pos_r, error_test_neg_r, maybe_list = valid_step(
            sess, lstm, step, train_op, test_q, test_r, labels, merged, writer,
            dh, model, global_index)
        error_test_q_list.extend(error_test_q)
        error_test_pos_r_list.extend(error_test_pos_r)
        error_test_neg_r_list.extend(error_test_neg_r)
        maybe_list_list.append(maybe_list)
        maybe_global_index_list.append(global_index)
        if ok:
            right += 1
        else:
            wrong += 1

        questions_ok_dict[global_index] = ok
    acc = right / (right + wrong)
    ct.print("right:%d wrong:%d" % (right, wrong), "debug")
    return acc, error_test_q_list, error_test_pos_r_list, error_test_neg_r_list, maybe_list_list, maybe_global_index_list, questions_ok_dict
Esempio n. 7
0
    def batch_iter(self, data, batch_size):
        total = len(data)
        shuffle_indices = np.random.permutation(np.arange(total))  # 打乱样本下标

        info1 = "q total:%d ; epohches-size:%s " % (total, len(data) // batch_size)
        ct.print(info1, 'info')
        x_new = []
        y_new = []
        z_new = []
        p_new = []



        rith_answer = 0
        right_index = 0
        for list_index in range(total):
            index = -1
            data_current = data[shuffle_indices[list_index]]
            for ts in data_current[2]:
                index += 1
                x_new.append(data_current[1])
                # t1 = (index, right1, score, relation,right,z_score)
                y_new.append((ts[2], ts[5]))  # 继续遍历
                z_new.append(ts[1])
                p_new.append(ts[3])
                # ts

                # 问题  Z分数 NN得分
                if int(ts[1][0])==1:
                    right_index = index
                    rith_answer+=1
                msg = "%s\t%s\t%s\t%s\t%s" % (data_current[1], ts[5], ts[2], ts[3],ts[4])
                ct.print(msg, 'debug1')

            # if list_index % batch_size == 0 and list_index != 0:
            x_return = x_new.copy()  # 问题
            y_return = y_new.copy()  # 数据
            z_return = z_new.copy()  # 标签
            p_return = p_new.copy()  # 属性

            x_new.clear()
            y_new.clear()
            z_new.clear()
            p_new.clear()
            yield np.array(x_return), np.array(y_return), np.array(z_return),np.array(p_return),right_index
Esempio n. 8
0
def get_shuffle_indices_train(total, step, train_part, model, train_step):
    """

    :param dh:
    :param step:
    :param train_part:
    :param model: train valid test
    :return:
    """
    if train_part == 'relation':
        shuffle_indices = np.random.permutation(np.arange(total))  # 打乱样本下标
        shuffle_indices1 = [str(x) for x in list(shuffle_indices)]
        # step  训练模式    训练部分
        ct.just_log(
            config.cc_par('combine'), '%s\t%s\t%s\t%s' %
            (train_step, model, train_part, '\t'.join(shuffle_indices1)))
    else:
        f1s = ct.file_read_all_lines_strip(config.cc_par('combine'))
        line = ''
        exist = False
        for l1 in f1s:
            if str(l1).split('\t')[0] == str(train_step):
                line = str(l1)
                exist = True
                break
        if exist:
            line_split = line.split('\t')
            line_split = line_split[3:]
            line_split = [int(x) for x in line_split]
            shuffle_indices = np.array(line_split)
            ct.print('get_shuffle_indices_train   exist %s' % train_step,
                     'shuffle_indices_train')
        else:  # 不存在就自己写
            shuffle_indices = np.random.permutation(np.arange(total))  # 打乱样本下标
            ct.print('get_shuffle_indices_train   not exist %s' % train_step,
                     'shuffle_indices_train')
            # step  训练模式    训练部分
            # ct.file_wirte_list(config.cc_par('combine'),
            #                    '%s\t%s\t%s\t%s' % (train_step, model, train_part, '\t'.join(shuffle_indices)))

    return shuffle_indices
Esempio n. 9
0
def test_step1(data, model):
    total_loss = 0.0
    total_acc = 0.0
    total = 0
    gc_valid = lh.batch_iter(data, batch_size)
    error_count = 0
    right_count = 0

    for gc_valid_item in gc_valid:
        total += 1
        x1 = gc_valid_item[1]
        y1 = gc_valid_item[2]
        p1 = gc_valid_item[3]
        r_i = gc_valid_item[4]
        x1 = x1.reshape(-1, 2)
        y1 = y1.reshape(-1, 2)

        _z, _z1, _z_max, _z_right, _loss, _w, _b, _accuracy = \
            sess.run([test_z, test_z1, test_z_max, test_z_right, test_loss, w, b, test_accuracy],
                     feed_dict={test_x: x1, test_y: y1, test_right_index: r_i})

        ct.print(_w, 'w')
        ct.print(_b, 'b')
        if _accuracy:

            right_count += 1
        else:
            error_count += 1
        # print(_loss)
        total_loss += _loss
        total_acc += _accuracy
    ct.print(
        'model %s epoch %s   loss = %s,  acc = %s error_count %s right_count %s total:%s' %
        (model, epoch, total_loss / total, right_count / total, error_count, right_count, total),
        'debug')
    ct.print(_w, 'w')
    ct.print(_b, 'b')
    return w, b
Esempio n. 10
0
def run_step2(sess, lstm, step, trainstep, train_op, train_q, train_cand,
              train_neg, merged, writer, dh, use_error):
    start_time = time.time()
    feed_dict = {
        lstm.ori_input_quests: train_q,  # ori_batch
        lstm.cand_input_quests: train_cand,  # cand_batch
        lstm.neg_input_quests: train_neg  # neg_batch
    }

    # ct.check_len(train_q,15)
    # ct.check_len(train_cand, 15)
    # ct.check_len(train_neg, 15)
    summary, l1, acc1, embedding1, train_op1, \
    ori_cand_score, ori_neg_score, ori_quests_out = sess.run(
        [merged, lstm.loss, lstm.acc, lstm.embedding, train_op,
         lstm.ori_cand, lstm.ori_neg, lstm.ori_quests],
        feed_dict=feed_dict)

    time_str = datetime.datetime.now().isoformat()
    right, wrong, score = [0.0] * 3
    for i in range(0, len(train_q)):
        ori_cand_score_mean = np.mean(ori_cand_score[i])
        ori_neg_score_mean = np.mean(ori_neg_score[i])
        if ori_cand_score_mean > 0.55 and ori_neg_score_mean < 0.4:
            right += 1.0
        else:
            wrong += 1.0
        score += ori_cand_score_mean - ori_neg_score_mean
    time_elapsed = time.time() - start_time

    writer.add_summary(summary, trainstep)
    # ct.print("STEP:" + str(step) + " loss:" + str(l1) + " acc:" + str(acc1))
    info = "use_error %s %s: step %s, loss %s, acc %s, score %s,right %s wrong %s, %6.7f secs/batch " % (
        use_error, time_str, trainstep, l1, acc1, score, right, wrong,
        time_elapsed)
    ct.just_log2("info", info)
    ct.print(info)
    if use_error and l1 == 0.0 and acc1 == 1.0:
        ct.just_log2("debug",
                     "step=%s,train_step=%s------" % (step, trainstep))
        dh.log_error_r(train_q, "train_q")
        dh.log_error_r(train_cand, "train_cand")
        dh.log_error_r(train_neg, "train_neg")
        ct.print("??????")

    if l1 == 0.0 and acc1 == 1.0:
        dh.loss_ok += 1
        ct.log3("loss = 0.0  %d " % dh.loss_ok)
        ct.print("loss == 0.0 and acc == 1.0 checkpoint and exit now = %d" %
                 dh.loss_ok)
        if dh.loss_ok == FLAGS.stop_loss_zeor_count:
            checkpoint(sess)
            os._exit(0)
    else:
        dh.loss_ok = 0
    # ct.prin(1)
    if (trainstep + 1) % FLAGS.check == 0:
        checkpoint(sess)
Esempio n. 11
0
def pre_ner(dh, res, top_k):

    # 排序
    res2 = dh.bkt.sort_sentence(res, dh.bh)
    ct.print(res2)
    # 实体抽取《》
    list1_new = [
        dh.bh.entity_re_extract_one_repeat(ct.clean_str_zh2en(x)) for x in res2
    ]
    # 去掉重复
    list1_new = ct.list_no_repeat(list1_new)  # 去掉重复
    # 去掉包含
    # 5.8.3 去掉词语包含试试 有一首歌叫	有一首歌	一首歌
    if True:
        # 能略微提高
        list1_new_2 = []
        for list1_new_word in list1_new:
            if not ct.be_contains(list1_new_word, list1_new):
                list1_new_2.append(list1_new_word)
        list1_new = list1_new_2
    ct.print(list1_new)
    list1_new = list1_new[0:top_k]
    return list1_new
Esempio n. 12
0
def main():
    time.sleep(0.5)  # 休息0.5 秒让之前的进程退出
    now = "\n\n\n" + str(datetime.datetime.now().isoformat())
    # test 是完整的; small 是少量 ; debug 只是一次
    model = FLAGS.mode
    ct.print("tf:%s should be 1.2.1 model:%s " %
             (str(tf.__version__), model))  # 1.2.1
    ct.just_log2("info", now)
    ct.just_log2("valid", now)
    ct.just_log2("test", now)
    ct.just_log2("info", get_config_msg())
    ct.log3(now)

    embedding_weight = None
    error_test_dict = dict()
    valid_test_dict = dict()
    # 1 读取所有的数据,返回一批数据标记好的数据{data.x,data.label}
    dh = data_helper.DataClass(model)
    if FLAGS.word_model == "word2vec_train":
        embedding_weight = dh.embeddings

    # 3 构造模型LSTM类
    ct.print("max_document_length=%s,vocab_size=%s " %
             (str(dh.max_document_length), str(dh.converter.vocab_size)))
    lstm = mynn.CustomNetwork(
        max_document_length=dh.max_document_length,  # timesteps
        word_dimension=FLAGS.word_dimension,  # 一个单词的维度
        vocab_size=dh.converter.vocab_size,  # embedding时候的W的大小embedding_size
        rnn_size=FLAGS.rnn_size,  # 隐藏层大小
        model=model,
        need_cal_attention=FLAGS.need_cal_attention,
        need_max_pooling=FLAGS.need_max_pooling,
        word_model=FLAGS.word_model,
        embedding_weight=embedding_weight,
        need_gan=False)

    # 4 ----------------------------------- 设定loss-----------------------------------
    global_step = tf.Variable(0, name="globle_step", trainable=False)
    tvars = tf.trainable_variables()
    grads, _ = tf.clip_by_global_norm(tf.gradients(lstm.loss, tvars),
                                      FLAGS.max_grad_norm)
    optimizer = tf.train.GradientDescentOptimizer(1e-1)
    optimizer.apply_gradients(zip(grads, tvars))
    train_op = optimizer.apply_gradients(zip(grads, tvars),
                                         global_step=global_step)

    # 初始化
    init = tf.global_variables_initializer()
    merged = tf.summary.merge_all()

    with tf.Session().as_default() as sess:
        writer = tf.summary.FileWriter("log/", sess.graph)
        sess.run(init)

        embeddings = []
        use_error = False
        error_test_q_list = []
        error_test_pos_r_list = []
        error_test_neg_r_list = []

        # 测试输出所以的训练问题和测试问题
        # dh.build_train_test_q()
        #
        train_step = 0
        max_acc = 0
        for step in range(FLAGS.epoches):

            toogle_line = ">>>>>>>>>>>>>>>>>>>>>>>>>step=%d,total_train_step=%d " % (
                step, len(dh.q_neg_r_tuple))
            ct.log3(toogle_line)
            ct.just_log2("info", toogle_line)

            # 数据准备
            my_generator = ''
            if FLAGS.fix_model and len(error_test_q_list) != 0:
                my_generator = dh.batch_iter_wq_debug_fix_model(
                    error_test_q_list, error_test_pos_r_list,
                    error_test_neg_r_list, FLAGS.batch_size)
                use_error = True
                toogle_line = "\n\n\n\n\n------------------use_error to train"
                ct.log3(toogle_line)
                ct.just_log2("info", toogle_line)
                ct.just_log2("valid", 'use_error to train')
                ct.just_log2("test", 'use_error to train')
            elif ct.is_debug_few():
                toogle_line = "\n------------------is_debug_few to train"
                ct.log3(toogle_line)
                ct.just_log2("info", toogle_line)
                train_part = config.cc_par('train_part')
                model = 'train'
                # 属性就生成问题就读取
                shuffle_indices = get_shuffle_indices_train(
                    len(dh.q_neg_r_tuple_train), step, train_part, model,
                    train_step)

                if train_part == 'relation':
                    my_generator = dh.batch_iter_wq_debug(
                        dh.train_question_list_index,
                        dh.train_relation_list_index, shuffle_indices,
                        FLAGS.batch_size, train_part)
                else:
                    my_generator = dh.batch_iter_wq_debug(
                        dh.train_question_list_index,
                        dh.train_answer_list_index, shuffle_indices,
                        FLAGS.batch_size, train_part)
            else:
                # 不用
                train_q, train_cand, train_neg = \
                    dh.batch_iter_wq(dh.train_question_list_index, dh.train_relation_list_index,
                                     FLAGS.batch_size)

            toogle_line = "\n==============================train_step=%d\n" % train_step
            ct.just_log2("info", toogle_line)
            ct.log3(toogle_line)

            # 训练数据
            for gen in my_generator:
                toogle_line = "\n==============================train_step=%d\n" % train_step
                ct.just_log2("info", toogle_line)
                ct.log3(toogle_line)

                if not use_error:
                    train_step += 1

                train_q = gen[0]
                train_cand = gen[1]
                train_neg = gen[2]
                run_step2(sess, lstm, step, train_step, train_op, train_q,
                          train_cand, train_neg, merged, writer, dh, use_error)

                if use_error:
                    continue
                    # -------------------------test
                    # 1 源数据,训练数据OR验证数据OR测试数据

            # 验证
            valid_test_dict, error_test_dict, max_acc, all_right,\
                error_test_q_list, error_test_pos_r_list, error_test_neg_r_list \
                = valid_test_checkpoint(train_step, dh, step, sess, lstm, merged, writer,
                                        train_op,
                                        valid_test_dict, error_test_dict, max_acc)

            if config.cc_par('keep_run') and all_right and step > 2:
                del lstm  # 清理资源
                del sess
                return True

            if use_error:
                error_test_q_list.clear()
                error_test_pos_r_list.clear()
                error_test_neg_r_list.clear()
                use_error = False
            toogle_line = "<<<<<<<<<<<<<<<<<<<<<<<<<<<<step=%d\n" % step
            # ct.just_log2("test", toogle_line)
            ct.just_log2("info", toogle_line)

            ct.log3(toogle_line)
Esempio n. 13
0
def valid_step(sess, lstm, step, train_op, test_q, test_r, labels, merged,
               writer, dh, model, global_index):
    start_time = time.time()
    feed_dict = {
        lstm.test_input_q: test_q,
        lstm.test_input_r: test_r,
    }
    question = ''
    relations = []
    for _ in test_q:
        v_s_1 = dh.converter.arr_to_text_no_unk(_)
        valid_msg = model + " test_q 1:" + v_s_1
        ct.just_log2("valid_step", valid_msg)
        question = v_s_1
    for _ in test_r:
        v_s_1 = dh.converter.arr_to_text_no_unk(_)
        valid_msg = model + " test_r 1:" + v_s_1
        ct.just_log2("valid_step", valid_msg)
        relations.append(v_s_1)

    error_test_q = []
    error_test_pos_r = []
    error_test_neg_r = []
    fuzzy_boundary = []

    test_q_r_cosin = sess.run([lstm.test_q_r], feed_dict=feed_dict)

    test_q_r_cosin = test_q_r_cosin[0]
    right, wrong, score = [0.0] * 3
    st_list = []  # 各个关系的得分

    # 用第一个和其他的比较,这个是另一种判定正确的办法,
    # for i in range(1, len(test_q_r_cosin)):
    #     compare_res = ct.nump_compare_matix(test_q_r_cosin[0], test_q_r_cosin[i])
    #     ct.print("compare_res:" + str(compare_res))

    for i in range(0, len(test_q_r_cosin)):
        st = ct.new_struct()
        st.index = i
        st.cosine_matix = test_q_r_cosin[i]
        ori_cand_score_mean = np.mean(test_q_r_cosin[i])
        st.score = ori_cand_score_mean
        st_list.append(st)
        # ct.print(ori_cand_score_mean)
    # 将得分和index结合,然后得分排序
    st_list.sort(key=ct.get_key)
    st_list.reverse()
    st_list_sort = st_list  # 取全部 st_list[0:5]

    ct.just_log2("info", "\n ##3 score")
    score_list = []
    for st in st_list_sort:
        # ct.print("index:%d ,score= %f " % (st.index, st.score))
        # mylog.logger.info("index:%d ,score= %f " % (st.index, st.score))
        # 得到得分排序前X的index
        # 根据index找到对应的关系数组
        # 得到得分最高的关系跟labels做判断是否是正确答案,加入统计
        better_index = st.index
        # 根据对应的关系数组找到对应的文字
        r1 = dh.converter.arr_to_text_no_unk(test_r[better_index])
        # ct.print(r1)
        ct.just_log2(
            "info", "step:%d st.index:%d,score:%f,r:%s" %
            (step, st.index, st.score, r1))
        if st.index == 0:
            _tmp_right = 1
        else:
            _tmp_right = 0
        # 训练的epoches步骤,R的index,得分,是否正确,关系,字表面特征
        score_list.append(
            "%d_%d_%f_%s" %
            (st.index, _tmp_right, st.score, r1.replace('_', '-')))
    _tmp_msg1 = "%d\t%s\t%d\t%s\t%s" % (step, model, global_index, question,
                                        '\t'.join(score_list))
    ct.just_log2("logistics", _tmp_msg1)
    # 记录到单独文件

    is_right = False
    msg = " win r =%d  " % st_list_sort[0].index
    ct.log3(msg)
    if st_list_sort[0].index == 0:
        ct.print(
            "================================================================ok"
        )
        is_right = True
    else:
        # todo: 在此记录该出错的题目和积分比pos高的neg关系
        # q,pos,neg
        # error_test_q.append()
        # 找到
        for st in st_list_sort:
            # 在此记录st list的neg
            if st.index == 0:
                break
            else:
                error_test_neg_r.append(test_r[st.index])
                error_test_q.append(test_q[0])
                error_test_pos_r.append(test_r[0])
        ct.print(
            "================================================================error"
        )
        ct.just_log2("info", "!!!!! error %d  " % step)
    ct.just_log2("info", "\n =================================end\n")

    # 在这里增加跳变检查,通过一个文件动态判断是否执行
    # 实际是稳定的模型来执行
    run = ct.file_read_all_lines_strip('config')[0] == '1'
    ct.print("run %s " % run, 'info')
    maybe_list = []
    if run:
        st_list_sort = list(st_list_sort)
        for index in range(len(st_list_sort) - 1):
            # if index == len(st_list_sort) :
            #     continue

            space = st_list_sort[index].score - st_list_sort[index + 1].score
            maybe_list.append(st_list_sort[index])

            if space > config.skip_threshold():
                break

        # 判断是否在其中
        pos_in_it = False
        for index in range(len(maybe_list)):
            item = maybe_list[index]
            # 输出相关的相近属性,并记录是否在其中,并作出全局准确率预测
            if item.index == 0 and pos_in_it == False:
                pos_in_it = True
            better_index = item.index
            r1 = dh.converter.arr_to_text_no_unk(test_r[better_index])
            item.relation = r1
            msg1 = "step:%d st.index:%d,score:%f,r:%s" % (step, item.index,
                                                          item.score, r1)
            item.msg1 = msg1
            maybe_list[index] = item
            ct.print(msg1, "maybe")

        if pos_in_it:
            maybe_dict['r'] += 1
        else:
            maybe_dict['e'] += 1
        acc0 = maybe_dict['r'] / (maybe_dict['r'] + maybe_dict['e'])
        # ct.print("%f pos_in_it  %s" % (acc0, pos_in_it), "maybe")
        # ct.print("\n", "maybe")

    time_elapsed = time.time() - start_time
    time_str = datetime.datetime.now().isoformat()
    ct.print("%s: step %s,  score %s, is_right %s, %6.7f secs/batch" %
             (time_str, step, score, str(is_right), time_elapsed))
    return is_right, error_test_q, error_test_pos_r, error_test_neg_r, maybe_list
Esempio n. 14
0
accuracy = tf.equal(loss, 0)

test_x = tf.placeholder(dtype=tf.float32, shape=[None, X_SIZE], name='test_input')  # 784 个  像素
test_y = tf.placeholder(dtype=tf.float32, shape=[None, Y_SIZE], name='test_output')  # 10 分类
test_right_index = tf.placeholder(dtype=tf.int32, name='test_right_index')  # 正确的分类的index
test_z = tf.add(tf.matmul(test_x, w), b)  # 得分
test_z1 = tf.reduce_mean(test_z, 1)  # 取均值
test_z_max = tf.reduce_max(test_z1)  # 取出最大的项
# # 正确分数
test_z_right = tf.gather(test_z1, test_right_index)
# loss = tf.maximum(0.0, tf.reduce_max(z_right) - z_max)
# loss = tf.multiply(z_max - tf.reduce_max(z_right),10000)  # loss = 最大者 与 正确答案的差距
test_loss = test_z_max - tf.reduce_max(test_z_right)  # loss = 最大者 与 正确答案的差距
test_accuracy = tf.equal(test_loss, 0)

ct.print("lr %s size %s f1= %s" % (lr, batch_size, f1))


def run_step1(data, model):
    total_loss = 0.0
    total_acc = 0.0
    total = 0
    gc_valid = lh.batch_iter(data, batch_size)
    error_count = 0
    right_count = 0

    for gc_valid_item in gc_valid:
        total += 1
        x1 = gc_valid_item[1]
        y1 = gc_valid_item[2]
        p1 = gc_valid_item[3]
Esempio n. 15
0
    def batch_iter(self,
                   data,
                   batch_size,
                   mode='train',
                   transform=False,
                   random_index=True):
        total = len(data)
        if random_index:
            shuffle_indices = np.random.permutation(np.arange(total))  # 打乱样本下标
        else:
            shuffle_indices = np.arange(total)  # 打乱样本下标

        info1 = "q total:%d ; epohches-size:%s " % (total,
                                                    len(data) // batch_size)
        ct.print(info1, 'info')
        x_new = []
        y_new = []
        z_new = []
        p_new = []

        rith_answer = 0
        right_index = 0
        for list_index in range(total):
            index = -1
            data_current = data[shuffle_indices[list_index]]
            for ts in data_current[2]:
                index += 1
                x_new.append(data_current[1])
                # 0____1____8.718868____9.006550____机械设计基础____8
                # t1 = (index, right1, score, relation,right,z_score)
                y_new.append(
                    (float(ts[4]), float(ts[5]), float(ts[6])))  # 继续遍历
                z_new.append(ts[1])
                p_new.append(ts[2])
                # ts

                # 问题  Z分数 NN得分
                if int(ts[1][0]) == 1:
                    right_index = index
                    rith_answer += 1
                msg = "%s\t%s\t%s\t%s\t%s" % (data_current[1], ts[5], ts[2],
                                              ts[3], ts[4])
                ct.print(msg, 'debug1')

            # if list_index % batch_size == 0 and list_index != 0:
            x_return = x_new.copy()  # 问题
            y_return = y_new.copy()  # 数据
            z_return = z_new.copy()  # 标签
            p_return = p_new.copy()  # 属性

            x_new.clear()
            y_new.clear()
            z_new.clear()
            p_new.clear()
            if transform:
                y_return_new = self.min_max_scaler.fit_transform(y_return)
            else:
                y_return_new = np.array(y_return)
            # yield np.array(x_return), np.array(y_return), np.array(y_return_new), np.array(p_return), right_index
            yield np.array(x_return), y_return_new, np.array(
                z_return), np.array(p_return), right_index, data_current[5]
Esempio n. 16
0
def main():
    with tf.device("/gpu"):
        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        sess = tf.Session(config=session_conf)
        now = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
        #  重要的,是否恢复模型,loss的部分;属性的数目
        model = FLAGS.mode
        test_style = True
        ct.print("tf:%s should be 1.2.1 model:%s " %
                 (str(tf.__version__), model))  # 1.2.1
        ct.print("mark:%s " % config.cc_par('mark'), 'mark')  # 1.2.1
        ct.just_log2("info", now)
        ct.just_log2("result", now)
        ct.just_log2("info", get_config_msg())
        ct.print(get_config_msg(), "mark")
        ct.just_log3(
            "test_check",
            "mode\tid\tglobal_id\tglobal_id_in_origin\tquestion\tentity\tpos\tanswer\tr1\tr2\tr3\n"
        )
        ct.log3(now)
        msg1 = "t_relation_num:%d  train_part:%s loss_part:%s" % \
               (config.cc_par('t_relation_num'),config.cc_par('train_part'), config.cc_par('loss_part'))
        ct.print(msg1)
        msg1 = 'restrore:%s use_alias_dict:%s' % (
            config.cc_par('restore_model'), config.cc_par('use_alias_dict'))
        ct.print(msg1)
        if config.cc_par('restore_model'):
            ct.print(config.cc_par('restore_path'))

        embedding_weight = None
        error_test_dict = dict()
        valid_test_dict = dict()
        # 1 读取所有的数据,返回一批数据标记好的数据{data.x,data.label}
        dh = data_helper.DataClass(model, "test")
        if FLAGS.word_model == "word2vec_train":
            embedding_weight = dh.embeddings

        # 3 构造模型LSTM类
        # loss_type = "pair"
        discriminator = Discriminator(
            max_document_length=dh.max_document_length,  # timesteps
            word_dimension=FLAGS.word_dimension,  # 一个单词的维度
            vocab_size=dh.converter.
            vocab_size,  # embedding时候的W的大小embedding_size
            rnn_size=FLAGS.rnn_size,  # 隐藏层大小
            model=model,
            need_cal_attention=config.cc_par('d_need_cal_attention'),
            need_max_pooling=FLAGS.need_max_pooling,
            word_model=FLAGS.word_model,
            embedding_weight=embedding_weight,
            need_gan=True,
            first=True)

        # generator = Generator(
        #     max_document_length=dh.max_document_length,  # timesteps
        #     word_dimension=FLAGS.word_dimension,  # 一个单词的维度
        #     vocab_size=dh.converter.vocab_size,  # embedding时候的W的大小embedding_size
        #     rnn_size=FLAGS.rnn_size,  # 隐藏层大小
        #     model=model,
        #     need_cal_attention=config.cc_par('g_need_cal_attention'), # 不带注意力玩
        #     need_max_pooling=FLAGS.need_max_pooling,
        #     word_model=FLAGS.word_model,
        #     embedding_weight=embedding_weight,
        #     need_gan=True, first=False)

        ct.print("max_document_length=%s,vocab_size=%s " %
                 (str(dh.max_document_length), str(dh.converter.vocab_size)))
        # 初始化
        init = tf.global_variables_initializer()
        merged = tf.summary.merge_all()
        with sess.as_default():
            writer = tf.summary.FileWriter(ct.log_path() + "\\log\\",
                                           sess.graph)
            sess.run(init)
            loss_dict = dict()
            loss_dict['loss'] = 0
            loss_dict['pos'] = 0
            loss_dict['neg'] = 0

            # 如果需要恢复则恢复
            if config.cc_par('restore_model'):
                saver = tf.train.Saver(tf.global_variables(),
                                       max_to_keep=FLAGS.num_checkpoints)
                save_path = config.cc_par('restore_path')
                ct.print('restore:%s' % save_path, 'model')
                saver.restore(sess, config.cc_par('restore_path'))

            # 1 NER 部分1
            print('加载别名词典:')
            dh.bh.stat_dict('../data/nlpcc2016/4-ner/extract_entitys_all.txt')
            dh.bh.init_ner(f_in2='../data/nlpcc2016/4-ner/extract_e/e1.tj.txt')

            print('input:')
            line = '红楼梦的作者是谁?'  # input()
            _best_p, _best_s = ner_rel_analyisis(dh, discriminator, line,
                                                 sess)  # 2 NER LSTM 识别
            hh_dh = dh
            hh_discriminator = discriminator
            hh_sess = sess
            print(_best_s)
            print(_best_p)
            return hh_dh, hh_discriminator, hh_sess
Esempio n. 17
0
def log_error_questions(dh, model, _1, _2, _3, error_test_dict,
                        maybe_list_list, acc, maybe_global_index_list):
    ct.just_log2("test_error",
                 '\n--------------------------log_test_error:%d\n' % len(_1))
    skip_flag = ''
    for i in range(len(_1)):  # 问题集合
        v_s_1 = dh.converter.arr_to_text_no_unk(_1[i])
        valid_msg1 = model + " test_q 1:" + v_s_1
        flag = v_s_1

        v_s_2 = dh.converter.arr_to_text_no_unk(_2[i])
        valid_msg2 = model + " test_r_pos :" + v_s_2

        v_s_3 = dh.converter.arr_to_text_no_unk(_3[i])
        valid_msg3 = model + " test_r_neg :" + v_s_3

        if skip_flag != flag:  # 新起一个问题
            skip_flag = flag
            if valid_msg1 in error_test_dict:
                error_test_dict[valid_msg1] += 1
            else:
                error_test_dict[valid_msg1] = 1

            # ct.just_log2("test_error", '\n')
            ct.just_log2("test_error",
                         valid_msg1 + ' %s' % str(error_test_dict[valid_msg1]))
            ct.just_log2("test_error", valid_msg2)

        # else:
        ct.just_log2("test_error", valid_msg3)
    ct.just_log2("test_error", '--------------%d' % len(_1))
    ct.print("==========%s" % model, "maybe_possible")

    # 再记录一次 出错问题的排序
    tp = ct.sort_dict(error_test_dict)
    ct.just_log2('error_count', "\n\n")
    for t in tp:
        ct.just_log2('error_count', "%s\t%s" % (t[0], t[1]))

    # 记录
    maybe_tmp_dict = dict()
    maybe_tmp_dict['r'] = 0
    maybe_tmp_dict['e'] = 0
    maybe_tmp_dict['m1'] = 0  # 错误中 在可能列表中 找到
    maybe_tmp_dict['m2'] = 0  # 错误中 在可能列表中 没找到的
    index = -1
    for maybe_list in maybe_list_list:
        index += 1
        pos_in_it = False
        for item in maybe_list:
            # 输出相关的相近属性,并记录是否在其中,并作出全局准确率预测
            if item.index == 0 and pos_in_it == False:
                pos_in_it = True

                # ct.print(item.msg1, "maybe1")
                #
                # if not pos_in_it:
                #     ct.print(item.msg1, "maybe2")

                # 记录那些在记录中 且不是 0 的

        is_right = False
        if maybe_list[0].index == 0:
            is_right = True

        if pos_in_it:
            maybe_tmp_dict['r'] += 1
            if is_right == False:
                maybe_tmp_dict['m1'] += 1
                maybe_r_list = [x.relation for x in maybe_list]
                msg = "%d\t%s" % (maybe_global_index_list[index],
                                  '\t'.join(maybe_r_list))
                if maybe_global_index_list[index] != -1 and msg != '':
                    ct.print(msg, "maybe_possible")
        else:
            maybe_tmp_dict['e'] += 1

        # ct.print("pos_in_it  %s" % (pos_in_it), "maybe1")
        # ct.print("\n", "maybe1")
        # ct.print("\n", "maybe2")

    total = (maybe_tmp_dict['r'] + maybe_tmp_dict['e'])
    acc0 = maybe_tmp_dict['r'] / total
    maybe_canget = maybe_tmp_dict['m1'] / total
    msg = "==== %s %f 正确答案数(%d)/总数(%d):%f;候补(%d)/总数:%f " \
          % (model, acc, maybe_tmp_dict['r'], total, acc0, maybe_tmp_dict['m1'], maybe_canget)
    # ct.print(msg, "maybe1")
    ct.print(msg, "maybe_possible")
    # ct.print("\n---------------------------", "maybe1")

    return error_test_dict
Esempio n. 18
0
def ner_rel_analyisis(dh, discriminator, line, sess):
    res = dh.bkt.ner_q(dh.bh, line)
    s_list = pre_ner(dh, res, 2)  # 得到排名前3的实体
    # 准备数据
    item = build_spo(dh, line, s_list)
    q_origin = item['q_']
    q_origin_for_s = item['q_s']
    q_origin_for_r = item['q_p']
    q_origin_for_a = item['q_a']
    p_neg = item['p_neg']  # 候选的属性
    s_neg = item['s_neg']  # 候选的实体
    a_neg = item['a_neg']  # 候选的答案
    # ner_labels = item['ner_labels']
    # rel_labels = item['rel_labels']
    feed_dict = dict()
    feed_dict[discriminator.test_input_q] = q_origin_for_r  # 属性的
    feed_dict[discriminator.test_input_r] = p_neg
    feed_dict[discriminator.ner_test_input_q] = q_origin_for_s  # 用于实体识别的
    feed_dict[discriminator.ner_test_input_r] = s_neg
    feed_dict[discriminator.ans_test_input_q] = q_origin_for_a
    feed_dict[discriminator.ans_test_input_r] = a_neg
    [test_q_r_cosin, _test_q_r, _ner_test_q_r] = \
        sess.run([discriminator.q_r_ner_cosine, discriminator.test_q_r,
                  discriminator.ner_test_q_r], feed_dict=feed_dict)
    # 构建得分结构体
    st_list = []  # 各个关系的得分
    mean_num = 2
    for i in range(0, len(test_q_r_cosin)):
        st = ct.new_struct()
        st.index = i
        # st.label = labes[i]  # 对错
        # st.ner_label = ner_labels[i]
        # st.rel_label = rel_labels[i]
        st.cosine_matix = test_q_r_cosin[i]
        st.score = test_q_r_cosin[i] / mean_num
        st.r_score = _test_q_r[i]
        st.ner_score = _ner_test_q_r[i]
        st_list.append(st)
    # print(item)
    # 将得分和index结合,然后得分排序
    st_list.sort(key=ct.get_key)
    st_list.reverse()
    st_list_sort = st_list  # 取全部 st_list[0:5]
    # 取得最佳的NER
    ner_stlist = st_list.copy()
    ner_stlist.sort(key=ct.get_ner_key)
    ner_stlist.reverse()
    _s1 = dh.converter.arr_to_text_no_unk(s_neg[ner_stlist[0].index])
    ct.print("ner:%s" % _s1)
    # 取得最佳的Relation
    r_stlist = st_list.copy()
    r_stlist.sort(key=ct.get_r_key)
    r_stlist.reverse()
    _p1 = dh.converter.arr_to_text_no_unk(p_neg[r_stlist[0].index])
    ct.print("rel:%s" % _p1)
    index = -1
    _best_s = ''
    _best_p = ''
    for st in st_list_sort:
        index += 1
        # ct.print("index:%d ,score= %f " % (st.index, st.score))
        # mylog.logger.info("index:%d ,score= %f " % (st.index, st.score))
        # 得到得分排序前X的index
        # 根据index找到对应的关系数组
        # 得到得分最高的关系跟labels做判断是否是正确答案,加入统计
        better_index = st.index
        # 根据对应的关系数组找到对应的文字
        _q1 = dh.converter.arr_to_text_no_unk(q_origin_for_r[better_index])
        _s1 = dh.converter.arr_to_text_no_unk(s_neg[better_index])
        _p1 = dh.converter.arr_to_text_no_unk(p_neg[better_index])

        _msg = "st.index:%d,score:%f ner:%f rel:%f,q:%s s:%s  r:%s  " % (
            better_index,
            st.score,
            st.ner_score,
            st.r_score,
            _q1,
            _s1,
            _p1,
        )
        ct.just_log2("info", _msg)
        ct.print(_msg)
        if index == 0:
            _best_s = _s1
            _best_p = _p1

            # label_s_pos.append(cand_s_neg_item == r_pos1)  # s 相同
            # label_p_pos.append(_cand_ps_neg_item == entity1)  # p 相同
    return _best_p, _best_s
Esempio n. 19
0
    if False:
        dh = DataClass("cc")
        dh.check_spo()

    # 输入一个问句输出对应的NER 前3
    # 1 分词
    if False:
        print('加载别名词典:')
        bkh = baike_helper()
        bkt = baike_test()
        bkh.stat_dict('../data/nlpcc2016/4-ner/extract_entitys_all.txt')
        bkh.init_ner(f_in2='../data/nlpcc2016/4-ner/extract_e/e1.tj.txt')
        print('input:')
        line = '民生路有哪些结点?'  # input()
        res = bkt.ner_q(bkh, line)
        ct.print(res)
        # 排序
        res2 = bkt.sort_sentence(res, bkh)
        ct.print(res2)
        # 实体抽取《》
        list1_new = [
            baike_helper.entity_re_extract_one_repeat(ct.clean_str_zh2en(x))
            for x in res2
        ]
        # 去掉重复
        list1_new = ct.list_no_repeat(list1_new)  # 去掉重复
        # 去掉包含
        # 5.8.3 去掉词语包含试试 有一首歌叫	有一首歌	一首歌
        if True:
            # 能略微提高
            list1_new_2 = []
Esempio n. 20
0
def main(_):
    # prepare_data()
    # FLAGS.start_string = FLAGS.start_string.decode('utf-8')
    # converter = TextConverter(filename=FLAGS.converter_path)
    if os.path.isdir(FLAGS.checkpoint_path):
        FLAGS.checkpoint_path = \
            tf.train.latest_checkpoint(FLAGS.checkpoint_path)

    model_path = os.path.join('model', FLAGS.name)
    if os.path.exists(model_path) is False:
        os.makedirs(model_path)
    model = 'ner'
    dh = data_helper.DataClass(model)
    train_batch_size = 1
    # g = dh.batch_iter_char_rnn(train_batch_size)  # (FLAGS.num_seqs, FLAGS.num_steps)
    embedding_weight = dh.embeddings

    model = CharRNN(dh.converter.vocab_size,  # 词汇表大小 从其中生成所有候选
                    num_seqs=train_batch_size,  # FLAGS.num_seqs,  # ? 一个batch 的 句子 数目
                    num_steps=dh.max_document_length,  # FLAGS.num_steps,  # 一个句子的长度
                    lstm_size=FLAGS.lstm_size,
                    num_layers=FLAGS.num_layers,
                    learning_rate=FLAGS.learning_rate,
                    train_keep_prob=FLAGS.train_keep_prob,
                    use_embedding=FLAGS.use_embedding,
                    embedding_size=FLAGS.embedding_size,
                    embedding_weight=embedding_weight,
                    sampling=True,
                    dh=dh
                    )

    model.load(FLAGS.checkpoint_path)
    # cs = []
    # cs.append('♠是什么类型的产品')
    # cs.append('♠是谁')
    # cs.append('♠是哪个公司的长度')
    f1 = '../data/nlpcc2016/6-answer/q.rdf.ms.re.v1.txt'
    f3 = '../data/nlpcc2016/4-ner/extract_entitys_all_tj.v1.txt'
    f4 = '../data/nlpcc2016/4-ner/extract_entitys_all_tj.sort_by_ner_lstm.v1.txt'
    f1s = ct.file_read_all_lines_strip(f1)
    f3s = ct.file_read_all_lines_strip(f3)
    f1s_new = []
    f3s_new = []
    for i in range(len(f1s)):
        # if str(f1s[i]).__contains__('NULL'):
        #     continue
        f1s_new.append(f1s[i])
        f3s_new.append(f3s[i])

    # 过滤NULL
    # 获取候选实体逐个去替代和判断

    # cs.append('立建候时么什是♠')
    # 读取出所有候选实体并打分取出前3 看准确率

    f4s = []
    _index = -1
    for l1 in f1s_new:  # 遍历每个问题
        _index += 1
        replace_qs = []
        for l3 in f3s_new[_index].split('\t'):
            q_1 = str(l1).split('\t')[0].replace(l3, '♠')
            replace_qs.append((q_1, l3))
        entitys = []
        for content, l3 in replace_qs:
            # content = input("input:")
            start = dh.convert_str_to_indexlist_2(content, False)

            # arr = model.sample(FLAGS.max_length, start, dh.converter.vocab_size,dh.get_padding_num())
            # #converter.vocab_size
            r1, score_list = model.judge(start, dh.converter.vocab_size)
            entitys.append((l3, r1))
            # print(content)
            # print(r1)
            # print(score_list)
            ct.print("%s\t%s\t%s" % (content, l3, r1), 'debug_process')
        entitys.sort(key=lambda x: x[1])
        entitys_new = [x[0] for x in entitys]
        ct.print('\t'.join(entitys_new))
        f4s.append('\t'.join(entitys_new))
    ct.file_wirte_list(f4, f4s)
Esempio n. 21
0
def valid_test_checkpoint(train_step,
                          dh,
                          step,
                          sess,
                          lstm,
                          merged,
                          writer,
                          train_op,
                          valid_test_dict,
                          error_test_dict,
                          acc_max=0):
    test_batchsize = FLAGS.test_batchsize  # 暂时统一 验证和测试的数目
    # if (train_step + 1) % FLAGS.evaluate_every == 0:
    if True:
        model = "valid"
        train_part = config.cc_par('train_part')
        if train_part == 'relation':
            train_part_1 = dh.train_relation_list_index
        else:
            train_part_1 = dh.train_answer_list_index

        id_list = get_shuffle_indices_test(dh, step, train_part, model,
                                           train_step)

        # if model == "valid":
        #     id_list = ct.get_static_id_list_debug(len(dh.train_question_list_index))
        # else:
        #     id_list = ct.get_static_id_list_debug_test(len(dh.test_question_list_index))

        # id_list = ct.random_get_some_from_list(id_list, FLAGS.evaluate_batchsize)

        acc_valid, error_test_q_list, error_test_pos_r_list, error_test_neg_r_list, maybe_list_list, \
        maybe_global_index_list, questions_ok_dict = \
            valid_batch_debug(sess, lstm, 0, train_op, merged, writer,
                              dh, test_batchsize, dh.train_question_list_index,
                              train_part_1,
                              model, dh.train_question_global_index, train_part, id_list)

        msg = "step:%d train_step %d valid_batchsize:%d  acc:%f " % (
            step, train_step, test_batchsize, acc_valid)
        ct.print(msg)
        ct.just_log2("valid", msg)
        valid_test_dict = log_error_questions(dh, model, error_test_q_list,
                                              error_test_pos_r_list,
                                              error_test_neg_r_list,
                                              valid_test_dict, maybe_list_list,
                                              acc_valid,
                                              maybe_global_index_list)
        # ct.print("===========step=%d"%step, "maybe_possible")

    # if FLAGS.need_test and (train_step + 1) % FLAGS.test_every == 0:
    if True:
        model = "test"
        train_part = config.cc_par('train_part')
        if train_part == 'relation':
            train_part_1 = dh.test_relation_list_index
        else:
            train_part_1 = dh.test_answer_list_index

        id_list = get_shuffle_indices_test(dh, step, train_part, model,
                                           train_step)

        acc_test, _1, _2, _3, maybe_list_list, maybe_global_index_list, questions_ok_dict = \
            valid_batch_debug(sess, lstm, step, train_op, merged, writer,
                              dh, test_batchsize, dh.test_question_list_index,
                              train_part_1, model, dh.test_question_global_index, train_part, id_list)
        # 测试 集合不做训练 但是将其记录下来

        error_test_dict = log_error_questions(dh, model, _1, _2, _3,
                                              error_test_dict, maybe_list_list,
                                              acc_test,
                                              maybe_global_index_list)

        # _1.clear()
        # _2.clear()
        # _3.clear()
        msg = "step:%d train_step %d valid_batchsize:%d  acc:%f " % (
            step, train_step, test_batchsize, acc_test)
        ct.print(msg)
        ct.just_log2("test", msg)
        ct.print("===========step=%d" % step, "maybe_possible")

    checkpoint(sess, step)

    # 输出记录
    all_right = False
    if acc_test >= acc_max and len(dh.maybe_test_questions) > 0:
        msg_list = []
        acc_max = acc_test
        all_right = True

        for index in dh.maybe_test_questions:
            # try:
            ok = questions_ok_dict[int(index)]
            # except Exception as ee1:
            #     print(ee1)
            if not ok:
                all_right = False
            msg = "%s_%s" % (index, ok)
            msg_list.append(msg)
        acc_str = "%s_%s" % (acc_valid, acc_test)
        ct.just_log(
            config.cc_par('test_ps_result'),
            "%s\t%s\t%s\t%s" % (step, ct.log_path().split('runs\\')[1],
                                acc_str, '\t'.join(msg_list)))

    return valid_test_dict, error_test_dict, acc_max, all_right, error_test_q_list, error_test_pos_r_list, error_test_neg_r_list