コード例 #1
0
    def init_model(self):
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        self.sess = tf.Session(config=config)
        self.config = CNNConfig()
        self.cnn = CNN(self.config)
        # self.cnn.setVGG16()

        print('Loading model from file:', self.model_path)
        saver = tf.train.import_meta_graph(self.model_path + '.meta')
        saver.restore(self.sess, self.model_path)
        self.graph = tf.get_default_graph()
        # 从图中读取变量
        self.input_x = self.graph.get_operation_by_name("input_x").outputs[0]
        self.labels = self.graph.get_operation_by_name("labels").outputs[0]
        self.dropout_keep_prob = self.graph.get_operation_by_name(
            "dropout_keep_prob").outputs[0]
        self.score = self.graph.get_operation_by_name('score/Relu').outputs[0]
        self.prediction = self.graph.get_operation_by_name(
            "prediction").outputs[0]
        self.training = self.graph.get_operation_by_name("training").outputs[0]
コード例 #2
0
def predict():
    """
    读取模型,预测商品标题
    :param titles: 列表,商品标题的字符串
    :return: results
    """
    # Test procedure
    # ======================================================
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        # TODO: 读取不同模型,修改此处参数
        # 要读取的模型路径
        checkpoint_dir = os.path.abspath("checkpoints/textcnn")
        # 模型的文件名放在这,不含后缀
        checkpoint_file = os.path.join(checkpoint_dir, "CHAR-RANDOM-25871")
        # 这要加.meta后缀
        saver = tf.train.import_meta_graph(os.path.join(checkpoint_dir, 'CHAR-RANDOM-25871.meta'))
        saver.restore(sess, checkpoint_file)
        graph = tf.get_default_graph()

        # 这里的train_mode参数要和模型一致
        config = CNNConfig('CHAR-RANDOM')
        cnn = TextCNN(config)

        # 从图中读取变量
        input_x = graph.get_operation_by_name("input_x").outputs[0]
        input_y = graph.get_operation_by_name("input_y").outputs[0]
        dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0]
        prediction = graph.get_operation_by_name("output/prediction").outputs[0]
        training = graph.get_operation_by_name("training").outputs[0]
        batch_x=np.load('./Data/Data.npy')[0:10000]
        feed_dict = {
            input_x: batch_x,
            dropout_keep_prob: 1.0,
            training: False
        }
        pre = sess.run(prediction, feed_dict)
        return pre
