def main(_):
    #if FLAGS.use_pingyin:
    vocabulary_word2index, vocabulary_index2word, vocabulary_label2index, vocabulary_index2label = create_vocabulary(
        FLAGS.traning_data_path,
        FLAGS.vocab_size,
        name_scope=FLAGS.model_name,
        tokenize_style=FLAGS.tokenize_style)
    vocab_size = len(vocabulary_word2index)
    print("cnn_model.vocab_size:", vocab_size)
    num_classes = len(vocabulary_index2label)
    print("num_classes:", num_classes)
    train, valid, test, true_label_percent = load_data(
        FLAGS.traning_data_path,
        vocabulary_word2index,
        vocabulary_label2index,
        FLAGS.sentence_len,
        FLAGS.model_name,
        tokenize_style=FLAGS.tokenize_style)
    trainX1, trainX2, trainBlueScores, trainY = train
    validX1, validX2, validBlueScores, validY = valid
    testX1, testX2, testBlueScores, testY = test
    length_data_mining_features = len(trainBlueScores[0])
    print("length_data_mining_features:", length_data_mining_features)
    #print some message for debug purpose
    print("model_name:", FLAGS.model_name, ";length of training data:",
          len(trainX1), ";length of validation data:", len(testX1),
          ";true_label_percent:", true_label_percent, ";tokenize_style:",
          FLAGS.tokenize_style, ";vocabulary size:", vocab_size)
    print("train_x1:", trainX1[0], ";train_x2:", trainX2[0])
    print("data mining features.length:", len(trainBlueScores[0]),
          "data_mining_features:", trainBlueScores[0], ";train_y:", trainY[0])
    #2.create session.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        #Instantiate Model
        textCNN = DualBilstmCnnModel(
            filter_sizes,
            FLAGS.num_filters,
            num_classes,
            FLAGS.learning_rate,
            FLAGS.batch_size,
            FLAGS.decay_steps,
            FLAGS.decay_rate,
            FLAGS.sentence_len,
            vocab_size,
            FLAGS.embed_size,
            FLAGS.hidden_size,
            FLAGS.is_training,
            model=FLAGS.model_name,
            similiarity_strategy=FLAGS.similiarity_strategy,
            top_k=FLAGS.top_k,
            max_pooling_style=FLAGS.max_pooling_style,
            length_data_mining_features=length_data_mining_features)
        #Initialize Save
        saver = tf.train.Saver()
        if os.path.exists(FLAGS.ckpt_dir + "checkpoint"):
            print("Restoring Variables from Checkpoint.")
            saver.restore(sess, tf.train.latest_checkpoint(FLAGS.ckpt_dir))
            if FLAGS.decay_lr_flag:
                #trainX1, trainX2, trainY = shuffle_data(trainX1, trainX2, trainY)
                for i in range(2):  # decay learning rate if necessary.
                    print(i, "Going to decay learning rate by half.")
                    sess.run(textCNN.learning_rate_decay_half_op)
        else:
            print('Initializing Variables')
            sess.run(tf.global_variables_initializer())
            if not os.path.exists(FLAGS.ckpt_dir):
                os.makedirs(FLAGS.ckpt_dir)

            if FLAGS.use_pretrained_embedding:  #load pre-trained word embedding
                print("===>>>going to use pretrained word embeddings...")
                assign_pretrained_word_embedding(sess, vocabulary_index2word,
                                                 vocab_size, textCNN,
                                                 FLAGS.word2vec_model_path)
        curr_epoch = sess.run(textCNN.epoch_step)
        #3.feed data & training
        number_of_training_data = len(trainX1)
        batch_size = FLAGS.batch_size
        iteration = 0
        best_acc = 0.60
        best_f1_score = 0.20
        weights_dict = init_weights_dict(
            vocabulary_label2index)  #init weights dict.
        for epoch in range(curr_epoch, FLAGS.num_epochs):
            print("Auto.Going to shuffle data")
            trainX1, trainX2, trainBlueScores, trainY = shuffle_data(
                trainX1, trainX2, trainBlueScores, trainY)
            loss, eval_acc, counter = 0.0, 0.0, 0
            for start, end in zip(
                    range(0, number_of_training_data, batch_size),
                    range(batch_size, number_of_training_data, batch_size)):
                iteration = iteration + 1
                input_x1, input_x2, input_bluescores, input_y = generate_batch_training_data(
                    trainX1, trainX2, trainBlueScores, trainY,
                    number_of_training_data, batch_size)
                #input_x1=trainX1[start:end]
                #input_x2=trainX2[start:end]
                #input_bluescores=trainBlueScores[start:end]
                #input_y=trainY[start:end]
                weights = get_weights_for_current_batch(input_y, weights_dict)

                feed_dict = {
                    textCNN.input_x1: input_x1,
                    textCNN.input_x2: input_x2,
                    textCNN.input_bluescores: input_bluescores,
                    textCNN.input_y: input_y,
                    textCNN.weights: np.array(weights),
                    textCNN.dropout_keep_prob: FLAGS.dropout_keep_prob,
                    textCNN.iter: iteration,
                    textCNN.tst: not FLAGS.is_training
                }
                curr_loss, curr_acc, lr, _, _ = sess.run([
                    textCNN.loss_val, textCNN.accuracy, textCNN.learning_rate,
                    textCNN.update_ema, textCNN.train_op
                ], feed_dict)
                loss, eval_acc, counter = loss + curr_loss, eval_acc + curr_acc, counter + 1
                if counter % 100 == 0:
                    print(
                        "Epoch %d\tBatch %d\tTrain Loss:%.3f\tAcc:%.3f\tLearning rate:%.5f"
                        % (epoch, counter, loss / float(counter),
                           eval_acc / float(counter), lr))
                #middle checkpoint
                #if start!=0 and start%(500*FLAGS.batch_size)==0: # eval every 3000 steps.
                #eval_loss, acc,f1_score, precision, recall,_ = do_eval(sess, textCNN, validX1, validX2, validY,iteration)
                #print("【Validation】Epoch %d Loss:%.3f\tAcc:%.3f\tF1 Score:%.3f\tPrecision:%.3f\tRecall:%.3f" % (epoch, acc,eval_loss, f1_score, precision, recall))
                # save model to checkpoint
                #save_path = FLAGS.ckpt_dir + "model.ckpt"
                #saver.save(sess, save_path, global_step=epoch)
            #epoch increment
            print("going to increment epoch counter....")
            sess.run(textCNN.epoch_increment)

            # 4.validation
            print(epoch, FLAGS.validate_every,
                  (epoch % FLAGS.validate_every == 0))

            if epoch % FLAGS.validate_every == 0:
                eval_loss, eval_accc, f1_scoree, precision, recall, weights_label = do_eval(
                    sess, textCNN, validX1, validX2, validBlueScores, validY,
                    iteration, vocabulary_index2word)
                weights_dict = get_weights_label_as_standard_dict(
                    weights_label)
                print("label accuracy(used for label weight):==========>>>>",
                      weights_dict)
                print(
                    "【Validation】Epoch %d\t Loss:%.3f\tAcc %.3f\tF1 Score:%.3f\tPrecision:%.3f\tRecall:%.3f"
                    % (epoch, eval_loss, eval_accc, f1_scoree, precision,
                       recall))
                #save model to checkpoint
                if eval_accc * 1.05 > best_acc and f1_scoree > best_f1_score:
                    save_path = FLAGS.ckpt_dir + "model.ckpt"
                    print("going to save model. eval_f1_score:", f1_scoree,
                          ";previous best f1 score:", best_f1_score,
                          ";eval_acc", str(eval_accc), ";previous best_acc:",
                          str(best_acc))
                    saver.save(sess, save_path, global_step=epoch)
                    best_acc = eval_accc
                    best_f1_score = f1_scoree

                if FLAGS.decay_lr_flag and (epoch != 0 and
                                            (epoch == 1 or epoch == 3
                                             or epoch == 5 or epoch == 8)):
                    #TODO print("Auto.Restoring Variables from Checkpoint.")
                    #TODO saver.restore(sess, tf.train.latest_checkpoint(FLAGS.ckpt_dir))

                    for i in range(2):  # decay learning rate if necessary.
                        print(i, "Going to decay learning rate by half.")
                        sess.run(textCNN.learning_rate_decay_half_op)

        # 5.最后在测试集上做测试,并报告测试准确率 Test
        test_loss, acc_t, f1_score_t, precision, recall, weights_label = do_eval(
            sess, textCNN, testX1, testX2, testBlueScores, testY, iteration,
            vocabulary_index2word)
        print(
            "Test Loss:%.3f\tAcc:%.3f\tF1 Score:%.3f\tPrecision:%.3f\tRecall:%.3f:"
            % (test_loss, acc_t, f1_score_t, precision, recall))
    pass
