Ejemplo n.º 1
0
    def validate(self,sess,writer,on_QR_weak,epo):
        """
        can only be called by pretrain and finetune
        """
        if on_QR_weak:
            times = 1
        else:
            times = 7
        precision_num, recall_num, f1_num = 0, 0, 0
        for i in range(times):
            if on_QR_weak:
                QR_dev = utils.generate_train_test(self.QR_weak_dev_path, self.word_id, "QA")
            else:
                QR_dev = utils.generate_train_test(self.QR_dev_dir + str(i) + '.csv', self.word_id, "QR")

            Q_dev_batch, R_dev_batch, Y_dev_batch, last_start = utils.batch_triplet_shuffle(
                QR_dev[0], QR_dev[1],
                QR_dev[2], self.batch_size,
                True)
            prediction_weak_dev = []
            y_weak_dev_true = []

            for s in range(len(Y_dev_batch)):
                Loss_dev, pred = sess.run([self.loss, self.predict],
                                          {self.Q_ori: Q_dev_batch[s], self.R_ori: R_dev_batch[s],
                                           self.labels: Y_dev_batch[s]})
                pred = [pre[0] for pre in pred]
                if s == len(Y_dev_batch) - 1:
                    prediction_weak_dev += pred[last_start: self.batch_size]
                    y_weak_dev_true += [pre[0] for pre in Y_dev_batch[s]][last_start: self.batch_size]
                else:
                    prediction_weak_dev += pred
                    y_weak_dev_true += [pre[0] for pre in Y_dev_batch[s]]
            precision, recall, f1, d = mertic(prediction_weak_dev, y_weak_dev_true, self.threshold)
            precision_num += precision
            recall_num += recall
            f1_num += f1
        if on_QR_weak:
            if writer is not None :
                writer.writerow(
                    [epo + 1, 'On QR_weak_dev', round(precision_num / times, 6), round(recall_num / times, 6),
                     round(f1_num / times, 6)])
            print("[Validate on QR_weak dev]: --precision: ", round(precision_num/times,6), "--recall: ", round(recall_num/times,6), "--f1: ", round(f1_num/times,6),'--d:',d)
        else:
            if writer is not None :
                writer.writerow([epo + 1, 'On QR_dev', round(precision_num / times, 6), round(recall_num / times, 6),
                                 f1_num / times])
            print("[Validate on QR dev]: --precision: ", round(precision_num/times,6), "--recall: ", round(recall_num/times,6), "--f1: ", round(f1_num/times,6),'--d:',d)
                          len(Y_batch), "--loss:", Loss_main)
                    step += 1
        # VALID
                Q_dev_batch, R_dev_batch, Y_dev_batch, last_start = preprocess.batch_triplet_shuffle(
                    QR_q_dev, QR_r_dev, QR_y_dev, batch_size, False)
                prediction = []
                for s in range(len(Y_dev_batch)):
                    Loss_dev, pred = sess.run(
                        [main_loss, predict_QR], {
                            Q_ori: Q_dev_batch[s],
                            A_ori: R_dev_batch[s],
                            R_ori: R_dev_batch[s],
                            labels: Y_dev_batch[s]
                        })
                    pred = [pre[0] for pre in pred]
                    if s == len(Y_dev_batch) - 1:
                        prediction += pred[last_start:batch_size]
                    else:
                        prediction += pred
                y_dev_true = [tru[0] for tru in QR_y_dev]
                precision, recall, f1 = mertic(prediction, y_dev_true,
                                               threshold)
                print("[epoch_valid_%d]:" % (epo + 1), "--precision: ",
                      precision, "--recall: ", recall, "--f1: ", f1)
                saver.save(
                    sess, "model/model.ckpt" + "_p-%f_r-%f_f1-%f" %
                    (precision, recall, f1))
                writer.writerow([(m + 1), precision, recall, f1] + prediction)
                csvfile.flush()