コード例 #3
0
def train():
    # Training procedure
    # ======================================================
    # 设定最小显存使用量
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        config = CNNConfig('CHAR-RANDOM')
        cnn = TextCNN(config)
        cnn.prepare_data()
        cnn.setCNN()

        print('Setting Tensorboard and Saver...')
        # 设置Saver和checkpoint来保存模型
        # ===================================================
        checkpoint_dir = os.path.join(os.path.abspath("checkpoints"),
                                      "textcnn")
        checkpoint_prefix = os.path.join(checkpoint_dir, cnn.train_mode)
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        saver = tf.train.Saver(tf.global_variables())
        # =====================================================

        # 配置Tensorboard,重新训练时,请将tensorboard文件夹删除,不然图会覆盖
        # ====================================================================
        train_tensorboard_dir = 'tensorboard/textcnn/train/' + config.train_mode
        valid_tensorboard_dir = 'tensorboard/textcnn/valid/' + config.train_mode
        if not os.path.exists(train_tensorboard_dir):
            os.makedirs(train_tensorboard_dir)
        if not os.path.exists(valid_tensorboard_dir):
            os.makedirs(valid_tensorboard_dir)

        # 训练结果记录
        log_file = open(valid_tensorboard_dir + '/log.txt', mode='w')

        merged_summary = tf.summary.merge([
            tf.summary.scalar('Trainloss', cnn.loss),
            tf.summary.scalar('Trainaccuracy', cnn.accuracy)
        ])
        merged_summary_t = tf.summary.merge([
            tf.summary.scalar('Testloss', cnn.loss),
            tf.summary.scalar('Testaccuracy', cnn.accuracy)
        ])
        train_summary_writer = tf.summary.FileWriter(train_tensorboard_dir,
                                                     sess.graph)
        # =========================================================================

        global_step = tf.Variable(0, trainable=False)

        # 保证Batch normalization的执行
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(
                update_ops):  # 保证train_op在update_ops执行之后再执行。
            train_op = tf.train.AdamOptimizer(config.learning_rate).minimize(
                cnn.loss, global_step)

        # 训练步骤
        def train_step(batch_x, batch_y, keep_prob=config.dropout_keep_prob):
            feed_dict = {
                cnn.input_x: batch_x,
                cnn.labels: batch_y,
                cnn.dropout_keep_prob: keep_prob,
                cnn.training: True
            }
            sess.run(train_op, feed_dict=feed_dict)
            step, loss, accuracy, summery = sess.run(
                [global_step, cnn.loss, cnn.accuracy, merged_summary],
                feed_dict={
                    cnn.input_x: batch_x,
                    cnn.labels: batch_y,
                    cnn.dropout_keep_prob: 1.0,
                    cnn.training: False
                })
            t = datetime.datetime.now().strftime('%m-%d %H:%M')
            print('TRAIN  %s: epoch: %d, step: %d, loss: %f, accuracy: %f' %
                  (t, epoch, step, loss, accuracy))
            # 把结果写入Tensorboard中
            train_summary_writer.add_summary(summery, step)

        def test_step(batch_x, batch_y):

            step, loss, accuracy, summery = sess.run(
                [global_step, cnn.loss, cnn.accuracy, merged_summary_t],
                feed_dict={
                    cnn.input_x: batch_x,
                    cnn.labels: batch_y,
                    cnn.dropout_keep_prob: 1.0,
                    cnn.training: False
                })
            t = datetime.datetime.now().strftime('%m-%d %H:%M')
            print('TEST %s: epoch: %d, step: %d, loss: %f, accuracy: %f' %
                  (t, epoch, step, loss, accuracy))
            # 把结果写入Tensorboard中
            train_summary_writer.add_summary(summery, step)
            return accuracy

        print('Start training TextCNN, training mode=' + cnn.train_mode)
        sess.run(tf.global_variables_initializer())

        last = 0
        # Training loop
        for epoch in range(1000000):
            batch_x, batch_y = train_dataset.next_batch(128)
            train_step(batch_x, batch_y, config.dropout_keep_prob)
            if epoch % 10 == 0:
                batch_x, batch_y = test_dataset.next_batch(128)
                accuracy = test_step(batch_x, batch_y)
                if accuracy > last:
                    path = saver.save(sess,
                                      checkpoint_prefix,
                                      global_step=global_step)
                    print("Saved model checkpoint to {}\n".format(path))
                    last = accuracy

        train_summary_writer.close()
        log_file.close()
コード例 #4
0
ファイル: run_cnn.py プロジェクト: Inistlwq/text_similarity
                                x0=default_parameters)
    space = search_result.space
    print('sorted validation result:')
    print(sorted(zip(search_result.func_vals, search_result.x_iters)))
    print('best parameter:')
    print(space.point_to_dict(search_result.x))
    pass


if __name__ == '__main__':

    if len(sys.argv) != 2 or sys.argv[1] not in ['train', 'test', 'tune_hyper']:
        raise ValueError("""usage: python run_cnn.py [train / test / tune_hyper]""")

    print('Configuring CNN model...')
    config = CNNConfig()
    if not os.path.exists(vocab_dir):  # 如果不存在词汇表,重建
        build_vocab(vocab_dir, min_count = config.min_count, vocab_size = config.vocab_size)
    categories, id2cate, cate2id = read_category()
    #pdb.set_trace()
    chars, char2id, id2char = read_vocab(vocab_dir)
    config.vocab_size = len(chars)#避免词的数量达不到设置的
    config.id2char = id2char
    config.num_classes = len(categories)
    model = CNNModel(config)

    if sys.argv[1] == 'train':
        train()
    elif sys.argv[1] == 'test':
        if not os.path.exists(encoder_save_path):
            raise Exception(encoder_save_path + ' not found.')
