Ejemplo n.º 1
0
def restore_model(testPicArr):
    # 实例化一个数据流图并作为整个 tensorflow 运行环境的默认图
    with tf.Graph().as_default() as tg:
        # 输入x占位
        x = tf.placeholder(
            tf.float32,
            [1, config.img_width, config.img_height, fer_forward.NUM_CHANNELS])
        # 获得输出y的前向传播计算图
        y = fer_forward.forward(x, False, None)
        # 定义预测值为y中最大值的索引号
        preValue = tf.argmax(y, 1)
        # 定义滑动平均
        variable_averages = tf.train.ExponentialMovingAverage(
            fer_backward.MOVING_AVERAGE_DECAY)
        # 将影子变量直接映射到变量的本身
        variables_to_restore = variable_averages.variables_to_restore()
        # 创建一个保存模型的对象
        saver = tf.train.Saver(variables_to_restore)

        # 创建一个会话
        with tf.Session() as sess:
            # 通过checkpoint文件找到模型文件名
            ckpt = tf.train.get_checkpoint_state(config.MODEL_SAVE_PATH)
            # 如果模型存在
            if ckpt and ckpt.model_checkpoint_path:
                # 加载模型
                saver.restore(sess, ckpt.model_checkpoint_path)
                reshape_x = np.reshape(testPicArr,
                                       (1, config.img_width, config.img_height,
                                        fer_forward.NUM_CHANNELS))
                # 计算预测值
                preValue = sess.run(preValue, feed_dict={x: reshape_x})
                # 返回预测值
                return preValue
            # 如果模型不存在
            else:
                # 输出模型文件未找到提示
                print("No checkpoint file found")
                # 返回-1
                return -1
Ejemplo n.º 2
0
def test():
    # 实例化一个数据流图并作为整个 tensorflow 运行环境的默认图
    with tf.Graph().as_default() as g:
        x = tf.placeholder(tf.float32, [
            MINI_BATCH, config.img_width, config.img_height,
            fer_forward.NUM_CHANNELS
        ])
        y_ = tf.placeholder(tf.float32, [None, fer_forward.OUTPUT_NODE])

        prob = tf.placeholder(tf.float32)
        bn_training = tf.placeholder(tf.bool)
        # y = fer_forward.forward(x, keep_prob=prob)
        y, dict_ret = fer_forward.forward(x,
                                          keep_prob=prob,
                                          bn_enable=True,
                                          bn_training=bn_training)

        ema = tf.train.ExponentialMovingAverage(backward.MOVING_AVERAGE_DECAY)
        ema_restore = ema.variables_to_restore()  #生成ema替代原变量的映射关系。

        loader = tf.train.Saver(ema_restore)

        correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        # 批量获取测试数据
        img_batch, label_batch = fer_generator.get_tfrecord(
            MINI_BATCH, config.tfRecord_test)
        for i in range(3):
            # 创建一个会话
            with tf.Session() as sess:
                # 通过checkpoint文件找到模型文件名
                ckpt = tf.train.get_checkpoint_state(config.MODEL_SAVE_PATH)
                # 如果模型存在
                if ckpt and ckpt.model_checkpoint_path:

                    loader.restore(sess, ckpt.model_checkpoint_path)
                    global_step = ckpt.model_checkpoint_path.split(
                        '/')[-1].split('-')[-1]

                    # 创建一个线程协调器
                    coord = tf.train.Coordinator()
                    threads = tf.train.start_queue_runners(sess=sess,
                                                           coord=coord)

                    iterations = int(TOTAL_TEST_NUM / MINI_BATCH)
                    total_accuracy_score = 0
                    for i in range(iterations):
                        xs, ys = sess.run([img_batch,
                                           label_batch])  #一定要把这步扔到循环内部。
                        # reshape测试输入数据xs
                        reshape_xs = np.reshape(
                            xs, (MINI_BATCH, config.img_width,
                                 config.img_height, fer_forward.NUM_CHANNELS))

                        accuracy_score = sess.run(accuracy,
                                                  feed_dict={
                                                      x: reshape_xs,
                                                      y_: ys,
                                                      prob: 1.0,
                                                      bn_training: False
                                                  })

                        print("%g" % (accuracy_score), end=', ')
                        total_accuracy_score += accuracy_score

                    # 输出global_step和准确率
                    print("After %s training step(s), test accuracy = %g" %
                          (global_step, total_accuracy_score / iterations))
                    # 终止所有线程
                    coord.request_stop()
                    coord.join(threads)

                else:
                    print('No checkpoint file found')
                    return
            time.sleep(TEST_INTERVAL_SECS)
