def wrapper_map_probe_v2(self, tfrecord): """tf.data.Dataset map function for probe data v2. Args: tfrecord: serilized by tf.data.Dataset. Returns: A map function """ def _extract_fn(tfrecord): """Extracts the functions.""" features = { 'image/encoded': tf.FixedLenFeature([], tf.string), 'image/label': tf.FixedLenFeature([], tf.int64) } example = tf.parse_single_example(tfrecord, features) image, label = example['image/encoded'], tf.cast( example['image/label'], dtype=tf.int32) return [image, label] image_bytes, label = _extract_fn(tfrecord) label = tf.cast(label, tf.int64) image = imagenet_preprocess_image(image_bytes, is_training=True, image_size=self.image_size) return image, label
def _func(data): image, label = data['image'], data['label'] image_bytes = tf.image.encode_jpeg(image, name='encode_jpeg') image = imagenet_preprocess_image(image_bytes, is_training=train, image_size=self.image_size) return image, label, data['id'], label # Last field is useless
def _func(data): img, label = data['image'], data['label'] image_bytes = tf.image.encode_jpeg(img) image_1 = imagenet_preprocess_image(image_bytes, is_training=train, image_size=self.image_size) if train: image_2 = imagenet_preprocess_image(image_bytes, is_training=train, image_size=self.image_size, autoaugment_name='v0', use_cutout=True) images = tf.concat( [tf.expand_dims(image_1, 0), tf.expand_dims(image_2, 0)], axis=0) else: images = image_1 return images, label