コード例 #5
0
def train():
    # Training procedure
    # ======================================================
    # 设定最小显存使用量
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        config = CNNConfig()
        cnn = CNN(config)
        train_init_op, valid_init_op = cnn.prepare_data()
        cnn.setVGG13()

        print('Setting Tensorboard and Saver...')
        # 设置Saver和checkpoint来保存模型
        # ===================================================
        checkpoint_dir = os.path.abspath("checkpoints")
        checkpoint_prefix = checkpoint_dir + '/cnn'
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        saver = tf.train.Saver(tf.global_variables())
        # =====================================================

        # 配置Tensorboard,重新训练时,请将tensorboard文件夹删除,不然图会覆盖
        # ====================================================================
        train_tensorboard_dir = 'tensorboard/train/'
        valid_tensorboard_dir = 'tensorboard/valid/'
        if not os.path.exists(train_tensorboard_dir):
            os.makedirs(train_tensorboard_dir)
        if not os.path.exists(valid_tensorboard_dir):
            os.makedirs(valid_tensorboard_dir)

        # 训练结果记录
        log_file = open(valid_tensorboard_dir + '/log.csv',
                        mode='w',
                        encoding='utf-8')
        log_file.write(','.join(['epoch', 'loss', 'lwlrap']) + '\n')

        merged_summary = tf.summary.merge(
            [tf.summary.scalar('loss', cnn.loss)])

        train_summary_writer = tf.summary.FileWriter(train_tensorboard_dir,
                                                     sess.graph)
        # =========================================================================

        global_step = tf.Variable(0, trainable=False)
        learning_rate = tf.train.exponential_decay(config.learning_rate,
                                                   global_step,
                                                   decay_steps=2000,
                                                   decay_rate=0.94,
                                                   staircase=False)

        # 保证Batch normalization的执行
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(
                update_ops):  # 保证train_op在update_ops执行之后再执行。
            train_op = tf.train.GradientDescentOptimizer(
                learning_rate).minimize(cnn.loss, global_step)

        # 训练步骤
        def train_step(keep_prob=config.dropout_keep_prob):
            feed_dict = {cnn.dropout_keep_prob: keep_prob, cnn.training: True}
            _, step, loss, y_pred, y_true, summery = sess.run(
                [
                    train_op, global_step, cnn.loss, cnn.prediction,
                    cnn.input_y, merged_summary
                ],
                feed_dict=feed_dict)
            # 计算lwlrap
            lrap = lwlrap.calculate_overall_lwlrap_sklearn(truth=y_true,
                                                           scores=y_pred)
            # per_class_lwlrap, weight_per_class = calculate_per_class_lwlrap(truth=batch_y, scores=y_pred)
            # mean_lwlrap = np.sum(per_class_lwlrap * weight_per_class)
            t = datetime.datetime.now().strftime('%m-%d %H:%M')
            print('%s: epoch: %d, step: %d, loss: %f, lwlrap: %f' %
                  (t, epoch, step, loss, lrap))
            # 把结果写入Tensorboard中
            train_summary_writer.add_summary(summery, step)

        train.current_lrap = 0.0
        train.best_lrap = 0.0
        train.patience = 5
        train.current_iter = 0

        # 验证步骤
        def valid_step():
            # 把valid_loss和valid_accuracy归0
            y_true = []
            y_pred = []
            i = 0
            losses = 0.0
            while True:
                try:
                    feed_dict = {
                        cnn.dropout_keep_prob: 1.0,
                        cnn.training: False
                    }
                    loss, pred, true = sess.run(
                        [cnn.loss, cnn.prediction, cnn.input_y], feed_dict)
                    y_pred.extend(pred)
                    y_true.extend(true)
                    losses += loss
                    i += 1
                except tf.errors.OutOfRangeError:
                    # 遍历完验证集,计算评估
                    valid_loss = losses / i
                    y_true = np.asarray(y_true)
                    y_pred = np.asarray(y_pred)
                    lrap = lwlrap.calculate_overall_lwlrap_sklearn(
                        truth=y_true, scores=y_pred)
                    t = datetime.datetime.now().strftime('%m-%d %H:%M')
                    log = '%s: epoch %d, validation loss: %0.6f, lwlrap: %0.6f' % (
                        t, epoch, valid_loss, lrap)
                    print(log)
                    log_file.write(log + '\n')
                    time.sleep(3)
                    return

        print('Start training CNN...')
        sess.run(tf.global_variables_initializer())
        # Training loop
        for epoch in range(config.epoch_num):
            if cnn.use_img_input:
                sess.run(train_init_op)
            else:
                sess.run(train_init_op,
                         feed_dict={
                             cnn.features_placeholder: cnn.features,
                             cnn.labels_placeholder: cnn.labels
                         })
            while True:
                try:
                    train_step(config.dropout_keep_prob)
                except tf.errors.OutOfRangeError:
                    # 初始化验证集迭代器
                    if cnn.use_img_input:
                        sess.run(valid_init_op)
                    else:
                        sess.run(valid_init_op,
                                 feed_dict={
                                     cnn.features_placeholder: cnn.features,
                                     cnn.labels_placeholder: cnn.labels
                                 })
                    # 计算验证集准确率
                    valid_step()
                    break
                except KeyboardInterrupt:
                    train_summary_writer.close()
                    log_file.close()
                    path = saver.save(sess,
                                      checkpoint_prefix,
                                      global_step=global_step)
                    print("Saved model checkpoint to {}\n".format(path))
                    return
        train_summary_writer.close()
        log_file.close()
        # 训练完成后保存参数
        path = saver.save(sess, checkpoint_prefix, global_step=global_step)
        print("Saved model checkpoint to {}\n".format(path))
