Ejemplo n.º 1
0
    def prepare_test_data(self):
        # 读取词汇表
        if self.train_mode == 'CHAR-RANDOM':
            # 1.字符级
            self.vocab = preprocess.read_vocab(os.path.join('data',preprocess.CHAR_VOCAB_PATH))

        elif self.train_mode == 'WORD-NON-STATIC':
            self.vocab = preprocess.read_vocab(os.path.join('data', preprocess.WORD_VOCAB_PATH))

        # 测试集有标题,读取时注意跳过第一行
        dataset = TextLineDataset(os.path.join('data',preprocess.TEST_PATH))
        dataset = dataset.shuffle(preprocess.TOTAL_TEST_SIZE).batch(self.test_batch_size)

        iterator = dataset.make_one_shot_iterator()
        next_element = iterator.get_next()

        return dataset, next_element
    def draw_confusion_matrix(self):
        # train_init_op, test_init_op, next_train_element, next_test_element = self.cnn.prepare_data()
        test_dataset = TextLineDataset(
            os.path.join('data', preprocess.FILTERED_TEST_PATH)).skip(1).batch(
                self.cnn.test_batch_size)
        # Create a reinitializable iterator
        test_iterator = test_dataset.make_one_shot_iterator()
        next_test_element = test_iterator.get_next()

        y_true = []
        y_pred = []
        test_loss = 0.0
        test_accuracy = 0.0
        test_precision = 0.0
        test_recall = 0.0
        test_f1_score = 0.0
        i = 0
        while True:
            try:
                lines = self.sess.run(next_test_element)
                batch_x, batch_y = self.cnn.convert_input(lines)
                feed_dict = {
                    self.input_x: batch_x,
                    self.labels: batch_y,
                    self.dropout_keep_prob: 1.0,
                    self.training: False
                }
                # loss, pred, true = sess.run([self.cnn.loss, self.cnn.prediction, self.cnn.labels], feed_dict)
                # 多次验证,取loss和score均值
                mean_score = 0
                for i in range(self.config.multi_test_num):
                    score = self.sess.run(self.score, feed_dict)
                    mean_score += score
                mean_score /= self.config.multi_test_num
                pred = self.sess.run(tf.argmax(mean_score, 1))
                y_pred.extend(pred)
                y_true.extend(batch_y)
                i += 1
            except tf.errors.OutOfRangeError:
                # 遍历完验证集,计算评估
                test_loss /= i
                test_accuracy = metrics.accuracy_score(y_true=y_true,
                                                       y_pred=y_pred)
                test_precision = metrics.precision_score(y_true=y_true,
                                                         y_pred=y_pred,
                                                         average='weighted')
                test_recall = metrics.recall_score(y_true=y_true,
                                                   y_pred=y_pred,
                                                   average='weighted')
                test_f1_score = metrics.f1_score(y_true=y_true,
                                                 y_pred=y_pred,
                                                 average='weighted')
                log = ('precision: %0.6f, recall: %0.6f, f1_score: %0.6f' %
                       (test_precision, test_recall, test_f1_score))
                print(log)

                cm = confusion_matrix(y_true, y_pred)
                print('Total samples:', np.sum(cm))
                cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]  # 归一化
                print('Confusion matrix:\n', cm)
                # 绘制混淆矩阵
                # ==============================================================
                fig, ax = plt.subplots()
                im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
                ax.figure.colorbar(im, ax=ax)
                # We want to show all ticks...
                ax.set(
                    xticks=np.arange(cm.shape[1]),
                    yticks=np.arange(cm.shape[0]),
                    # ... and label them with the respective list entries
                    xticklabels=self.class_name,
                    yticklabels=self.class_name,
                    title="Normalized confusion matrix",
                    ylabel='True label',
                    xlabel='Predicted label')

                # Rotate the tick labels and set their alignment.
                plt.setp(ax.get_xticklabels(),
                         rotation=45,
                         ha="right",
                         rotation_mode="anchor")

                # Loop over data dimensions and create text annotations.
                fmt = '.2f'
                thresh = cm.max() / 2.
                for i in range(cm.shape[0]):
                    for j in range(cm.shape[1]):
                        ax.text(
                            j,
                            i,
                            format(cm[i, j], fmt),
                            ha="center",
                            va="center",
                            color="white" if cm[i, j] > thresh else "black")
                fig.tight_layout()
                plt.savefig('./data/confusion_matrix.jpg')
                plt.show()
                # =====================================================================
                break