class Predictor(object):
    def __init__(self, model_path):
        self.model_path = model_path
        self.class_name = [
            'Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral'
        ]

    def init_model(self):
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        self.sess = tf.Session(config=config)
        self.config = CNNConfig()
        self.cnn = CNN(self.config)
        # self.cnn.setVGG16()

        print('Loading model from file:', self.model_path)
        saver = tf.train.import_meta_graph(self.model_path + '.meta')
        saver.restore(self.sess, self.model_path)
        self.graph = tf.get_default_graph()
        # 从图中读取变量
        self.input_x = self.graph.get_operation_by_name("input_x").outputs[0]
        self.labels = self.graph.get_operation_by_name("labels").outputs[0]
        self.dropout_keep_prob = self.graph.get_operation_by_name(
            "dropout_keep_prob").outputs[0]
        self.score = self.graph.get_operation_by_name('score/Relu').outputs[0]
        self.prediction = self.graph.get_operation_by_name(
            "prediction").outputs[0]
        self.training = self.graph.get_operation_by_name("training").outputs[0]

    def predict(self, batch_x):
        feed_dict = {
            self.input_x: batch_x,
            self.dropout_keep_prob: 1.0,
            self.training: False
        }
        score, pre = self.sess.run([self.score, self.prediction], feed_dict)
        return score, pre

    def draw_confusion_matrix(self):
        # train_init_op, test_init_op, next_train_element, next_test_element = self.cnn.prepare_data()
        test_dataset = TextLineDataset(
            os.path.join('data', preprocess.FILTERED_TEST_PATH)).skip(1).batch(
                self.cnn.test_batch_size)
        # Create a reinitializable iterator
        test_iterator = test_dataset.make_one_shot_iterator()
        next_test_element = test_iterator.get_next()

        y_true = []
        y_pred = []
        test_loss = 0.0
        test_accuracy = 0.0
        test_precision = 0.0
        test_recall = 0.0
        test_f1_score = 0.0
        i = 0
        while True:
            try:
                lines = self.sess.run(next_test_element)
                batch_x, batch_y = self.cnn.convert_input(lines)
                feed_dict = {
                    self.input_x: batch_x,
                    self.labels: batch_y,
                    self.dropout_keep_prob: 1.0,
                    self.training: False
                }
                # loss, pred, true = sess.run([self.cnn.loss, self.cnn.prediction, self.cnn.labels], feed_dict)
                # 多次验证,取loss和score均值
                mean_score = 0
                for i in range(self.config.multi_test_num):
                    score = self.sess.run(self.score, feed_dict)
                    mean_score += score
                mean_score /= self.config.multi_test_num
                pred = self.sess.run(tf.argmax(mean_score, 1))
                y_pred.extend(pred)
                y_true.extend(batch_y)
                i += 1
            except tf.errors.OutOfRangeError:
                # 遍历完验证集,计算评估
                test_loss /= i
                test_accuracy = metrics.accuracy_score(y_true=y_true,
                                                       y_pred=y_pred)
                test_precision = metrics.precision_score(y_true=y_true,
                                                         y_pred=y_pred,
                                                         average='weighted')
                test_recall = metrics.recall_score(y_true=y_true,
                                                   y_pred=y_pred,
                                                   average='weighted')
                test_f1_score = metrics.f1_score(y_true=y_true,
                                                 y_pred=y_pred,
                                                 average='weighted')
                log = ('precision: %0.6f, recall: %0.6f, f1_score: %0.6f' %
                       (test_precision, test_recall, test_f1_score))
                print(log)

                cm = confusion_matrix(y_true, y_pred)
                print('Total samples:', np.sum(cm))
                cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]  # 归一化
                print('Confusion matrix:\n', cm)
                # 绘制混淆矩阵
                # ==============================================================
                fig, ax = plt.subplots()
                im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
                ax.figure.colorbar(im, ax=ax)
                # We want to show all ticks...
                ax.set(
                    xticks=np.arange(cm.shape[1]),
                    yticks=np.arange(cm.shape[0]),
                    # ... and label them with the respective list entries
                    xticklabels=self.class_name,
                    yticklabels=self.class_name,
                    title="Normalized confusion matrix",
                    ylabel='True label',
                    xlabel='Predicted label')

                # Rotate the tick labels and set their alignment.
                plt.setp(ax.get_xticklabels(),
                         rotation=45,
                         ha="right",
                         rotation_mode="anchor")

                # Loop over data dimensions and create text annotations.
                fmt = '.2f'
                thresh = cm.max() / 2.
                for i in range(cm.shape[0]):
                    for j in range(cm.shape[1]):
                        ax.text(
                            j,
                            i,
                            format(cm[i, j], fmt),
                            ha="center",
                            va="center",
                            color="white" if cm[i, j] > thresh else "black")
                fig.tight_layout()
                plt.savefig('./data/confusion_matrix.jpg')
                plt.show()
                # =====================================================================
                break

    def _detect_sentiment(self, detector, img):
        # 转为灰度图片
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        faces = detector.detectMultiScale(image=gray,
                                          scaleFactor=1.1,
                                          minNeighbors=2,
                                          minSize=(30, 30),
                                          flags=0)
        if len(faces) != 0:
            batch_x = []
            for face in faces:
                x, y, w, h = face
                cv2.rectangle(img, (x, y), (x + w, y + h), (0, 0, 255), 1)
                # opencv的图像是[y, x]储存的
                # 裁剪并显示人脸部分
                img_cropped = cv2.resize(
                    gray[y:y + h,
                         x:x + w], (self.cnn.img_size, self.cnn.img_size))
                cv2.imshow('cropped', img_cropped)
                img_input = img_cropped.reshape(
                    [self.cnn.img_size, self.cnn.img_size, 1])
                batch_x.append(img_input)
            batch_x = np.stack(batch_x)
            mean_score = 0
            for i in range(self.config.multi_test_num):
                score, _ = self.predict(batch_x)
                mean_score += score
            mean_score /= self.config.multi_test_num
            pred = self.sess.run(tf.argmax(mean_score, 1))
            for i in range(len(faces)):
                # 给score显示条形图
                # =======================================================
                plt.bar(range(self.cnn.class_num),
                        mean_score[i],
                        align='center',
                        color='steelblue',
                        alpha=0.8)
                plt.ylabel('Score')
                plt.xticks(range(self.cnn.class_num), self.class_name)
                plt.show()
                # ========================================================
                cv2.putText(img=img,
                            text=self.class_name[pred[i]],
                            org=(faces[i][0], faces[i][1] + faces[i][3] + 20),
                            fontFace=cv2.FONT_HERSHEY_COMPLEX,
                            fontScale=0.6,
                            color=(0, 0, 255))
            return img
        else:
            return None

    def camera_detect(self):
        # 调用笔记本内置摄像头,所以参数为0,如果有其他的摄像头可以调整参数为1,2
        cam = cv2.VideoCapture(0)
        detector = cv2.CascadeClassifier(
            './data/haarcascade_frontalface_alt.xml')
        while True:
            # 从摄像头读取图片
            sucess, img = cam.read()
            if not sucess:
                continue
            img = self._detect_sentiment(detector, img)
            # 显示摄像头,背景是灰度。
            if img is not None:
                cv2.imshow("Sentiment Detection", img)
            # 保持画面的持续。
            k = cv2.waitKey(1)
            if k == 27:
                # 通过esc键退出摄像
                cv2.destroyAllWindows()
                break
            elif k == ord("s"):
                # 通过s键保存图片,并退出。
                cv2.imwrite("capture.jpg", img)
                cv2.destroyAllWindows()
                break
        # 关闭摄像头
        cam.release()

    def image_detect(self, img_path):
        img = cv2.imread(img_path)
        detector = cv2.CascadeClassifier(
            './data/haarcascade_frontalface_alt.xml')
        img = self._detect_sentiment(detector, img)
        cv2.imshow('Sentiment Detection Result', img)
        k = cv2.waitKey()
        if k == ord("s"):
            # 通过s键保存图片,并退出。
            cv2.imwrite("result.jpg", img)
        cv2.destroyAllWindows()