コード例 #6
0
def train():
    # Training procedure
    # ======================================================
    # 设定最小显存使用量
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        config = CNNConfig()
        cnn = CNN(config)
        cnn.setVGG16()

        print('Setting Tensorboard and Saver...')
        # 设置Saver和checkpoint来保存模型
        # ===================================================
        checkpoint_dir = os.path.join(os.path.abspath("checkpoints"), "cnn")
        checkpoint_prefix = os.path.join(checkpoint_dir)
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        saver = tf.train.Saver(tf.global_variables())
        # =====================================================

        # 配置Tensorboard,重新训练时,请将tensorboard文件夹删除,不然图会覆盖
        # ====================================================================
        train_tensorboard_dir = 'tensorboard/cnn/train/'
        test_tensorboard_dir = 'tensorboard/cnn/test/'
        if not os.path.exists(train_tensorboard_dir):
            os.makedirs(train_tensorboard_dir)
        if not os.path.exists(test_tensorboard_dir):
            os.makedirs(test_tensorboard_dir)

        # 训练结果记录
        log_file = open(test_tensorboard_dir + '/log.csv',
                        mode='w',
                        encoding='utf-8')
        log_file.write(
            ','.join(['epoch', 'loss', 'precision', 'recall', 'f1_score']) +
            '\n')

        merged_summary = tf.summary.merge([
            tf.summary.scalar('loss', cnn.loss),
            tf.summary.scalar('accuracy', cnn.accuracy)
        ])

        train_summary_writer = tf.summary.FileWriter(train_tensorboard_dir,
                                                     sess.graph)
        # =========================================================================

        global_step = tf.Variable(0, trainable=False)
        # 衰减的学习率,每1000次衰减4%
        learning_rate = tf.train.exponential_decay(config.learning_rate,
                                                   global_step,
                                                   decay_steps=5000,
                                                   decay_rate=0.98,
                                                   staircase=False)

        # 保证Batch normalization的执行
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(
                update_ops):  # 保证train_op在update_ops执行之后再执行。
            train_op = tf.train.AdamOptimizer(learning_rate).minimize(
                cnn.loss, global_step)

        # 训练步骤
        def train_step(batch_x, batch_y, keep_prob=config.dropout_keep_prob):
            feed_dict = {
                cnn.input_x: batch_x,
                cnn.labels: batch_y,
                cnn.dropout_keep_prob: keep_prob,
                cnn.training: True
            }
            sess.run(train_op, feed_dict=feed_dict)
            step, loss, accuracy, summery = sess.run(
                [global_step, cnn.loss, cnn.accuracy, merged_summary],
                feed_dict={
                    cnn.input_x: batch_x,
                    cnn.labels: batch_y,
                    cnn.dropout_keep_prob: 1.0,
                    cnn.training: False
                })
            t = datetime.datetime.now().strftime('%m-%d %H:%M')
            print('%s: epoch: %d, step: %d, loss: %f, accuracy: %f' %
                  (t, epoch, step, loss, accuracy))
            # 把结果写入Tensorboard中
            train_summary_writer.add_summary(summery, step)

        # 验证步骤
        def test_step(next_test_element):
            # 把test_loss和test_accuracy归0
            y_true = []
            y_pred = []
            test_loss = 0.0
            test_accuracy = 0.0
            test_precision = 0.0
            test_recall = 0.0
            test_f1_score = 0.0
            i = 0
            while True:
                try:
                    lines = sess.run(next_test_element)
                    batch_x, batch_y = cnn.convert_input(lines)
                    feed_dict = {
                        cnn.input_x: batch_x,
                        cnn.labels: batch_y,
                        cnn.dropout_keep_prob: 1.0,
                        cnn.training: False
                    }
                    # loss, pred, true = sess.run([cnn.loss, cnn.prediction, cnn.labels], feed_dict)
                    # 多次验证,取loss和score均值
                    mean_loss = 0
                    mean_score = 0
                    for i in range(config.multi_test_num):
                        loss, score = sess.run([cnn.loss, cnn.score],
                                               feed_dict)
                        mean_loss += loss
                        mean_score += score
                    mean_loss /= config.multi_test_num
                    mean_score /= config.multi_test_num
                    pred = sess.run(tf.argmax(mean_score, 1))
                    y_pred.extend(pred)
                    y_true.extend(batch_y)
                    test_loss += mean_loss
                    i += 1
                except tf.errors.OutOfRangeError:
                    # 遍历完验证集,计算评估
                    test_loss /= i
                    test_accuracy = metrics.accuracy_score(y_true=y_true,
                                                           y_pred=y_pred)
                    test_precision = metrics.precision_score(
                        y_true=y_true, y_pred=y_pred, average='weighted')
                    test_recall = metrics.recall_score(y_true=y_true,
                                                       y_pred=y_pred,
                                                       average='weighted')
                    test_f1_score = metrics.f1_score(y_true=y_true,
                                                     y_pred=y_pred,
                                                     average='weighted')

                    t = datetime.datetime.now().strftime('%m-%d %H:%M')
                    log = '%s: epoch %d, testing loss: %0.6f, accuracy: %0.6f' % (
                        t, epoch, test_loss, test_accuracy)
                    log = log + '\n' + (
                        'precision: %0.6f, recall: %0.6f, f1_score: %0.6f' %
                        (test_precision, test_recall, test_f1_score))
                    print(log)
                    log_file.write(','.join([
                        str(epoch),
                        str(test_loss),
                        str(test_precision),
                        str(test_recall),
                        str(test_f1_score)
                    ]) + '\n')
                    time.sleep(3)
                    return

        print('Start training CNN...')
        sess.run(tf.global_variables_initializer())
        train_init_op, test_init_op, next_train_element, next_test_element = cnn.prepare_data(
        )
        # Training loop
        for epoch in range(config.epoch_num):
            sess.run(train_init_op)
            while True:
                try:
                    lines = sess.run(next_train_element)
                    batch_x, batch_y = cnn.convert_input(lines)
                    train_step(batch_x, batch_y, config.dropout_keep_prob)
                except tf.errors.OutOfRangeError:
                    # 初始化验证集迭代器
                    sess.run(test_init_op)
                    # 计算验证集准确率
                    test_step(next_test_element)
                    break
        train_summary_writer.close()
        log_file.close()
        # 训练完成后保存参数
        path = saver.save(sess, checkpoint_prefix, global_step=global_step)
        print("Saved model checkpoint to {}\n".format(path))