Example #2
0
def main():
    opt = Options()
    vocabulary_word2index, vocabulary_index2word, vocabulary_label2index, vocabulary_index2label = create_vocabulary(
        "data/atec_nlp_sim_train2.csv",
        opt.vocab_size,
        name_scope=opt.name_scope,
        tokenize_style=opt.tokenize_style)
    vocab_size = len(vocabulary_word2index)
    print("vocab_size:", vocab_size)
    num_classes = len(vocabulary_index2label)
    print("num_classes:", num_classes)
    with open("./cache_SWEM_1/train_valid_test.pik") as f:
        train, valid, test, true_label_percent = pickle.load(f)
    train_q, train_a, _, train_lab = train
    print("train_nums:", len(train_q))
    val_q, val_a, _, val_lab = valid
    test_q, test_a, _, test_lab = test
    wordtoix = vocabulary_word2index
    ixtoword = vocabulary_index2word

    opt.n_words = len(ixtoword)
    # loadpath = "./data/snli.p"
    # x = cPickle.load(open(loadpath, "rb"))
    #
    # train, val, test = x[0], x[1], x[2]
    # wordtoix, ixtoword = x[4], x[5]
    #
    # train_q, train_a, train_lab = train[0], train[1], train[2]
    # val_q, val_a, val_lab = val[0], val[1], val[2]
    # test_q, test_a, test_lab = test[0], test[1], test[2]
    #
    # train_lab = np.array(train_lab, dtype='float32')
    # val_lab = np.array(val_lab, dtype='float32')
    # test_lab = np.array(test_lab, dtype='float32')
    #
    # opt = Options()
    # opt.n_words = len(ixtoword)
    #
    # del x

    print(dict(opt))
    print('Total words: %d' % opt.n_words)

    #若partially use labeled data则进行以下操作,这部分操作什么意思?
    # 目前猜测part_data设置为True时只利用部分训练集,portion就是保留的训练集大小,应该是用于测试模型阶段使用的
    if opt.part_data:
        np.random.seed(123)
        train_ind = np.random.choice(len(train_q),
                                     int(len(train_q) * opt.portion),
                                     replace=False)
        train_q = [train_q[t] for t in train_ind]
        train_a = [train_a[t] for t in train_ind]
        train_lab = [train_lab[t] for t in train_ind]
    #验证训练集和预处理好的词嵌入文件是否对齐
    try:
        params = np.load('./data/snli_emb.p')
        if params[0].shape == (opt.n_words, opt.embed_size):
            print('Use saved embedding.')
            #pdb.set_trace()
            opt.W_emb = np.array(params[0], dtype='float32')
        else:
            print('Emb Dimension mismatch: param_g.npz:' +
                  str(params[0].shape) + ' opt: ' +
                  str((opt.n_words, opt.embed_size)))
            opt.fix_emb = False
    except IOError:
        print('No embedding file found.')
        opt.fix_emb = False

    with tf.device('/gpu:0'):
        #注意训练数据是两批句子,所以x的占位符要成对定义
        x_1_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.maxlen])
        x_2_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.maxlen])
        x_mask_1_ = tf.placeholder(tf.float32,
                                   shape=[opt.batch_size, opt.maxlen])
        x_mask_2_ = tf.placeholder(tf.float32,
                                   shape=[opt.batch_size, opt.maxlen])
        y_ = tf.placeholder(tf.float32, shape=[opt.batch_size, opt.category])
        keep_prob = tf.placeholder(tf.float32)
        #auto_encoder就是模型的定义、模型运行过程中的所有tensor,这个项目将其封装起来了,很值得借鉴的工程技巧
        # 返回的是一些重要的tensor,后面sess.run的时候作为参数传入
        accuracy_, loss_, train_op_, W_emb, logits_ = auto_encoder(
            x_1_, x_2_, x_mask_1_, x_mask_2_, y_, keep_prob, opt)
        merged = tf.summary.merge_all()

    def do_eval(sess, train_q, train_a, train_lab):
        train_correct = 0.0
        # number_examples = len(train_q)
        # print("valid examples:", number_examples)
        eval_loss, eval_accc, eval_counter = 0.0, 0.0, 0
        eval_true_positive, eval_false_positive, eval_true_negative, eval_false_negative = 0, 0, 0, 0
        # batch_size = 1
        weights_label = {}  # weight_label[label_index]=(number,correct)
        weights = np.ones((opt.batch_size))
        kf_train = get_minibatches_idx(len(train_q),
                                       opt.batch_size,
                                       shuffle=True)
        for _, train_index in kf_train:
            train_sents_1 = [train_q[t] for t in train_index]
            train_sents_2 = [train_a[t] for t in train_index]
            train_labels = [train_lab[t] for t in train_index]
            train_labels_array = np.array(train_labels)
            # print("train_labels", train_labels.shape)
            # train_labels = train_labels.reshape((len(train_labels), opt.category))
            train_labels = np.eye(opt.category)[train_labels_array]
            x_train_batch_1, x_train_mask_1 = prepare_data_for_emb(
                train_sents_1, opt)
            x_train_batch_2, x_train_mask_2 = prepare_data_for_emb(
                train_sents_2, opt)

            curr_eval_loss, curr_accc, logits = sess.run(
                [loss_, accuracy_, logits_],
                feed_dict={
                    x_1_: x_train_batch_1,
                    x_2_: x_train_batch_2,
                    x_mask_1_: x_train_mask_1,
                    x_mask_2_: x_train_mask_2,
                    y_: train_labels,
                    opt.weights_label: weights,
                    keep_prob: 1.0
                })
            true_positive, false_positive, true_negative, false_negative = compute_confuse_matrix(
                logits, train_labels
            )  # logits:[batch_size,label_size]-->logits[0]:[label_size]
            # write_predict_error_to_file(start,file_object,logits[0], evalY[start:end][0],vocabulary_index2word,evalX1[start:end],evalX2[start:end])
            eval_loss, eval_accc, eval_counter = eval_loss + curr_eval_loss, eval_accc + curr_accc, eval_counter + 1  # 注意这里计算loss和accc的方法,计算累加值,然后归一化
            weights_label = compute_labels_weights(
                weights_label, logits, train_labels_array
            )  # compute_labels_weights(weights_label,logits,labels)
            eval_true_positive, eval_false_positive = eval_true_positive + true_positive, eval_false_positive + false_positive
            eval_true_negative, eval_false_negative = eval_true_negative + true_negative, eval_false_negative + false_negative
            # weights_label = compute_labels_weights(weights_label, logits, evalY[start:end]) #compute_labels_weights(weights_label,logits,labels)
        print("true_positive:", eval_true_positive, ";false_positive:",
              eval_false_positive, ";true_negative:", eval_true_negative,
              ";false_negative:", eval_false_negative)
        p = float(eval_true_positive) / float(eval_true_positive +
                                              eval_false_positive)
        r = float(eval_true_positive) / float(eval_true_positive +
                                              eval_false_negative)
        f1_score = (2 * p * r) / (p + r)
        print("eval_counter:", eval_counter, ";eval_acc:", eval_accc)
        return eval_loss / float(eval_counter), eval_accc / float(
            eval_counter), f1_score, p, r, weights_label

    max_val_accuracy = 0.
    max_test_accuracy = 0.
    weights_dict = init_weights_dict(
        vocabulary_label2index)  # init weights dict.
    # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=1)
    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    np.set_printoptions(precision=3)
    np.set_printoptions(threshold=np.inf)
    saver = tf.train.Saver()

    with tf.Session(config=config) as sess:
        train_writer = tf.summary.FileWriter(opt.log_path + '/train',
                                             sess.graph)
        test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph)
        sess.run(tf.global_variables_initializer())
        if opt.restore:  #若使用已保存好的参数
            try:
                #pdb.set_trace()
                t_vars = tf.trainable_variables()
                # print([var.name[:-2] for var in t_vars])
                save_keys = tensors_key_in_file(opt.save_path)

                # pdb.set_trace()
                # print(save_keys.keys())
                ss = set([var.name for var in t_vars]) & set(
                    [s + ":0" for s in save_keys.keys()])
                cc = {var.name: var for var in t_vars}
                #pdb.set_trace()

                # only restore variables with correct shape
                ss_right_shape = set(
                    [s for s in ss if cc[s].get_shape() == save_keys[s[:-2]]])

                loader = tf.train.Saver(var_list=[
                    var for var in t_vars if var.name in ss_right_shape
                ])
                loader.restore(sess, opt.save_path)

                print("Loading variables from '%s'." % opt.save_path)
                print("Loaded variables:" + str(ss))

            except:
                print("No saving session, using random initialization")
                sess.run(tf.global_variables_initializer())

        try:
            best_acc = 0
            best_f1_score = 0
            for epoch in range(opt.max_epochs):
                print("Starting epoch %d" % epoch)
                loss, acc, uidx = 0.0, 0.0, 0.0
                kf = get_minibatches_idx(len(train_q),
                                         opt.batch_size,
                                         shuffle=True)  #随机创建minibatch数据
                for _, train_index in kf:
                    uidx += 1
                    sents_1 = [train_q[t]
                               for t in train_index]  #根据索引回到总数据集中寻找相应数据
                    sents_2 = [train_a[t] for t in train_index]
                    x_labels = [train_lab[t] for t in train_index]
                    x_labels_array = np.array(x_labels)
                    # print("x_labels:", x_labels.shape)
                    # 为何要在这里进行reshape,是想进行onehot操作?但是这明显是错误的,((len(x_labels),))怎么能reshape成((len(x_labels),opt.category))
                    # x_labels = x_labels.reshape((len(x_labels),opt.category))
                    # one-hot向量化
                    x_labels = np.eye(opt.category)[x_labels_array]

                    #prepare_data_for_emb函数的作用是什么?初步猜测是把sents中每一个单词替换成相应的索引,然后才能根据索引获取词向量
                    x_batch_1, x_batch_mask_1 = prepare_data_for_emb(
                        sents_1, opt)
                    x_batch_2, x_batch_mask_2 = prepare_data_for_emb(
                        sents_2, opt)
                    weights = get_weights_for_current_batch(
                        list(x_labels_array), weights_dict)

                    _, curr_loss, curr_accuracy = sess.run(
                        [train_op_, loss_, accuracy_],
                        feed_dict={
                            x_1_: x_batch_1,
                            x_2_: x_batch_2,
                            x_mask_1_: x_batch_mask_1,
                            x_mask_2_: x_batch_mask_2,
                            y_: x_labels,
                            opt.weights_label: weights,
                            keep_prob: opt.dropout_ratio
                        })
                    loss, acc = loss + curr_loss, acc + curr_accuracy
                    if uidx % 100 == 0:
                        print(
                            "Epoch %d\tBatch %d\tTrain Loss:%.3f\tAcc:%.3f\t" %
                            (epoch, uidx, loss / float(uidx),
                             acc / float(uidx)))

                if epoch % 1 == 0:
                    # do_eval参数待修改
                    eval_loss, eval_accc, f1_scoree, precision, recall, weights_label = do_eval(
                        sess, train_q, train_a, train_lab)
                    weights_dict = get_weights_label_as_standard_dict(
                        weights_label)
                    # print("label accuracy(used for label weight):==========>>>>", weights_dict)
                    print(
                        "【Validation】Epoch %d\t Loss:%.3f\tAcc %.3f\tF1 Score:%.3f\tPrecision:%.3f\tRecall:%.3f"
                        % (epoch, eval_loss, eval_accc, f1_scoree, precision,
                           recall))
                    # save model to checkpoint
                    if eval_accc > best_acc and f1_scoree > best_f1_score:
                        save_path = opt.ckpt_dir + "/model.ckpt"
                        print("going to save model. eval_f1_score:", f1_scoree,
                              ";previous best f1 score:", best_f1_score,
                              ";eval_acc", str(eval_accc),
                              ";previous best_acc:", str(best_acc))
                        saver.save(sess, save_path, global_step=epoch)
                        best_acc = eval_accc
                        best_f1_score = f1_scoree
            test_loss, acc_t, f1_score_t, precision, recall, weights_label = do_eval(
                sess, test_q, test_a, test_lab)
            print(
                "Test Loss:%.3f\tAcc:%.3f\tF1 Score:%.3f\tPrecision:%.3f\tRecall:%.3f:"
                % (test_loss, acc_t, f1_score_t, precision, recall))

            #每训练valid_freq个minibatch就在训练集、验证集和测试集上计算准确率,并更新最优测试集准确率
            #         if uidx % opt.valid_freq == 0:
            #             train_correct = 0.0
            #             kf_train = get_minibatches_idx(len(train_q), opt.batch_size, shuffle=True)
            #             for _, train_index in kf_train:
            #                 train_sents_1 = [train_q[t] for t in train_index]
            #                 train_sents_2 = [train_a[t] for t in train_index]
            #                 train_labels = [train_lab[t] for t in train_index]
            #                 train_labels = np.array(train_labels)
            #                 # print("train_labels", train_labels.shape)
            #                 # train_labels = train_labels.reshape((len(train_labels), opt.category))
            #                 train_labels = np.eye(opt.category)[train_labels]
            #                 x_train_batch_1, x_train_mask_1 = prepare_data_for_emb(train_sents_1, opt)
            #                 x_train_batch_2, x_train_mask_2 = prepare_data_for_emb(train_sents_2, opt)
            #
            #                 train_accuracy = sess.run(accuracy_,
            #                                           feed_dict={x_1_: x_train_batch_1, x_2_: x_train_batch_2, x_mask_1_: x_train_mask_1, x_mask_2_: x_train_mask_2,
            #                                                      y_: train_labels, keep_prob: 1.0})
            #
            #                 train_correct += train_accuracy * len(train_index)
            #
            #             train_accuracy = train_correct / len(train_q)
            #
            #             # print("Iteration %d: Training loss %f, dis loss %f, rec loss %f" % (uidx,
            #             #                                                                     loss, dis_loss, rec_loss))
            #             print("Train accuracy %f " % train_accuracy)
            #
            #             val_correct = 0.0
            #             is_train = True
            #             kf_val = get_minibatches_idx(len(val_q), opt.batch_size, shuffle=True)
            #             for _, val_index in kf_val:
            #                 val_sents_1 = [val_q[t] for t in val_index]
            #                 val_sents_2 = [val_a[t] for t in val_index]
            #                 val_labels = [val_lab[t] for t in val_index]
            #                 val_labels = np.array(val_labels)
            #                 # val_labels = val_labels.reshape((len(val_labels), opt.category))
            #                 val_labels = np.eye(opt.category)[val_labels]
            #                 x_val_batch_1, x_val_mask_1 = prepare_data_for_emb(val_sents_1, opt)
            #                 x_val_batch_2, x_val_mask_2 = prepare_data_for_emb(val_sents_2, opt)
            #
            #                 val_accuracy = sess.run(accuracy_, feed_dict={x_1_: x_val_batch_1, x_2_: x_val_batch_2,
            #                                                               x_mask_1_: x_val_mask_1, x_mask_2_: x_val_mask_2, y_: val_labels, keep_prob: 1.0})
            #
            #                 val_correct += val_accuracy * len(val_index)
            #
            #             val_accuracy = val_correct / len(val_q)
            #
            #             print("Validation accuracy %f " % val_accuracy)
            #
            #             if val_accuracy > max_val_accuracy:
            #                 max_val_accuracy = val_accuracy
            #
            #                 test_correct = 0.0
            #                 kf_test = get_minibatches_idx(len(test_q), opt.batch_size, shuffle=True)
            #                 for _, test_index in kf_test:
            #                     test_sents_1 = [test_q[t] for t in test_index]
            #                     test_sents_2 = [test_a[t] for t in test_index]
            #                     test_labels = [test_lab[t] for t in test_index]
            #                     test_labels = np.array(test_labels)
            #                     # test_labels = test_labels.reshape((len(test_labels), opt.category))
            #                     test_labels = np.eye(opt.category)[test_labels]
            #                     x_test_batch_1, x_test_mask_1 = prepare_data_for_emb(test_sents_1, opt)
            #                     x_test_batch_2, x_test_mask_2 = prepare_data_for_emb(test_sents_2, opt)
            #
            #                     test_accuracy = sess.run(accuracy_, feed_dict={x_1_: x_test_batch_1, x_2_: x_test_batch_2,
            #                                                                    x_mask_1_: x_test_mask_1, x_mask_2_: x_test_mask_2,
            #                                                                    y_: test_labels, keep_prob: 1.0})
            #
            #                     test_correct += test_accuracy * len(test_index)
            #
            #                 test_accuracy = test_correct / len(test_q)
            #
            #                 print("Test accuracy %f " % test_accuracy)
            #
            #                 max_test_accuracy = test_accuracy
            #
            #     print("Epoch %d: Max Test accuracy %f" % (epoch, max_test_accuracy))
            #
            # print("Max Test accuracy %f " % max_test_accuracy)

        except KeyboardInterrupt:
            print('Training interupted')
            print("Max Test accuracy %f " % max_test_accuracy)