def read_single_example(example_string): """Parses the record string.""" return tf.parse_single_example( example_string, features={ 'image': tf.FixedLenFeature([], dtype=tf.string), 'label': tf.FixedLenFeature([], tf.int64) })
def process_example(example_string, image_size, data_augmentation=None): """Processes a single example string. Extracts and processes the image, and ignores the label. We assume that the image has three channels. Args: example_string: str, an Example protocol buffer. image_size: int, desired image size. The extracted image will be resized to `[image_size, image_size]`. data_augmentation: A DataAugmentation object with parameters for perturbing the images. Returns: image_rescaled: the image, resized to `image_size x image_size` and rescaled to [-1, 1]. Note that Gaussian data augmentation may cause values to go beyond this range. """ image_string = tf.parse_single_example(example_string, features={ 'image': tf.FixedLenFeature( [], dtype=tf.string), 'label': tf.FixedLenFeature([], tf.int64) })['image'] image_decoded = tf.image.decode_jpeg(image_string, channels=3) image_resized = tf.image.resize_images( image_decoded, [image_size, image_size], method=tf.image.ResizeMethod.BILINEAR, align_corners=True) image = 2 * (image_resized / 255.0 - 0.5) # Rescale to [-1, 1]. if data_augmentation is not None: if data_augmentation.enable_gaussian_noise: image = image + tf.random_normal( tf.shape(image)) * data_augmentation.gaussian_noise_std if data_augmentation.enable_jitter: j = data_augmentation.jitter_amount paddings = tf.constant([[j, j], [j, j], [0, 0]]) image = tf.pad(image, paddings, 'REFLECT') image = tf.image.random_crop(image, [image_size, image_size, 3]) return image
def __call__(self, example_string): """Processes a single example string. Extracts and processes the image, and ignores the label. We assume that the image has three channels. Args: example_string: str, an Example protocol buffer. Returns: image_rescaled: the image, resized to `image_size x image_size` and rescaled to [-1, 1]. Note that Gaussian data augmentation may cause values to go beyond this range. """ image_string = tf.parse_single_example( example_string, features={ 'image': tf.FixedLenFeature([], dtype=tf.string), 'label': tf.FixedLenFeature([], tf.int64) })['image'] image_decoded = tf.image.decode_image(image_string, channels=3) image_decoded.set_shape([None, None, 3]) image_resized = tf.image.resize_images( image_decoded, [self.image_size, self.image_size], method=tf.image.ResizeMethod.BILINEAR, align_corners=True) image_resized = tf.cast(image_resized, tf.float32) image = 2 * (image_resized / 255.0 - 0.5) # Rescale to [-1, 1]. if self.data_augmentation is not None: if self.data_augmentation.enable_gaussian_noise: image = image + tf.random_normal(tf.shape( image)) * self.data_augmentation.gaussian_noise_std if self.data_augmentation.enable_jitter: j = self.data_augmentation.jitter_amount paddings = tf.constant([[j, j], [j, j], [0, 0]]) image = tf.pad(image, paddings, 'REFLECT') image = tf.image.random_crop( image, [self.image_size, self.image_size, 3]) return image
def __call__(self, example_string): """Processes a single example string. Extracts and processes the feature, and ignores the label. Args: example_string: str, an Example protocol buffer. Returns: feat: The feature tensor. """ feat_string = tf.parse_single_example( example_string, features={ 'image/embedding': tf.FixedLenFeature([], dtype=tf.string), 'image/class/label': tf.FixedLenFeature([], tf.int64) })['image/embedding'] feat = tf.io.parse_tensor(feat_string, tf.float32) return feat