예제 #1
0
def read_single_example(example_string):
  """Parses the record string."""
  return tf.parse_single_example(
      example_string,
      features={
          'image': tf.FixedLenFeature([], dtype=tf.string),
          'label': tf.FixedLenFeature([], tf.int64)
      })
예제 #2
0
def process_example(example_string, image_size, data_augmentation=None):
    """Processes a single example string.

  Extracts and processes the image, and ignores the label. We assume that the
  image has three channels.

  Args:
    example_string: str, an Example protocol buffer.
    image_size: int, desired image size. The extracted image will be resized to
      `[image_size, image_size]`.
    data_augmentation: A DataAugmentation object with parameters for perturbing
      the images.

  Returns:
    image_rescaled: the image, resized to `image_size x image_size` and rescaled
      to [-1, 1]. Note that Gaussian data augmentation may cause values to
      go beyond this range.
  """
    image_string = tf.parse_single_example(example_string,
                                           features={
                                               'image':
                                               tf.FixedLenFeature(
                                                   [], dtype=tf.string),
                                               'label':
                                               tf.FixedLenFeature([], tf.int64)
                                           })['image']
    image_decoded = tf.image.decode_jpeg(image_string, channels=3)
    image_resized = tf.image.resize_images(
        image_decoded, [image_size, image_size],
        method=tf.image.ResizeMethod.BILINEAR,
        align_corners=True)
    image = 2 * (image_resized / 255.0 - 0.5)  # Rescale to [-1, 1].

    if data_augmentation is not None:
        if data_augmentation.enable_gaussian_noise:
            image = image + tf.random_normal(
                tf.shape(image)) * data_augmentation.gaussian_noise_std

        if data_augmentation.enable_jitter:
            j = data_augmentation.jitter_amount
            paddings = tf.constant([[j, j], [j, j], [0, 0]])
            image = tf.pad(image, paddings, 'REFLECT')
            image = tf.image.random_crop(image, [image_size, image_size, 3])

    return image
예제 #3
0
    def __call__(self, example_string):
        """Processes a single example string.

    Extracts and processes the image, and ignores the label. We assume that the
    image has three channels.

    Args:
      example_string: str, an Example protocol buffer.

    Returns:
      image_rescaled: the image, resized to `image_size x image_size` and
      rescaled to [-1, 1]. Note that Gaussian data augmentation may cause values
      to go beyond this range.
    """
        image_string = tf.parse_single_example(
            example_string,
            features={
                'image': tf.FixedLenFeature([], dtype=tf.string),
                'label': tf.FixedLenFeature([], tf.int64)
            })['image']
        image_decoded = tf.image.decode_image(image_string, channels=3)
        image_decoded.set_shape([None, None, 3])
        image_resized = tf.image.resize_images(
            image_decoded, [self.image_size, self.image_size],
            method=tf.image.ResizeMethod.BILINEAR,
            align_corners=True)
        image_resized = tf.cast(image_resized, tf.float32)
        image = 2 * (image_resized / 255.0 - 0.5)  # Rescale to [-1, 1].

        if self.data_augmentation is not None:
            if self.data_augmentation.enable_gaussian_noise:
                image = image + tf.random_normal(tf.shape(
                    image)) * self.data_augmentation.gaussian_noise_std

            if self.data_augmentation.enable_jitter:
                j = self.data_augmentation.jitter_amount
                paddings = tf.constant([[j, j], [j, j], [0, 0]])
                image = tf.pad(image, paddings, 'REFLECT')
                image = tf.image.random_crop(
                    image, [self.image_size, self.image_size, 3])

        return image
예제 #4
0
    def __call__(self, example_string):
        """Processes a single example string.

    Extracts and processes the feature, and ignores the label.

    Args:
      example_string: str, an Example protocol buffer.

    Returns:
      feat: The feature tensor.
    """
        feat_string = tf.parse_single_example(
            example_string,
            features={
                'image/embedding': tf.FixedLenFeature([], dtype=tf.string),
                'image/class/label': tf.FixedLenFeature([], tf.int64)
            })['image/embedding']

        feat = tf.io.parse_tensor(feat_string, tf.float32)

        return feat