Ejemplo n.º 3
0
def backward():
    flip_control = tf.random_uniform(
        shape=[BATCH_SIZE])  #控制翻转,都从[1]改成[batch_size],tf.cond换成tf.where。
    rotate_control = tf.random_uniform(shape=[BATCH_SIZE])
    rotate_angle = tf.random_uniform(
        shape=[1], minval=-0.5, maxval=0.5,
        dtype=tf.float32)  #我自己测0.5还可接受,因为有些图本来就偏,不能偏太大。
    global_step = tf.Variable(0, trainable=False)
    learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,
                                               global_step,
                                               train_num_examples / BATCH_SIZE,
                                               LEARNING_RATE_DECAY,
                                               staircase=True)

    # 输入图像size是可变的,通过配置文件,lenet输入通道数是1,灰度图,这个应该得先处理了,原图可能不是灰度的
    x = tf.placeholder(tf.float32, [
        BATCH_SIZE, config.img_width, config.img_height,
        fer_forward.NUM_CHANNELS
    ],
                       name='x')
    y_ = tf.placeholder(tf.float32, [None, fer_forward.OUTPUT_NODE], name='y_')
    prob = tf.placeholder(tf.float32, name='keep_prob')
    bn_training = tf.placeholder(tf.bool, name='bn_training')
    is_data_augment = tf.placeholder(
        tf.bool,
        name='data_augment')  #是否数据增强,测准确度的时候都不应该带。#这个其实在提取数据的计算图,不在网络的计算图。。。。
    #建立网络和损失函数
    y, return_dict = fer_forward.forward(x,
                                         keep_prob=prob,
                                         bn_enable=True,
                                         bn_training=bn_training)
    cem = cross_entropy(y, y_)

    loss = cem + tf.add_n(tf.get_collection('regularization_losses'))
    train_step = tf.train.AdamOptimizer(learning_rate).minimize(
        loss, global_step=global_step)
    # 用了EMA,测试的时候取EMA,训练的时候反馈正确率用的是weights
    ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
    ema_op = ema.apply(tf.trainable_variables())

    tensorboard_print(return_dict)

    print('ema:', ema)
    print('tf.trainable_variables():', tf.trainable_variables())
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    print('update_ops:', update_ops)
    print('tf.GraphKeys.UPDATE_OPS:', tf.GraphKeys.UPDATE_OPS)
    # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)#BN需要依赖操作
    with tf.control_dependencies(update_ops):
        with tf.control_dependencies([train_step, ema_op]):
            train_op = tf.no_op(name='train')
    #准确率
    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    summary_cem = tf.summary.scalar('cem', cem)
    summary_loss = tf.summary.scalar('loss', loss)
    summary_accuracy = tf.summary.scalar('accuracy', accuracy)

    saver = tf.train.Saver()
    # 批量获取数据
    img_batch, label_batch = fer_generator.get_tfrecord(
        BATCH_SIZE, config.tfRecord_train)
    reshaped_img_batch = tf.reshape(img_batch,
                                    shape=[
                                        BATCH_SIZE, config.img_width,
                                        config.img_height,
                                        fer_forward.NUM_CHANNELS
                                    ])
    print('reshaped_img_batch:', reshaped_img_batch)
    tf.summary.image('img_input', reshaped_img_batch)  #记录输入图片

    #数据增强应该可选
    # flipped_img_batch = tf.cond(is_data_augment,
    #                             lambda:tf.where(flip_control >= 0.5,
    #                                      lambda:tf.image.flip_left_right(reshaped_img_batch),
    #                                      lambda:reshaped_img_batch),
    #                             lambda:reshaped_img_batch)
    # #                             lambda: reshaped_img_batch)
    # rotated_img_batch = tf.cond(is_data_augment,
    #                             tf.where(rotate_control >= 0.5,
    #                                      tf.contrib.image.rotate(flipped_img_batch, rotate_angle[0],interpolation='BILINEAR'),
    #                                      flipped_img_batch),
    #                             flipped_img_batch
    #                             )
    #tf.where里边不用lambda,不是函数,两个同size的tensor就行
    flipped_img_batch = tf.where(flip_control >= 0.5,
                                 tf.image.flip_left_right(reshaped_img_batch),
                                 reshaped_img_batch)
    rotated_img_batch = tf.where(
        rotate_control >= 0.5,
        tf.contrib.image.rotate(flipped_img_batch,
                                rotate_angle[0],
                                interpolation='BILINEAR'),
        flipped_img_batch)  #interpolation:  "NEAREST", "BILINEAR".
    #tf.cond不能用tensor?必须callable
    final_img_batch = tf.cond(is_data_augment, lambda: rotated_img_batch,
                              lambda: reshaped_img_batch)

    # final_img_batch = rotated_img_batch

    test_img_batch, test_label_batch = fer_generator.get_tfrecord(
        MINI_BATCH, config.tfRecord_test)
    with tf.Session() as sess:
        log_dir = 'tensorboard_dir'
        test_log_dir = 'test_tensorboard_dir'
        tf.global_variables_initializer().run()
        tf.local_variables_initializer().run()  #
        # 继续训练
        ckpt = tf.train.get_checkpoint_state(config.MODEL_SAVE_PATH)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:  #这个逻辑不断点续训才能有!!!
            if os.path.exists(log_dir):  # 删掉以前的summary,以免图像重合,
                shutil.rmtree(log_dir)
            os.makedirs(log_dir)
            if os.path.exists(test_log_dir):  # 删掉以前的summary,以免图像重合,
                shutil.rmtree(test_log_dir)
            os.makedirs(test_log_dir)

        writer = tf.summary.FileWriter(logdir=log_dir, graph=sess.graph)
        test_writer = tf.summary.FileWriter(logdir=test_log_dir,
                                            graph=sess.graph)
        merged = tf.summary.merge_all()
        test_merged = tf.summary.merge(
            inputs=[summary_cem, summary_loss,
                    summary_accuracy])  # 不需要其他数据,只需要准确率,用上summary那个返回值

        # 创建一个线程协调器
        coord = tf.train.Coordinator()
        # 启动入队线程
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        #img_batch,label_batch=fer_generator.get_tfrecord(BATCH_SIZE,config.tfRecord_train)##这个位置是血的教训
        for i in range(STEPS):
            xs, ys = sess.run([final_img_batch, label_batch],
                              feed_dict={is_data_augment: True})

            reshape_xs = np.reshape(
                xs, (BATCH_SIZE, config.img_width, config.img_height,
                     fer_forward.NUM_CHANNELS))
            # 训练更新,accuracy_value不是很好用,考虑到dropout。
            _, loss_value, accuracy_value, step, lr = sess.run(
                [train_op, loss, accuracy, global_step, learning_rate],
                feed_dict={
                    x: reshape_xs,
                    y_: ys,
                    prob: 0.3,
                    bn_training: True,
                    is_data_augment: True
                })

            # writer.add_event()
            if (step + 1) % 30 == 0:
                # 训练集,避免影响训练过程拿的数据,这里不新get,使用重复数据,只是把dropout改掉。
                # 后期这里也可以换成valid数据集,如果需要
                accuracy_score, train_summary = sess.run(
                    [accuracy, merged],
                    feed_dict={
                        x: reshape_xs,
                        y_: ys,
                        prob: 1.0,
                        bn_training: True,
                        is_data_augment: False
                    })  # 不能直接用,必须单独run,因为dropout
                #accuracy其实应该用prob1.0,但是想收集prob信息,这样就又收集不到了(都是1.0)。只做调试用吧,无所谓了。
                writer.add_summary(train_summary, step)

                print(
                    "%s : After %d training step(s),lr is %g loss,accuracy on training batch is %g , %g."
                    % (time.strftime('%Y-%m-%d %H:%M:%S'), step, lr,
                       loss_value, accuracy_score))

                #测试集#test的准确率是分批运算的需要合成,怎么办?只能先用batch的
                xs, ys = sess.run([test_img_batch, test_label_batch])
                # reshape测试输入数据xs
                reshape_xs = np.reshape(
                    xs, (MINI_BATCH, config.img_width, config.img_height,
                         fer_forward.NUM_CHANNELS))

                #这个时候的测试准确率其实也不准,因为测试集也走了数据增强的预处理
                accuracy_score, test_summary = sess.run(
                    [accuracy, test_merged],
                    feed_dict={
                        x: reshape_xs,
                        y_: ys,
                        prob: 1.0,
                        bn_training: False,
                        is_data_augment: False
                    })
                test_writer.add_summary(test_summary, step)
                print("After %s training step(s), test accuracy = %g" %
                      (step, accuracy_score))

                # 输出global_step和准确率
                saver.save(sess,
                           os.path.join(config.MODEL_SAVE_PATH,
                                        config.MODEL_NAME),
                           global_step=global_step)

        coord.request_stop()
        coord.join(threads)
