def test(): print('Loading test data...') x_test, y_test, seq_lens_test = process_file(test_dir, word_to_id, label_to_id, config.seq_length) content_test, label_test = read_corpus(test_dir) session = tf.Session() session.run(tf.global_variables_initializer()) saver = tf.train.Saver() # 读取模型 saver.restore(sess=session, save_path=save_path) print('Start testing...') loss_test, acc_test = evaluate(session, x_test, y_test, seq_lens_test, config.batch_size) msg = 'Test Loss: {0:>2.2f}, Test Acc: {1:>2.2%}' print(msg.format(loss_test, acc_test)) data_len = len(x_test) num_batch = int((data_len - 1) / config.batch_size) + 1 predict_result = np.zeros(shape=[len(x_test)], dtype=np.int32) for i in range(num_batch): start = i * config.batch_size end = min(data_len, start + config.batch_size) feed_dict = { model.content: x_test[start:end], model.label: y_test[start:end], model.sequence_lengths: seq_lens_test[start:end], model.batch_size: len(x_test[start:end]) } predict_result[start:end] = session.run(model.predict_label, feed_dict=feed_dict) print('Writing predict result to predict.txt...') with open('predict.txt', 'w', encoding='utf-8') as f: for i in range(len(predict_result)): f.write(id_to_label[predict_result[i]] + '\t' + content_test[i] + '\n')
def train(): # 配置TensorBoard和Saver print('Configuring TensorBoard and Saver...') saver = tf.train.Saver() if not os.path.exists(save_dir): os.makedirs(save_dir) if not os.path.exists(tensorboard_dir): os.makedirs(tensorboard_dir) tf.summary.scalar('loss', model.loss) tf.summary.scalar('accuracy', model.accuracy) merged_summary = tf.summary.merge_all() writer = tf.summary.FileWriter(tensorboard_dir) # 处理数据 print('Loading training data and validation data...') start_time = time.time() x_train, y_train, seq_lens_train = process_file(train_dir, word_to_id, label_to_id, config.seq_length) x_val, y_val, seq_lens_val = process_file(val_dir, word_to_id, label_to_id, config.seq_length) print('Time usage:', get_time_dif(start_time)) # 创建session session = tf.Session() session.run(tf.global_variables_initializer()) # 将图添加到TensorBoard中 writer.add_graph(session.graph) # 开始训练 print('Start training...') start_time = time.time() best_acc_val = 0.0 total_batch = 0 for epoch in range(config.epoch): batch_train = batch_iter(x_train, y_train, seq_lens_train, config.batch_size) for x_batch, y_batch, seq_lens_batch in batch_train: feed_dict = { model.content: x_batch, model.label: y_batch, model.sequence_lengths: seq_lens_batch } # 将训练结果写如到TensorBoard中 if total_batch % config.save_pre_batch == 0: s = session.run(merged_summary, feed_dict=feed_dict) writer.add_summary(s, total_batch) # 输出训练集和验证集的结果,并保存最好的模型 if total_batch % config.print_pre_batch == 0: loss_train, acc_train = session.run( [model.loss, model.accuracy], feed_dict=feed_dict) loss_val, acc_val = evaluate(session, x_val, y_val, seq_lens_val, config.batch_size) # 每次只保存最好的模型 if acc_val > best_acc_val: best_acc_val = acc_val saver.save(session, save_path) improved_str = '*' else: improved_str = '' time_dif = get_time_dif(start_time) msg = 'Iter: {0:>2}, Train Loss: {1:>2.2f}, Train Acc: {2:>2.2%}, ' \ 'Val Loss: {3:>2.2f}, Val Acc: {4:>2.2%}, Time: {5} {6}' print( msg.format(total_batch, loss_train, acc_train, loss_val, acc_val, time_dif, improved_str)) session.run(model.optimizer, feed_dict=feed_dict) total_batch += 1