Пример #1
0
def main(argv=None):  # pylint: disable=unused-argument
    args = parse_args(argv)
    cifar10.maybe_download_and_extract(args.data_dir)
    if tf.gfile.Exists(args.train_dir):
        tf.gfile.DeleteRecursively(args.train_dir)
    tf.gfile.MakeDirs(args.train_dir)
    train(args)
Пример #2
0
def main(argv=None):  # pylint: disable=unused-argument
    cifar10.maybe_download_and_extract()
    if gfile.Exists(FLAGS.train_dir):
        gfile.DeleteRecursively(FLAGS.train_dir)
    gfile.MakeDirs(FLAGS.train_dir)
    train()
Пример #3
0
def main(argv=None):  # pylint: disable=unused-argument
  cifar10.maybe_download_and_extract()
  if tf.gfile.Exists(FLAGS.eval_dir):
    tf.gfile.DeleteRecursively(FLAGS.eval_dir)
  tf.gfile.MakeDirs(FLAGS.eval_dir)
  evaluate()
Пример #4
0
FLAGS = flags.FLAGS


# define variable_with_weight_loss
# 和之前定义的weight有所不同,
# 这里定义附带loss的weight,通过权重惩罚避免部分权重系数过大,导致overfitting
def variable_with_weight_loss(shape, stddev, w1):
    var = tf.Variable(tf.truncated_normal(shape, stddev=stddev))
    if w1 is not None:
        weight_loss = tf.multiply(tf.nn.l2_loss(var), w1, name='weight_loss')
        tf.add_to_collection('losses', weight_loss)
    return var


# 下载数据集 - 调用cifar10函数下载并解压
cifar10.maybe_download_and_extract()
# 注意路径
cifar_dir = '.\\cifar10_data\\cifar-10-batches-bin'

# 采用 data augmentation进行数据处理
# 生成训练数据,训练数据通过cifar10_input的distort变化
images_train, labels_train = cifar10_input.distorted_inputs(
    data_dir=cifar_dir, batch_size=FLAGS.batch_size)
# 测试数据(eval_data 测试数据)
images_test, labels_test = cifar10_input.inputs(eval_data=True,
                                                data_dir=cifar_dir,
                                                batch_size=FLAGS.batch_size)

# 创建输入数据,采用 placeholder
x_input = tf.placeholder(tf.float32, [FLAGS.batch_size, 24, 24, 3])
y_input = tf.placeholder(tf.int32, [FLAGS.batch_size])
Пример #5
0
from cifar import cifar10
from cifar import cifar10_input
max_steps = 3000
batch_size = 128
data_dir = '/tmp/cifar10_data/cifar-10-batches-bin'


def variable_with_weight_loss(shape, stddev, w1):  # w1是用来控制L2 loss的大小
    var = tf.Variable(tf.truncated_normal(shape, stddev=stddev))  #使用截断正态分布初始化
    if w1 is not None:
        weight_loss = tf.multiply(tf.nn.l2_loss(var), w1, name='weight_loss')
        tf.add_to_collection('losses', weight_loss)
    return var


cifar10.maybe_download_and_extract()  #下载数据集

images_train, labels_train = cifar10_input.distorted_inputs(  #产生训练数据
    data_dir=data_dir, batch_size=batch_size)

images_test, labels_test = cifar10_input.inputs(
    eval_data=True,  #生成测试数据
    data_dir=data_dir,
    batch_size=batch_size)
images_holder = tf.placeholder(tf.float32, [batch_size, 24, 24, 3])
label_holder = tf.placeholder(tf.int32, [batch_size])

####卷基层1####5x5 3通道 64核
weight1 = variable_with_weight_loss(shape=[5, 5, 3, 64], stddev=5e-2, w1=0.0)
kernel1 = tf.nn.conv2d(images_holder, weight1, [1, 1, 1, 1], padding='SAME')
bias1 = tf.Variable(tf.constant(0.0, shape=[64]))