Ejemplo n.º 4
0
def test():
    # 实例化一个数据流图并作为整个 tensorflow 运行环境的默认图
    with tf.Graph().as_default() as g:
        # 输入x占位
        x = tf.placeholder(tf.float32, [
            TEST_NUM, config.img_width, config.img_height,
            fer_forward.NUM_CHANNELS
        ])
        # 标记y_占位
        y_ = tf.placeholder(tf.float32, [None, fer_forward.OUTPUT_NODE])
        # 获得输出y的前向传播计算图
        y = fer_forward.forward(x, False, None)
        # 定义滑动平均
        ema = tf.train.ExponentialMovingAverage(
            fer_backward.MOVING_AVERAGE_DECAY)
        # 将影子变量直接映射到变量的本身
        ema_restore = ema.variables_to_restore()
        # 创建一个保存模型的对象
        saver = tf.train.Saver(ema_restore)

        # 判断预测值和标记是否相同
        correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
        # 定义准确率
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

        # 批量获取测试数据
        img_batch, label_batch = fer_generateds.get_tfrecord(
            TEST_NUM, config.tfRecord_test)
        for i in range(3):
            # 创建一个会话
            with tf.Session() as sess:
                # 通过checkpoint文件找到模型文件名
                ckpt = tf.train.get_checkpoint_state(config.MODEL_SAVE_PATH)
                # 如果模型存在
                if ckpt and ckpt.model_checkpoint_path:
                    # 加载模型
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    # 获得模型中的global_step
                    global_step = ckpt.model_checkpoint_path.split(
                        '/')[-1].split('-')[-1]

                    # 创建一个线程协调器
                    coord = tf.train.Coordinator()
                    # 启动入队线程
                    threads = tf.train.start_queue_runners(sess=sess,
                                                           coord=coord)
                    xs, ys = sess.run([img_batch, label_batch])
                    # reshape测试输入数据xs
                    reshape_xs = np.reshape(
                        xs, (TEST_NUM, config.img_width, config.img_height,
                             fer_forward.NUM_CHANNELS))
                    # 计算准确率
                    accuracy_score = sess.run(accuracy,
                                              feed_dict={
                                                  x: reshape_xs,
                                                  y_: ys
                                              })
                    # 输出global_step和准确率
                    print("After %s training step(s), test accuracy = %g" %
                          (global_step, accuracy_score))
                    # 终止所有线程
                    coord.request_stop()
                    coord.join(threads)

                else:
                    print('No checkpoint file found')
                    return
            time.sleep(TEST_INTERVAL_SECS)