コード例 #7
0
def train():
    # Training procedure
    # ======================================================
    # 设定最小显存使用量
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        config = CNNConfig('MULTI')
        cnn = TextCNN(config)
        cnn.prepare_data()
        cnn.setCNN()

        print('Setting Tensorboard and Saver...')
        # 设置Saver和checkpoint来保存模型
        # ===================================================
        checkpoint_dir = os.path.join(os.path.abspath("checkpoints"),
                                      "textcnn")
        checkpoint_prefix = os.path.join(checkpoint_dir, cnn.train_mode)
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        saver = tf.train.Saver(tf.global_variables())
        # =====================================================

        # 配置Tensorboard,重新训练时,请将tensorboard文件夹删除,不然图会覆盖
        # ====================================================================
        train_tensorboard_dir = 'tensorboard/textcnn/train/' + config.train_mode
        valid_tensorboard_dir = 'tensorboard/textcnn/valid/' + config.train_mode
        if not os.path.exists(train_tensorboard_dir):
            os.makedirs(train_tensorboard_dir)
        if not os.path.exists(valid_tensorboard_dir):
            os.makedirs(valid_tensorboard_dir)

        # 训练结果记录
        log_file = open(valid_tensorboard_dir + '/log.txt', mode='w')

        merged_summary = tf.summary.merge([
            tf.summary.scalar('loss', cnn.loss),
            tf.summary.scalar('accuracy', cnn.accuracy)
        ])

        train_summary_writer = tf.summary.FileWriter(train_tensorboard_dir,
                                                     sess.graph)
        # =========================================================================

        global_step = tf.Variable(0, trainable=False)

        # 保证Batch normalization的执行
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(
                update_ops):  # 保证train_op在update_ops执行之后再执行。
            train_op = tf.train.AdamOptimizer(config.learning_rate).minimize(
                cnn.loss, global_step)

        # 训练步骤
        def train_step(batch_x, batch_y, keep_prob=config.dropout_keep_prob):
            feed_dict = {
                cnn.input_x: batch_x,
                cnn.labels: batch_y,
                cnn.dropout_keep_prob: keep_prob,
                cnn.training: True
            }
            sess.run(train_op, feed_dict=feed_dict)
            step, loss, accuracy, summery = sess.run(
                [global_step, cnn.loss, cnn.accuracy, merged_summary],
                feed_dict={
                    cnn.input_x: batch_x,
                    cnn.labels: batch_y,
                    cnn.dropout_keep_prob: 1.0,
                    cnn.training: False
                })
            t = datetime.datetime.now().strftime('%m-%d %H:%M')
            print('%s: epoch: %d, step: %d, loss: %f, accuracy: %f' %
                  (t, epoch, step, loss, accuracy))
            # 把结果写入Tensorboard中
            train_summary_writer.add_summary(summery, step)

        # 验证步骤
        def valid_step(next_valid_element):
            # 把valid_loss和valid_accuracy归0
            valid_loss = 0.0
            valid_accuracy = 0.0
            valid_precision = 0.0
            valid_recall = 0.0
            valid_f1_score = 0.0
            i = 0
            while True:
                try:
                    lines = sess.run(next_valid_element)
                    batch_x, batch_y = cnn.convert_input(lines)
                    feed_dict = {
                        cnn.input_x: batch_x,
                        cnn.labels: batch_y,
                        cnn.dropout_keep_prob: 1.0,
                        cnn.training: False
                    }
                    loss, accuracy, prediction, y_true = sess.run(
                        [cnn.loss, cnn.accuracy, cnn.prediction, cnn.labels],
                        feed_dict)

                    precision = sk.metrics.precision_score(y_true=y_true,
                                                           y_pred=prediction,
                                                           average='weighted')
                    recall = sk.metrics.recall_score(y_true=y_true,
                                                     y_pred=prediction,
                                                     average='weighted')
                    f1_score = sk.metrics.f1_score(y_true=y_true,
                                                   y_pred=prediction,
                                                   average='weighted')

                    valid_loss += loss
                    valid_accuracy += accuracy
                    valid_precision += precision
                    valid_recall += recall
                    valid_f1_score += f1_score
                    i += 1

                except tf.errors.OutOfRangeError:
                    # 遍历完验证集,然后对loss和accuracy求平均值
                    valid_loss /= i
                    valid_accuracy /= i
                    valid_precision /= i
                    valid_recall /= i
                    valid_f1_score /= i

                    t = datetime.datetime.now().strftime('%m-%d %H:%M')
                    log = '%s: epoch %d, validation loss: %0.6f, accuracy: %0.6f' % (
                        t, epoch, valid_loss, valid_accuracy)
                    log = log + '\n' + (
                        'precision: %0.6f, recall: %0.6f, f1_score: %0.6f' %
                        (valid_precision, valid_recall, valid_f1_score))
                    print(log)
                    log_file.write(log + '\n')
                    time.sleep(3)
                    return

        print('Start training TextCNN, training mode=' + cnn.train_mode)
        sess.run(tf.global_variables_initializer())

        # Training loop
        for epoch in range(config.epoch_num):
            train_init_op, valid_init_op, next_train_element, next_valid_element = cnn.shuffle_datset(
            )
            sess.run(train_init_op)
            while True:
                try:
                    lines = sess.run(next_train_element)
                    batch_x, batch_y = cnn.convert_input(lines)
                    train_step(batch_x, batch_y, config.dropout_keep_prob)
                except tf.errors.OutOfRangeError:
                    # 初始化验证集迭代器
                    sess.run(valid_init_op)
                    # 计算验证集准确率
                    valid_step(next_valid_element)
                    break

        train_summary_writer.close()
        log_file.close()
        # 训练完成后保存参数
        path = saver.save(sess, checkpoint_prefix, global_step=global_step)
        print("Saved model checkpoint to {}\n".format(path))