#coding:utf-8 from tutorials.image.cifar10 import cifar10_input from tutorials.image.cifar10 import cifar10 import tensorflow as tf import numpy as np import time max_steps = 3000 batch_size = 128 data_dir = '/tmp/cifar10_data/cifar-10-batches-bin' cifar10.maybe_download_and_extract() def interface(): mages_train, labels_train = cifar10_input.distorted_inputs( data_dir=data_dir, batch_size=batch_size) images_test, labels_test = cifar10_input.inputs(eval_data=True, data_dir=data_dir, batch_size=batch_size) image_holder = tf.placeholder(tf.float32, [batch_size, 24, 24, 3]) label_holder = tf.placeholder(tf.int32, [batch_size]) def variable_with_weight_loss(shape, stddev, w1): var = tf.Variable(tf.truncated_normal(shape, stddev)) if w1 is not None: #weight_loss = tf.matmul(tf.nn.l2_loss(var), w1, name='weight_loss') weight_loss = tf.multiply(tf.nn.l2_loss(var), w1, name='weight_loss') tf.add_to_collection('losses', weight_loss)
def main(argv=None): # pylint: disable=unused-argument cifar10.maybe_download_and_extract() if tf.gfile.Exists(FLAGS.eval_dir): tf.gfile.DeleteRecursively(FLAGS.eval_dir) tf.gfile.MakeDirs(FLAGS.eval_dir) evaluate()
# 定义初始化weights的函数,和之前一样依然使用tf.truncated_normal截断的正太分布来初始化权值 var = tf.Variable(tf.truncated_normal(shape, stddev=stddev)) if wl is not None: # 给weight加一个L2的loss,相当于做了一个L2的正则化处理 # 在机器学习中,不管是分类还是回归任务,都可能因为特征过多而导致过拟合,一般可以通过减少特征或者惩罚不重要特征的权重来缓解这个问题 # 但是通常我们并不知道该惩罚哪些特征的权重,而正则化就是帮助我们惩罚特征权重的,即特征的权重也会成为模型的损失函数的一部分 # 我们使用w1来控制L2 loss的大小 weight_loss = tf.multiply(tf.nn.l2_loss(var), wl, name='weight_loss') # 我们使用tf.add_to_collection把weight loss统一存到一个collection,这个collection名为"losses",它会在后面计算神经网络 # 总体loss时被用上 tf.add_to_collection("losses", weight_loss) return var # 下载cifar10类下载数据集,并解压,展开到其默认位置 cifar10.maybe_download_and_extract() # 使用cifar10_input类中的distorted_inputs函数产生训练需要使用的数据,包括特征及其对应的label,这里是封装好的tensor, # 每次执行都会生成一个batch_size的数量的样本。需要注意的是这里对数据进行了Data Augmentation数据增强 # 具体实现细节查看函数,其中数据增强操作包括随机水平翻转tf.image.random_flip_left_right() # 随机剪切一块24*24大小的图片tf.random_crop,随机设置亮度和对比度,tf.image.random_brightness、tf.image.random_contrast # 以及对数据进行标准化,白化 tf.image.per_image_standardization() 减去均值、除以方差,保证数据零均值,方差为1 images_train, labels_train = cifar10_input.distorted_inputs( data_dir=data_dir, batch_size=batch_size ) # 生成测试数据,不过这里不需要进行太多处理,不需要对图片进行翻转或修改亮度、对比度,不过需要裁剪图片正中间的24*24大小的区块, # 并进行数据标准化操作 images_test, labels_test = cifar10_input.inputs(eval_data=True, data_dir=data_dir, batch_size=batch_size) # 因为batch_size在之后定义网络结构时被用到了,所以数据尺寸中的第一个值即样本条数需要被预先设定,而不能像以前那样设置为None # 而数据尺寸中的图片尺寸为24*24即是剪裁后的大小,颜色通道数则设为3