def backward():
    global_step = tf.Variable(0, trainable=False)
    learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,
                                               global_step,
                                               train_num_examples / BATCH_SIZE,
                                               LEARNING_RATE_DECAY,
                                               staircase=True)

    # 输入图像size是可变的,通过配置文件,lenet输入通道数是1,灰度图,这个应该得先处理了,原图可能不是灰度的
    x = tf.placeholder(tf.float32, [
        BATCH_SIZE, config.img_width, config.img_height,
        fer_forward.NUM_CHANNELS
    ])
    y_ = tf.placeholder(tf.float32, [None, fer_forward.OUTPUT_NODE])
    #建立网络和损失函数
    y = fer_forward.forward(x, True)
    #损失
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y,
                                                        labels=tf.argmax(
                                                            y_, 1))
    cem = tf.reduce_mean(ce)
    loss = cem + tf.add_n(tf.get_collection('regularization_losses'))
    train_step = tf.train.AdamOptimizer(learning_rate).minimize(
        loss, global_step=global_step)
    # 用了EMA,测试的时候取EMA,训练的时候反馈正确率用的是weights
    ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
    ema_op = ema.apply(tf.trainable_variables())
    #update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)#BN需要依赖操作
    with tf.control_dependencies([train_step, ema_op]):
        train_op = tf.no_op(name='train')
    #准确率
    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    saver = tf.train.Saver()
    # 批量获取数据
    img_batch, label_batch = fer_generator.get_tfrecord(
        BATCH_SIZE, config.tfRecord_train)
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        tf.local_variables_initializer().run()  #队列报错,需要这句?没解决
        # 继续训练
        ckpt = tf.train.get_checkpoint_state(config.MODEL_SAVE_PATH)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)

        # 创建一个线程协调器
        coord = tf.train.Coordinator()
        # 启动入队线程
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        #img_batch,label_batch=fer_generator.get_tfrecord(BATCH_SIZE,config.tfRecord_train)##这个位置是血的教训
        for i in range(STEPS):
            xs, ys = sess.run([img_batch, label_batch])

            reshape_xs = np.reshape(
                xs, (BATCH_SIZE, config.img_width, config.img_height,
                     fer_forward.NUM_CHANNELS))
            # 训练更新
            _, loss_value, accuracy_value, step, lr = sess.run(
                [train_op, loss, accuracy, global_step, learning_rate],
                feed_dict={
                    x: reshape_xs,
                    y_: ys
                })
            if (i + 1) % 100 == 0:
                print(
                    "%s : After %d training step(s),lr is %g loss,accuracy on training batch is %g , %g."
                    % (time.strftime('%Y-%m-%d %H:%M:%S'), step, lr,
                       loss_value, accuracy_value))

                saver.save(sess,
                           os.path.join(config.MODEL_SAVE_PATH,
                                        config.MODEL_NAME),
                           global_step=global_step)

        coord.request_stop()
        coord.join(threads)
