def prepare_test_data(self): # 读取词汇表 if self.train_mode == 'CHAR-RANDOM': # 1.字符级 self.vocab = preprocess.read_vocab(os.path.join('data',preprocess.CHAR_VOCAB_PATH)) elif self.train_mode == 'WORD-NON-STATIC': self.vocab = preprocess.read_vocab(os.path.join('data', preprocess.WORD_VOCAB_PATH)) # 测试集有标题,读取时注意跳过第一行 dataset = TextLineDataset(os.path.join('data',preprocess.TEST_PATH)) dataset = dataset.shuffle(preprocess.TOTAL_TEST_SIZE).batch(self.test_batch_size) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() return dataset, next_element
def draw_confusion_matrix(self): # train_init_op, test_init_op, next_train_element, next_test_element = self.cnn.prepare_data() test_dataset = TextLineDataset( os.path.join('data', preprocess.FILTERED_TEST_PATH)).skip(1).batch( self.cnn.test_batch_size) # Create a reinitializable iterator test_iterator = test_dataset.make_one_shot_iterator() next_test_element = test_iterator.get_next() y_true = [] y_pred = [] test_loss = 0.0 test_accuracy = 0.0 test_precision = 0.0 test_recall = 0.0 test_f1_score = 0.0 i = 0 while True: try: lines = self.sess.run(next_test_element) batch_x, batch_y = self.cnn.convert_input(lines) feed_dict = { self.input_x: batch_x, self.labels: batch_y, self.dropout_keep_prob: 1.0, self.training: False } # loss, pred, true = sess.run([self.cnn.loss, self.cnn.prediction, self.cnn.labels], feed_dict) # 多次验证,取loss和score均值 mean_score = 0 for i in range(self.config.multi_test_num): score = self.sess.run(self.score, feed_dict) mean_score += score mean_score /= self.config.multi_test_num pred = self.sess.run(tf.argmax(mean_score, 1)) y_pred.extend(pred) y_true.extend(batch_y) i += 1 except tf.errors.OutOfRangeError: # 遍历完验证集,计算评估 test_loss /= i test_accuracy = metrics.accuracy_score(y_true=y_true, y_pred=y_pred) test_precision = metrics.precision_score(y_true=y_true, y_pred=y_pred, average='weighted') test_recall = metrics.recall_score(y_true=y_true, y_pred=y_pred, average='weighted') test_f1_score = metrics.f1_score(y_true=y_true, y_pred=y_pred, average='weighted') log = ('precision: %0.6f, recall: %0.6f, f1_score: %0.6f' % (test_precision, test_recall, test_f1_score)) print(log) cm = confusion_matrix(y_true, y_pred) print('Total samples:', np.sum(cm)) cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] # 归一化 print('Confusion matrix:\n', cm) # 绘制混淆矩阵 # ============================================================== fig, ax = plt.subplots() im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) ax.figure.colorbar(im, ax=ax) # We want to show all ticks... ax.set( xticks=np.arange(cm.shape[1]), yticks=np.arange(cm.shape[0]), # ... and label them with the respective list entries xticklabels=self.class_name, yticklabels=self.class_name, title="Normalized confusion matrix", ylabel='True label', xlabel='Predicted label') # Rotate the tick labels and set their alignment. plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") # Loop over data dimensions and create text annotations. fmt = '.2f' thresh = cm.max() / 2. for i in range(cm.shape[0]): for j in range(cm.shape[1]): ax.text( j, i, format(cm[i, j], fmt), ha="center", va="center", color="white" if cm[i, j] > thresh else "black") fig.tight_layout() plt.savefig('./data/confusion_matrix.jpg') plt.show() # ===================================================================== break