示例#1
0
    def wrapper_map_probe_v2(self, tfrecord):
        """tf.data.Dataset map function for probe data v2.

    Args:
      tfrecord: serilized by tf.data.Dataset.

    Returns:
      A map function
    """
        def _extract_fn(tfrecord):
            """Extracts the functions."""

            features = {
                'image/encoded': tf.FixedLenFeature([], tf.string),
                'image/label': tf.FixedLenFeature([], tf.int64)
            }
            example = tf.parse_single_example(tfrecord, features)
            image, label = example['image/encoded'], tf.cast(
                example['image/label'], dtype=tf.int32)

            return [image, label]

        image_bytes, label = _extract_fn(tfrecord)
        label = tf.cast(label, tf.int64)

        image = imagenet_preprocess_image(image_bytes,
                                          is_training=True,
                                          image_size=self.image_size)

        return image, label
示例#2
0
 def _func(data):
     image, label = data['image'], data['label']
     image_bytes = tf.image.encode_jpeg(image, name='encode_jpeg')
     image = imagenet_preprocess_image(image_bytes,
                                       is_training=train,
                                       image_size=self.image_size)
     return image, label, data['id'], label  # Last field is useless
示例#3
0
 def _func(data):
     img, label = data['image'], data['label']
     image_bytes = tf.image.encode_jpeg(img)
     image_1 = imagenet_preprocess_image(image_bytes,
                                         is_training=train,
                                         image_size=self.image_size)
     if train:
         image_2 = imagenet_preprocess_image(image_bytes,
                                             is_training=train,
                                             image_size=self.image_size,
                                             autoaugment_name='v0',
                                             use_cutout=True)
         images = tf.concat(
             [tf.expand_dims(image_1, 0),
              tf.expand_dims(image_2, 0)],
             axis=0)
     else:
         images = image_1
     return images, label