Пример #1
0
def test():
    print('Loading test data...')
    x_test, y_test, seq_lens_test = process_file(test_dir, word_to_id,
                                                 label_to_id,
                                                 config.seq_length)
    content_test, label_test = read_corpus(test_dir)

    session = tf.Session()
    session.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    # 读取模型
    saver.restore(sess=session, save_path=save_path)

    print('Start testing...')
    loss_test, acc_test = evaluate(session, x_test, y_test, seq_lens_test,
                                   config.batch_size)
    msg = 'Test Loss: {0:>2.2f}, Test Acc: {1:>2.2%}'
    print(msg.format(loss_test, acc_test))

    data_len = len(x_test)
    num_batch = int((data_len - 1) / config.batch_size) + 1
    predict_result = np.zeros(shape=[len(x_test)], dtype=np.int32)
    for i in range(num_batch):
        start = i * config.batch_size
        end = min(data_len, start + config.batch_size)
        feed_dict = {
            model.content: x_test[start:end],
            model.label: y_test[start:end],
            model.sequence_lengths: seq_lens_test[start:end],
            model.batch_size: len(x_test[start:end])
        }
        predict_result[start:end] = session.run(model.predict_label,
                                                feed_dict=feed_dict)

    print('Writing predict result to predict.txt...')
    with open('predict.txt', 'w', encoding='utf-8') as f:
        for i in range(len(predict_result)):
            f.write(id_to_label[predict_result[i]] + '\t' + content_test[i] +
                    '\n')
Пример #2
0
def train():
    # 配置TensorBoard和Saver
    print('Configuring TensorBoard and Saver...')
    saver = tf.train.Saver()
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    if not os.path.exists(tensorboard_dir):
        os.makedirs(tensorboard_dir)
    tf.summary.scalar('loss', model.loss)
    tf.summary.scalar('accuracy', model.accuracy)
    merged_summary = tf.summary.merge_all()
    writer = tf.summary.FileWriter(tensorboard_dir)

    # 处理数据
    print('Loading training data and validation data...')
    start_time = time.time()
    x_train, y_train, seq_lens_train = process_file(train_dir, word_to_id,
                                                    label_to_id,
                                                    config.seq_length)
    x_val, y_val, seq_lens_val = process_file(val_dir, word_to_id, label_to_id,
                                              config.seq_length)
    print('Time usage:', get_time_dif(start_time))

    # 创建session
    session = tf.Session()
    session.run(tf.global_variables_initializer())
    # 将图添加到TensorBoard中
    writer.add_graph(session.graph)

    # 开始训练
    print('Start training...')
    start_time = time.time()
    best_acc_val = 0.0
    total_batch = 0

    for epoch in range(config.epoch):
        batch_train = batch_iter(x_train, y_train, seq_lens_train,
                                 config.batch_size)
        for x_batch, y_batch, seq_lens_batch in batch_train:
            feed_dict = {
                model.content: x_batch,
                model.label: y_batch,
                model.sequence_lengths: seq_lens_batch
            }

            # 将训练结果写如到TensorBoard中
            if total_batch % config.save_pre_batch == 0:
                s = session.run(merged_summary, feed_dict=feed_dict)
                writer.add_summary(s, total_batch)

            # 输出训练集和验证集的结果,并保存最好的模型
            if total_batch % config.print_pre_batch == 0:
                loss_train, acc_train = session.run(
                    [model.loss, model.accuracy], feed_dict=feed_dict)
                loss_val, acc_val = evaluate(session, x_val, y_val,
                                             seq_lens_val, config.batch_size)
                # 每次只保存最好的模型
                if acc_val > best_acc_val:
                    best_acc_val = acc_val
                    saver.save(session, save_path)
                    improved_str = '*'
                else:
                    improved_str = ''
                time_dif = get_time_dif(start_time)
                msg = 'Iter: {0:>2}, Train Loss: {1:>2.2f}, Train Acc: {2:>2.2%}, ' \
                      'Val Loss: {3:>2.2f}, Val Acc: {4:>2.2%}, Time: {5} {6}'
                print(
                    msg.format(total_batch, loss_train, acc_train, loss_val,
                               acc_val, time_dif, improved_str))

            session.run(model.optimizer, feed_dict=feed_dict)
            total_batch += 1