Exemplo n.º 1
0
def distorted_inputs():
    # what the f**k is going on in this function ??
    images, labels = cifar10_input.distorted_inputs(data_dir=data_dir+'/cifar-10-batches-bin/', batch_size=batch_size)
    if use_fp16:
        images = tf.cast(images, tf.float16)
        labels = tf.cast(labels, tf.float16)
    return images, labels
Exemplo n.º 2
0
Arquivo: slvgg.py Projeto: JayYip/GLN
def distorted_inputs(batch_size,
                     data_dir='../../cifardataset/cifar-10-batches-bin'):
    images, labels = cifar10_input.distorted_inputs(data_dir=data_dir,
                                                    batch_size=batch_size)
    images = tf.image.resize_images(
        tf.cast(images, tf.float32),
        tf.convert_to_tensor([64, 64], dtype=tf.int32))
    labels = tf.one_hot(tf.cast(labels, tf.int32), depth=10, dtype=tf.int32)
    return (images, labels)
Exemplo n.º 3
0
def distorted_inputs():
    """Construct distorted input for CIFAR training using the Reader ops.
  Returns:
    images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
    labels: Labels. 1D tensor of [batch_size] size.
  Raises:
    ValueError: If no data_dir
  """
    if not FLAGS.data_dir:
        raise ValueError('Please supply a data_dir')
    data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
    return cifar10_input.distorted_inputs(data_dir=data_dir,
                                          batch_size=FLAGS.batch_size)
Exemplo n.º 4
0
def distorted_inputs():
  """Construct distorted input for CIFAR training using the Reader ops.
  Returns:
    images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
    labels: Labels. 1D tensor of [batch_size] size.
  Raises:
    ValueError: If no data_dir
  """
  if not FLAGS.data_dir:
    raise ValueError('Please supply a data_dir')
  data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
  return cifar10_input.distorted_inputs(data_dir=data_dir,
                                        batch_size=FLAGS.batch_size)
Exemplo n.º 5
0
def distorted_inputs():
  if not FLAGS.data_dir:
    raise ValueError('Please supply a data_dir')
  data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
  return cifar10_input.distorted_inputs(data_dir=data_dir,
                                        batch_size=FLAGS.batch_size)
Exemplo n.º 6
0
    var = tf.Variable(tf.truncated_normal(shape, stddev=stddev), name=name)
    if wl is not None:
        weight_loss = tf.multiply(tf.nn.l2_loss(var), wl, name='weight_loss')
        tf.add_to_collection('losses', weight_loss)
    return var


# 初始化weight函数(附加L2的loss(L2正则化))
# L1正则:制造稀疏特征,大部分无用特征权重置0
# L2正则:特征权重不过大,特征权重比较平均
# 奥卡姆剃刀法则

cifar10.maybe_download_and_extract()
# 下载数据并解压展开
with tf.name_scope('Get'):
    images_train, labels_train = cifar10_input.distorted_inputs(data_dir=data_dir, batch_size=batch_size)
    '''
    distorted_inputs
    产生训练需要使用的数据(特征,label)
    进行数据增强操作(图片随机水平旋转、随机剪切、随机设置亮度和对比度、数据标准化)
    '''

    images_test, labels_test = cifar10_input.inputs(eval_data=True, data_dir=data_dir, batch_size=batch_size)
# inputs:生成测试数据(只需进行24X24裁剪+数据标准化)
with tf.name_scope('Inputs'):
    image_holder = tf.placeholder(tf.float32, [batch_size, 24, 24, 3], name='image_input')
    label_holder = tf.placeholder(tf.int32, [batch_size], name='label_input')
'''
创建输入数据的placeholder(特征+label)
(batch_size后面定义网格结构时用到->样本条数需预先设定)
剪裁后图片大小24X24
Exemplo n.º 7
0
 def get_cifar10_batch_aug(self):
   from tensorflow.models.image.cifar10 import cifar10_input
   images, labels = cifar10_input.distorted_inputs(self.data_dir, self.batch_size)
   
   return (images, labels)