Ejemplo n.º 6
0
def backward():
    # 输入x占位
    x = tf.placeholder(tf.float32, [
        BATCH_SIZE, config.img_width, config.img_height,
        fer_forward.NUM_CHANNELS
    ])
    # 标记y_占位
    y_ = tf.placeholder(tf.float32, [None, fer_forward.OUTPUT_NODE])
    # 获得输出y的前向传播计算图
    y = fer_forward.forward(x, True, REGULARIZER)
    # 定义global_step并初始化为0,不可训练
    global_step = tf.Variable(0, trainable=False)
    # 计算稀疏softmax的交叉熵
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y,
                                                        labels=tf.argmax(
                                                            y_, 1))
    # 交叉熵取平均
    cem = tf.reduce_mean(ce)
    # 损失函数loss含正则化
    loss = cem + tf.add_n(tf.get_collection('losses'))
    # 定义指数衰减学习率
    learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,
                                               global_step,
                                               train_num_examples / BATCH_SIZE,
                                               LEARNING_RATE_DECAY,
                                               staircase=True)

    # 采用Adam优化
    train_step = tf.train.AdamOptimizer(learning_rate).minimize(
        loss, global_step=global_step)
    # 定义滑动平均
    ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
    # 将滑动平均作用到所有参数变量
    ema_op = ema.apply(tf.trainable_variables())
    # 每运行一步,所有待优化的参数求滑动平均
    with tf.control_dependencies([train_step, ema_op]):
        train_op = tf.no_op(name='train')

    # 创建一个保存模型的对象
    saver = tf.train.Saver()
    # 判断预测值和标记是否相同
    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
    # 定义准确率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    # 批量获取数据
    img_batch, label_batch = fer_generateds.get_tfrecord(
        BATCH_SIZE, config.tfRecord_train)
    # 创建一个会话
    with tf.Session() as sess:
        # 变量初始化
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        # 通过checkpoint文件找到模型文件名
        ckpt = tf.train.get_checkpoint_state(config.MODEL_SAVE_PATH)
        # 如果模型存在
        if ckpt and ckpt.model_checkpoint_path:
            # 加载模型继续训练
            saver.restore(sess, ckpt.model_checkpoint_path)

        # 创建一个线程协调器
        coord = tf.train.Coordinator()
        # 启动入队线程
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        for i in range(STEPS):
            xs, ys = sess.run([img_batch, label_batch])
            # reshape 输入数据xs
            reshape_xs = np.reshape(
                xs, (BATCH_SIZE, config.img_width, config.img_height,
                     fer_forward.NUM_CHANNELS))
            # 训练更新loss,accuracy,step
            _, loss_value, accuracy_value, step = sess.run(
                [train_op, loss, accuracy, global_step],
                feed_dict={
                    x: reshape_xs,
                    y_: ys
                })
            if (i + 1) % 200 == 0:
                # 输出训练轮数和loss值、accuracy值
                print(
                    "%s : After %d training step(s), loss,accuracy on training batch is %g , %g."
                    % (time.strftime('%Y-%m-%d %H:%M:%S'), step, loss_value,
                       accuracy_value))
                # 保存模型
                saver.save(sess,
                           os.path.join(config.MODEL_SAVE_PATH,
                                        config.MODEL_NAME),
                           global_step=global_step)
        # 终止所有线程
        coord.request_stop()
        coord.join(threads)