Exemplo n.º 2
0
def train():
    # Training procedure
    # ======================================================
    # 设定最小显存使用量
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        config = CNNConfig()
        cnn = CNN(config)
        cnn.setVGG16()

        print('Setting Tensorboard and Saver...')
        # 设置Saver和checkpoint来保存模型
        # ===================================================
        checkpoint_dir = os.path.join(os.path.abspath("checkpoints"), "cnn")
        checkpoint_prefix = os.path.join(checkpoint_dir)
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        saver = tf.train.Saver(tf.global_variables())
        # =====================================================

        # 配置Tensorboard,重新训练时,请将tensorboard文件夹删除,不然图会覆盖
        # ====================================================================
        train_tensorboard_dir = 'tensorboard/cnn/train/'
        test_tensorboard_dir = 'tensorboard/cnn/test/'
        if not os.path.exists(train_tensorboard_dir):
            os.makedirs(train_tensorboard_dir)
        if not os.path.exists(test_tensorboard_dir):
            os.makedirs(test_tensorboard_dir)

        # 训练结果记录
        log_file = open(test_tensorboard_dir + '/log.csv',
                        mode='w',
                        encoding='utf-8')
        log_file.write(
            ','.join(['epoch', 'loss', 'precision', 'recall', 'f1_score']) +
            '\n')

        merged_summary = tf.summary.merge([
            tf.summary.scalar('loss', cnn.loss),
            tf.summary.scalar('accuracy', cnn.accuracy)
        ])

        train_summary_writer = tf.summary.FileWriter(train_tensorboard_dir,
                                                     sess.graph)
        # =========================================================================

        global_step = tf.Variable(0, trainable=False)
        # 衰减的学习率,每1000次衰减4%
        learning_rate = tf.train.exponential_decay(config.learning_rate,
                                                   global_step,
                                                   decay_steps=5000,
                                                   decay_rate=0.98,
                                                   staircase=False)

        # 保证Batch normalization的执行
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(
                update_ops):  # 保证train_op在update_ops执行之后再执行。
            train_op = tf.train.AdamOptimizer(learning_rate).minimize(
                cnn.loss, global_step)

        # 训练步骤
        def train_step(batch_x, batch_y, keep_prob=config.dropout_keep_prob):
            feed_dict = {
                cnn.input_x: batch_x,
                cnn.labels: batch_y,
                cnn.dropout_keep_prob: keep_prob,
                cnn.training: True
            }
            sess.run(train_op, feed_dict=feed_dict)
            step, loss, accuracy, summery = sess.run(
                [global_step, cnn.loss, cnn.accuracy, merged_summary],
                feed_dict={
                    cnn.input_x: batch_x,
                    cnn.labels: batch_y,
                    cnn.dropout_keep_prob: 1.0,
                    cnn.training: False
                })
            t = datetime.datetime.now().strftime('%m-%d %H:%M')
            print('%s: epoch: %d, step: %d, loss: %f, accuracy: %f' %
                  (t, epoch, step, loss, accuracy))
            # 把结果写入Tensorboard中
            train_summary_writer.add_summary(summery, step)

        # 验证步骤
        def test_step(next_test_element):
            # 把test_loss和test_accuracy归0
            y_true = []
            y_pred = []
            test_loss = 0.0
            test_accuracy = 0.0
            test_precision = 0.0
            test_recall = 0.0
            test_f1_score = 0.0
            i = 0
            while True:
                try:
                    lines = sess.run(next_test_element)
                    batch_x, batch_y = cnn.convert_input(lines)
                    feed_dict = {
                        cnn.input_x: batch_x,
                        cnn.labels: batch_y,
                        cnn.dropout_keep_prob: 1.0,
                        cnn.training: False
                    }
                    # loss, pred, true = sess.run([cnn.loss, cnn.prediction, cnn.labels], feed_dict)
                    # 多次验证,取loss和score均值
                    mean_loss = 0
                    mean_score = 0
                    for i in range(config.multi_test_num):
                        loss, score = sess.run([cnn.loss, cnn.score],
                                               feed_dict)
                        mean_loss += loss
                        mean_score += score
                    mean_loss /= config.multi_test_num
                    mean_score /= config.multi_test_num
                    pred = sess.run(tf.argmax(mean_score, 1))
                    y_pred.extend(pred)
                    y_true.extend(batch_y)
                    test_loss += mean_loss
                    i += 1
                except tf.errors.OutOfRangeError:
                    # 遍历完验证集,计算评估
                    test_loss /= i
                    test_accuracy = metrics.accuracy_score(y_true=y_true,
                                                           y_pred=y_pred)
                    test_precision = metrics.precision_score(
                        y_true=y_true, y_pred=y_pred, average='weighted')
                    test_recall = metrics.recall_score(y_true=y_true,
                                                       y_pred=y_pred,
                                                       average='weighted')
                    test_f1_score = metrics.f1_score(y_true=y_true,
                                                     y_pred=y_pred,
                                                     average='weighted')

                    t = datetime.datetime.now().strftime('%m-%d %H:%M')
                    log = '%s: epoch %d, testing loss: %0.6f, accuracy: %0.6f' % (
                        t, epoch, test_loss, test_accuracy)
                    log = log + '\n' + (
                        'precision: %0.6f, recall: %0.6f, f1_score: %0.6f' %
                        (test_precision, test_recall, test_f1_score))
                    print(log)
                    log_file.write(','.join([
                        str(epoch),
                        str(test_loss),
                        str(test_precision),
                        str(test_recall),
                        str(test_f1_score)
                    ]) + '\n')
                    time.sleep(3)
                    return

        print('Start training CNN...')
        sess.run(tf.global_variables_initializer())
        train_init_op, test_init_op, next_train_element, next_test_element = cnn.prepare_data(
        )
        # Training loop
        for epoch in range(config.epoch_num):
            sess.run(train_init_op)
            while True:
                try:
                    lines = sess.run(next_train_element)
                    batch_x, batch_y = cnn.convert_input(lines)
                    train_step(batch_x, batch_y, config.dropout_keep_prob)
                except tf.errors.OutOfRangeError:
                    # 初始化验证集迭代器
                    sess.run(test_init_op)
                    # 计算验证集准确率
                    test_step(next_test_element)
                    break
        train_summary_writer.close()
        log_file.close()
        # 训练完成后保存参数
        path = saver.save(sess, checkpoint_prefix, global_step=global_step)
        print("Saved model checkpoint to {}\n".format(path))