csvfile.close()
Ejemplo n.º 3
0
            for s in range(len(Y_weak_dev_batch)):
                Loss_dev, pred = sess.run(
                    [all_loss, predict_QR], {
                        Q_ori: Q_weak_dev_batch[s],
                        R_ori: R_weak_dev_batch[s],
                        labels: Y_weak_dev_batch[s]
                    })
                pred = [pre[0] for pre in pred]
                if s == len(Y_weak_dev_batch) - 1:
                    prediction_weak_dev += pred[last_start:batch_size]
                    y_weak_dev_true += [pre[0] for pre in Y_weak_dev_batch[s]
                                        ][last_start:batch_size]
                else:
                    prediction_weak_dev += pred
                    y_weak_dev_true += [pre[0] for pre in Y_weak_dev_batch[s]]
            precision, recall, f1, _ = mertic(prediction_weak_dev,
                                              y_weak_dev_true, threshold)

            pretrain_writer.writerow(
                [epo + 1, 'On QR_weak_dev', precision, recall, f1])
            print("[Validate on QR_weak dev]: --precision: ", precision,
                  "--recall: ", recall, "--f1: ", f1)

            # validation on QR when training on QR_weak pairs
            Q_dev_batch, R_dev_batch, Y_dev_batch, last_start = utils.batch_triplet_shuffle(
                QR_q_dev, QR_r_dev, QR_y_dev, batch_size, True)
            prediction_dev_onWeak = []
            y_dev_true_onWeak = []
            for j in range(len(Y_dev_batch)):
                Loss_dev, pred = sess.run(
                    [all_loss, predict_QR], {
                        Q_ori: Q_dev_batch[j],
Ejemplo n.º 4
0
    def finetune(self):
        csv_file = open(self.log_dir + '/finetune_log.csv', "a+")
        finetune_writer = csv.writer(csv_file)
        with tf.Session() as sess:
            train_writer = tf.summary.FileWriter(self.log_dir, sess.graph)
            sess.run(self.init_op)
            step = 0
            self.saver.restore(sess, self.restore_model_path)
            print('#########Fine tune with pre-trained weights.#########')
            QR_train = utils.generate_train_test(self.QR_train_path,
                                                 self.word_id, "QR")
            for epo in range(self.fineTune_epochs):
                # train on QR_train
                ratio = 0.5
                Q_train_batch, R_train_batch, Y_train_batch, (
                    l_s_p, batch_slice) = utils.batch_pos_neg(
                        QR_train[0][:self.QR_train_slice],
                        QR_train[1][:self.QR_train_slice],
                        QR_train[2][:self.QR_train_slice], self.batch_size,
                        ratio, True, True)
                prediction_QR_train = []
                y_train_true = []
                for j in range(len(Y_train_batch)):
                    _, Loss_QR_train, pred = sess.run(
                        [self.optimizer, self.loss, self.predict], {
                            self.Q_ori: Q_train_batch[j],
                            self.R_ori: R_train_batch[j],
                            self.labels: Y_train_batch[j]
                        })
                    print("Finetuning epoch:", epo + 1, "--iter:", j + 1, "/",
                          len(Y_train_batch), "--Loss:", Loss_QR_train)
                    pred = [pre[0] for pre in pred]
                    y_train_true += [pre[0] for pre in Y_train_batch[j]]
                    prediction_QR_train += pred
                precision, recall, f1, d = mertic(prediction_QR_train,
                                                  y_train_true, self.threshold)
                print("[Train on QR]: --precision: ", precision, "--recall: ",
                      recall, "--f1: ", f1, "--d:", d)
                # self.validate(sess, finetune_writer, True, epo)
                self.validate(sess, finetune_writer, False, epo)

            # output the probability
            # Q_dev_batch, R_dev_batch, Y_dev_batch, (l_s_p, batch_slice) = utils.batch_pos_neg(QR_q_dev,
            #                                                                                   QR_r_dev,
            #                                                                                   QR_y_dev,
            #                                                                                   batch_size,
            #                                                                                   0.3, True)
            #
            # prediction_dev = []
            # y_dev_true = []
            # q_dev = []
            # r_dev = []
            # for j in range(len(Y_dev_batch)):
            #     Loss_dev, pred = sess.run([all_loss, predict_QR],
            #                               {Q_ori: Q_dev_batch[j], R_ori: R_dev_batch[j], labels: Y_dev_batch[j]})
            #     if j == len(Y_dev_batch) - 1:
            #         y_dev_true += [pre[0] for pre in Y_dev_batch[j]][l_s_p:]
            #         prediction_dev += [pre[0] for pre in pred][l_s_p:]
            #         q_dev += [list(pre) for pre in list(Q_dev_batch)[j]][l_s_p:]
            #         r_dev += [list(pre) for pre in list(R_dev_batch)[j]][l_s_p:]
            #     else:
            #         y_dev_true += [pre[0] for pre in Y_dev_batch[j]]
            #         prediction_dev += [pre[0] for pre in pred]
            #         q_dev += [list(pre) for pre in list(Q_dev_batch)[j]]
            #         r_dev += [list(pre) for pre in list(R_dev_batch)[j]]
            #
            #     for q, r, y, pre in zip(q_dev, r_dev, y_dev_true, prediction_dev):
            #         if y == 1 and pre > 0.5:
            #             print('FP preds, Q: {}, R: {}, sigmoid output: {}'.format(
            #                 utils.id_to_words(id_sequence=q, id_word=id_word),utils.id_to_words(id_sequence=r, id_word=id_word),pre))

        csv_file.close()
        print('Fintuning detail saved in {0}.'.format(self.timestamp))
Ejemplo n.º 5
0
    def pretrain(self):
        saver = tf.train.Saver()
        with tf.Session() as sess:
            train_writer = tf.summary.FileWriter(self.log_dir, sess.graph)
            sess.run(self.init_op)
            step = 0
            csv_file = open(self.log_dir + '/pretrain_log.csv', "a+")
            pretrain_writer = csv.writer(csv_file)
            QR_weak_train = utils.generate_train_test(self.QR_weak_train_path,
                                                      self.word_id, "QA")
            print("#########QR weak training begin.#########")
            for epo in range(self.preTraining_epochs):
                # train on QR weak pairs
                Q_batch, R_batch, Y_batch, last_start = utils.batch_pos_neg(
                    QR_weak_train[0][:self.QR_weak_train_slice],
                    QR_weak_train[1][:self.QR_weak_train_slice],
                    QR_weak_train[2][:self.QR_weak_train_slice],
                    self.batch_size, 0.5, True, True)
                prediction_weak = []
                y_weak_true = []
                for i in range(len(Y_batch)):
                    _QR, Loss_aux, pred_weak, summary = sess.run(
                        [self.optimizer, self.loss, self.predict, self.merged],
                        {
                            self.Q_ori: Q_batch[i],
                            self.R_ori: R_batch[i],
                            self.labels: Y_batch[i]
                        })

                    y_weak_true += [pre[0] for pre in Y_batch[i]]
                    pred_weak = [pre[0] for pre in pred_weak]
                    prediction_weak += pred_weak
                    train_writer.add_summary(summary, step)
                    step += 0
                    print("Pre-training epoch:", epo + 1, "--iter:", i + 1,
                          "/", len(Y_batch), "--Loss:", Loss_aux)
                self.validate(sess, pretrain_writer, True, epo)
                self.validate(sess, pretrain_writer, False, epo)
                # uncomment this to see whether weights have been updated.
                # precision_weak, recall_weak, f1_weak ,_ = mertic(prediction_weak, y_weak_true, threshold)
                # print("[Validate on QR_weak_train ]: --precision: ", precision_weak, "--recall: ", recall_weak, "--f1: ",
                #       f1_weak)
                # writer.writerow([epo, 'On QR_weak_train', precision_weak, recall_weak, f1_weak])

                # validation on QR_weak dev when training on QR_weak pairs

            saver.save(sess, self.preTrained_model)
            print('Pretrained model saved on {}.'.format(self.timestamp))
            csv_file.close()

            # finetune epoches
            if self.finetune_flag:
                print("#########Finetune on QR pairs.#########")
                QR_train = utils.generate_train_test(self.QR_train_path,
                                                     self.word_id, "QR")
                for epo in range(self.fineTune_epochs):
                    # train on QR_train
                    ratio = 0.5
                    Q_train_batch, R_train_batch, Y_train_batch, (
                        l_s_p, batch_slice) = utils.batch_pos_neg(
                            QR_train[0][:self.QR_train_slice],
                            QR_train[1][:self.QR_train_slice],
                            QR_train[2][:self.QR_train_slice], self.batch_size,
                            ratio, True, True)
                    prediction_QR_train = []
                    y_train_true = []
                    for j in range(len(Y_train_batch)):
                        _, Loss_QR_train, pred = sess.run(
                            [self.optimizer, self.loss, self.predict], {
                                self.Q_ori: Q_train_batch[j],
                                self.R_ori: R_train_batch[j],
                                self.labels: Y_train_batch[j]
                            })
                        print("Pre-training epoch:", epo + 1, "--iter:", j + 1,
                              "/", len(Y_train_batch), "--Loss:",
                              Loss_QR_train)

                        pred = [pre[0] for pre in pred]
                        y_train_true += [pre[0] for pre in Y_train_batch[j]]
                        prediction_QR_train += pred
                    precision, recall, f1, d = mertic(prediction_QR_train,
                                                      y_train_true,
                                                      self.threshold)
                    print("[Train on QR]: --precision: ", precision,
                          "--recall: ", recall, "--f1: ", f1, "--d:", d)
                    # self.validate(sess, None,True,epo)
                    self.validate(sess, None, False, epo)
Ejemplo n.º 6
0
    def finetune(self):
        csv_file = open(self.log_dir + '/finetune_log.csv', "a+")
        finetune_writer = csv.writer(csv_file)
        with tf.Session() as sess:
            train_writer = tf.summary.FileWriter(self.log_dir, sess.graph)
            sess.run(self.init_op)
            step = 0
            self.saver.restore(sess, self.restore_model_path)
            print('#########Fine tune with pre-trained weights.#########')
            QR_train = utils.generate_train_test(self.QR_train_path, self.word_id, "QR")
            for epo in range(self.fineTune_epochs):
                # train on QR_train
                ratio = 0.5
                Q_train_batch, R_train_batch, Y_train_batch, (l_s_p, batch_slice) = utils.batch_pos_neg(
                    QR_train[0][:self.QR_train_slice],
                    QR_train[1][:self.QR_train_slice],
                    QR_train[2][:self.QR_train_slice],
                    self.batch_size,
                    ratio, True,
                    True)
                prediction_QR_train = []
                y_train_true = []
                for j in range(len(Y_train_batch)):
                    _, Loss_QR_train, pred = sess.run([self.optimizer, self.loss, self.predict],
                                                      {self.Q_ori: Q_train_batch[j], self.R_ori: R_train_batch[j],
                                                       self.labels: Y_train_batch[j]})
                    print("Finetuning epoch:", epo + 1, "--iter:", j + 1, "/", len(Y_train_batch), "--Loss:",
                          Loss_QR_train)
                    pred = [pre[0] for pre in pred]
                    y_train_true += [pre[0] for pre in Y_train_batch[j]]
                    prediction_QR_train += pred
                precision, recall, f1, d = mertic(prediction_QR_train, y_train_true, self.threshold)
                print("[Train on QR]: --precision: ", precision, "--recall: ", recall, "--f1: ", f1, "--d:", d)

                precision_num, recall_num, f1_num = 0, 0, 0

                # Validate on QR_dev
                # for i in range(7):
                #     QR_dev = utils.generate_train_test(self.QR_dev_dir + str(i) + '.csv', self.word_id, "QR")
                #
                #     Q_dev_batch, R_dev_batch, Y_dev_batch, last_start = utils.batch_triplet_shuffle(QR_dev[0],
                #                                                                                     QR_dev[1],
                #                                                                                     QR_dev[2],
                #                                                                                     self.batch_size,
                #                                                                                     True)
                #     prediction_dev = []
                #     y_dev_true = []
                #     for j in range(len(Y_dev_batch)):
                #
                #         Loss_dev, pred = sess.run([self.loss, self.predict],
                #                                   {self.Q_ori: Q_dev_batch[j], self.R_ori: R_dev_batch[j],
                #                                    self.labels: Y_dev_batch[j]})
                #         pred = [pre[0] for pre in pred]
                #
                #         if j == len(Y_dev_batch) - 1:
                #
                #             prediction_dev += pred
                #             y_dev_true += [pre[0] for pre in Y_dev_batch[j]]
                #         else:
                #             prediction_dev += pred
                #             y_dev_true += [pre[0] for pre in Y_dev_batch[j]]
                #     precision, recall, f1, d = mertic(prediction_dev, y_dev_true, self.threshold)
                #     precision_num += precision
                #     recall_num += recall
                #     f1_num += f1
                # print("[Validate on QR]: --precision: ", precision_num / 7, "--recall: ", recall_num / 7, "--f1: ",
                #       f1_num / 7, "--d: ", d)
                self.validate(sess, finetune_writer, False, epo)
        csv_file.close()
        print('Fintuning detail saved in {0}.'.format(self.timestamp))