def train(xs,ys):
    # 定义输入输出placeholder.
    x = tf.placeholder(tf.float32,
                       [BATCH_SIZE,  # 第一维表示一个batch中样例的个数
                        traffic_inference.IMAGE_SIZE,  # 第二维和第三维表示图片的尺寸。
                        traffic_inference.IMAGE_SIZE,
                        traffic_inference.NUM_CHANNELS],  # 第四维度表示图片的深度
                       name='x-input'
                       )
    y_ = tf.placeholder(tf.float32, [None,traffic_inference.OUTPUT_NODE], name='y-input')
    regularizer = tf.contrib.layers.l2_regularizer(REGULARZATION_RATE)
    # 直接使用mnist_inferience.py中的前向传播过程。
    y = traffic_inference.inference(x, True, regularizer)          #预测值在inference文件中进行计算前向传播值
    global_step = tf.Variable(0, trainable=False)

    # 定义损失函数,滑动平均率,学习率,以及训练过程。
    variable_average = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECARY, global_step)
    variable_average_op = variable_average.apply(
        tf.trainable_variables()
    )
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_,1))       #把神经网络和损失函数合在一起计算
    # y是正确的数字只有一个,y_是输出的数字有十个选出最大的一个
    cross_entropy_mean = tf.reduce_mean(cross_entropy)
    loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
    learning_rate = tf.train.exponential_decay(
        LEARNING_RATE_BASE,
        global_step,
        400,
        LEARNING_RATE_DECAY,
        staircase = True
    )
    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)#最小化loss来进行反向传播
    with tf.control_dependencies([train_step, variable_average_op]):
        train_op = tf.no_op(name='train')

    # 初始化Tensorflow持久类
    saver=tf.train.Saver()
    with tf.Session() as sess:
        tf.initialize_all_variables().run()
        a=len(xs)
        # 训练过程中不再验证,测试与验证放在另一个程序中
        for i in range(TRAINING_STEPS):
            start=(i*BATCH_SIZE)%a
            end=min(start+BATCH_SIZE,a)
            xs_batch=xs[start:end]
            ys_batch=ys[start:end]
            reshaped_xs = np.reshape(xs_batch, (BATCH_SIZE,
                                          traffic_inference.IMAGE_SIZE,
                                          traffic_inference.IMAGE_SIZE,
                                          traffic_inference.NUM_CHANNELS))
            _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x:reshaped_xs,y_:ys_batch.eval()})
            #每1000轮保存一次模型。
            if i % 100== 0:
                # 输出当前训练情况,这里只输出了模型在当前训练batch上的损失函数,通过这个来近似了解当前训练情况。
                # 在验证数据上的正确信息会有一个单独的程序完成。
                print("After %d training step(s),loss on training " "batch is %g." % (step, loss_value))
                # 保存当前的模型。注意这里给出了global_step参数,这样可以让每个模型文件名最后都加上训练的轮数,比如
                saver.save(
                    sess,os.path.join(MODEL_SAVE_PATH,MODEL_NAME),global_step=global_step
                )
def evaluate(images,labels):
    a=len(images)
    #定义输入输出的格式。
    x=tf.placeholder(
        tf.float32,
        [BATCH_SIZE,  # 第一维表示一个batch中样例的个数
            traffic_inference.IMAGE_SIZE,  # 第二维和第三维表示图片的尺寸。
            traffic_inference.IMAGE_SIZE,
            traffic_inference.NUM_CHANNELS],  # 第四维度表示图片的深度
            name='x-input'

    )
    y_= tf.placeholder(tf.float32, [None,traffic_inference.OUTPUT_NODE], name='y-input')

#直接用封装好的类来计算前向传播的结果,因为测试时候不关注正则化损失函数的值,所以这里用于计算正则化损失的函数被设置为None.
    y=traffic_inference.inference(x,False,None)

    #使用前向传播的结果计算正确率,如果需要对未知的样例进行分类,那么使用
    #tf.argmax(y,1)就可以得到输出样本的类别了。
    correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
    accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))   #类型转换函数

    #通过变量重命名来加载模型,这样在前向传播过程中就不需要调用滑动平均的函数来获取平均值了。
    # 这样就可以完全公用mnist_inference.py中的前向传播过程了
    variable_averages=tf.train.ExponentialMovingAverage(
        traffic_train.MOVING_AVERAGE_DECARY)
    variable_to_restore=variable_averages.variables_to_restore()#加载模型时候可以将影子变量映射到变量本身
    saver=tf.train.Saver(variable_to_restore)

    #每隔EVAL_INTERVAL_SECS秒掉哦那个一次计算正确率的过程已检测训练过程中正确率的变化


    with tf.Session() as sess:
        #tf.train.get_checkpoint_state函数会通过checkpoint文件自动找到目录中最新的文件
        ckpt = tf.train.get_checkpoint_state(traffic_train.MODEL_SAVE_PATH)
        b=0
        if ckpt and ckpt.model_checkpoint_path:
            # 加载模型
            saver.restore(sess, ckpt.model_checkpoint_path)
            # 通过文件名得到迭代的轮数。
            global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
        else:
            print('No checkpoint file found')
            return
        for i in range(201):
            start = (i * BATCH_SIZE) % a
            end = min(start + BATCH_SIZE, a)
            xs_batch = images[start:end]
            ys_batch = labels[start:end]
            reshaped_xs = np.reshape(xs_batch, (BATCH_SIZE,
                                                traffic_inference.IMAGE_SIZE,
                                                traffic_inference.IMAGE_SIZE,
                                                traffic_inference.NUM_CHANNELS))
            accuracy_score=sess.run(accuracy,feed_dict={x:reshaped_xs,y_:ys_batch.eval()})
            b=b+accuracy_score
        print(b/201)
def test_sample(image):

    image_data = tf.gfile.FastGFile(image, 'rb').read()
    decode_image = tf.image.decode_png(image_data, 3)

    decode_image = tf.image.convert_image_dtype(decode_image, tf.float32)

    image = tf.reshape(decode_image, (1, 28, 28, 3))

    test_logit = traffic_inference.inference(image,
                                             train=False,
                                             regularizer=None)
    probabilities = tf.nn.softmax(test_logit)
    correct_prediction = tf.argmax(test_logit, 1)
    saver = tf.train.Saver()

    with tf.Session() as sess:
        #tf.train.get_checkpoint_state函数会通过checkpoint文件自动找到目录中最新的文件

        ckpt = tf.train.get_checkpoint_state(traffic_train.MODEL_SAVE_PATH)

        if ckpt and ckpt.model_checkpoint_path:
            # 加载模型
            saver.restore(sess, ckpt.model_checkpoint_path)
            # 通过文件名得到迭代的轮数。
            print("加载模型成功:" + ckpt.model_checkpoint_path)

            global_step = ckpt.model_checkpoint_path.split('/')[-1].split(
                '-')[-1]

            probabilities, label = sess.run(
                [probabilities, correct_prediction])
            probability = probabilities[0][label]

            print(
                "After %s training step(s),validation label = %d, has %g probability"
                % (global_step, label, probability))

        else:
            print('失败!')