Ejemplo n.º 7
0
def test():

    # 实例化一个数据流图并作为整个 tensorflow 运行环境的默认图
    with tf.Graph().as_default() as g:
        x = tf.placeholder(tf.float32, [1, config.img_width,
                                    config.img_height, fer_forward.NUM_CHANNELS])
        tf.summary.image('img_input',x)#记录输入图片
        y_ = tf.placeholder(tf.float32, [None, fer_forward.OUTPUT_NODE])

        prob = tf.placeholder(tf.float32)
        bn_training = tf.placeholder(tf.bool)
        # y = fer_forward.forward(x, keep_prob=prob)
        y,return_dict = fer_forward.forward(x,keep_prob=prob,bn_enable=True,bn_training=bn_training)

        ema = tf.train.ExponentialMovingAverage(backward.MOVING_AVERAGE_DECAY)
        ema_restore = ema.variables_to_restore()#生成ema替代原变量的映射关系。


        ##image是没有办法输出所有channel的(只有1、3、4三种支持),除非用循环把所有维度都分别丢进去吧。
        for i in range(return_dict['relu1_output'].shape[-1]):
            tf.summary.image('relu1/relu1_channel_'+str(i),
                         tf.expand_dims(input=return_dict['relu1_output'][:, :, :, i], axis=-1))  # 记录输出

        for i in range(return_dict['relu2_output'].shape[-1]):
            tf.summary.image('relu2/relu2_channel_' + str(i),
                             tf.expand_dims(input=return_dict['relu2_output'][:, :, :, i], axis=-1))
        for i in range(return_dict['relu3_output'].shape[-1]):
            tf.summary.image('relu3/relu3_channel_'+str(i),
                         tf.expand_dims(input=return_dict['relu3_output'][:, :, :, i], axis=-1))
        for i in range(return_dict['pool1_output'].shape[-1]):
            tf.summary.image('pool1/pool1_output'+str(i),
                         tf.expand_dims(input=return_dict['pool1_output'][:, :, :, i], axis=-1))
        for i in range(return_dict['pool2_output'].shape[-1]):
            tf.summary.image('pool2/pool2_output'+str(i),
                         tf.expand_dims(input=return_dict['pool2_output'][:, :, :, i], axis=-1))
        for i in range(return_dict['pool3_output'].shape[-1]):
            tf.summary.image('pool3/pool3_output'+str(i),
                         tf.expand_dims(input=return_dict['pool3_output'][:, :, :, i], axis=-1))


        print('yyyyyyyyyyyy:',y)
        tf.summary.histogram('y/pic',y[0])

        merged = tf.summary.merge_all()


        loader = tf.train.Saver(ema_restore)

        prediction = tf.argmax(y, 1)


        # 使用真实图片进行预测,包括各种处理。
        samples_column = 4
        samples_row = 3
        samples_size = samples_column * samples_row
        # ./picture_to_test/pic1


        with tf.Session() as sess:
            writer = tf.summary.FileWriter('./inference_tensorboard_dir', graph=sess.graph)
            ckpt = tf.train.get_checkpoint_state(config.MODEL_SAVE_PATH)
            if ckpt and ckpt.model_checkpoint_path:
                loader.restore(sess, ckpt.model_checkpoint_path)
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                for i in range(1,samples_size+1):#下标从1起
                    img = Image.open('./picture_to_test/pic'+str(i)+'.jpg')
                    # print(img, img.mode, img.size, img.format)
                    # 居中裁剪
                    edge_size = min(img.size[0], img.size[1])
                    if img.size[0] > edge_size:  # 如果x更长,裁剪x
                        start_x = (img.size[0] - edge_size) / 2
                        end_x = start_x + edge_size
                        start_y = 0
                        end_y = edge_size
                    else:
                        start_y = (img.size[0] - edge_size) / 2
                        end_y = start_y + edge_size
                        start_x = 0
                        end_x = edge_size

                    bounds = (start_x, start_y, end_x, end_y)
                    cropped_img = img.crop(bounds)
                    resized_img = cropped_img.resize((48, 48))
                    # 转灰度,转ndarray
                    converted_img = resized_img.convert("L")
                    ndarry_img = np.array(converted_img)
                    # 数值归一
                    ndarry_img = ndarry_img.astype(dtype=np.float32)
                    ndarry_img = ndarry_img * 1. / 255
                    ndarry_img = np.expand_dims(ndarry_img, axis=-1)
                    ndarry_img = np.expand_dims(ndarry_img, axis=0)
                    predict_idx,y_out,summary = sess.run([prediction,y, merged],feed_dict={x:ndarry_img,prob:1.0,bn_training:False})
                    predict_str = prediction_dict[predict_idx[0]]
                    print('pic%d:'%i,end=' ')
                    # print(type(y_out),y_out.shape)
                    for j in range(len(y_out[0])):
                        print(prediction_dict[j],':%.2f'%y_out[0][j],end='\t')
                    print()
                    plt.subplot(samples_row,samples_column,i)
                    plt.title('prediction:'+ predict_str)
                    plt.imshow(img)
                    # print('prediction is ',predict_idx)
                    # if i == 1:#方便打印,注意从1开始
                    writer.add_summary(summary,i)
                plt.show()

            else:
                print('No checkpoint file found')
                return