예제 #1
0
    def parser(self, serialized_example):
        """Parses a single tf.Example into image and label tensors."""
        if self._test_small_sample:
            image_raw = serialized_example
            image_preprocessing = serialized_example
            label = tf.constant(0, tf.int32)
        else:
            features = tf.parse_single_example(
                serialized_example,
                features={
                    'image/encoded':
                    tf.FixedLenFeature([], tf.string, default_value=''),
                    'image/class/label': (tf.FixedLenFeature([], tf.int64)),
                })
            image_raw = tf.reshape(features['image/encoded'], shape=[])

            image_preprocessing = tf.image.decode_image(image_raw,
                                                        dtype=tf.float32)

            image_preprocessing = preprocess_image(image=image_preprocessing,
                                                   image_size=224,
                                                   is_training=False)

            if self._dataset_name == 'imagenet':
                # Subtract one so that labels are in [0, 1000).
                label = tf.cast(tf.reshape(features['image/class/label'],
                                           shape=[]),
                                dtype=tf.int32) - 1
            else:
                label = tf.cast(tf.reshape(features['image/class/label'],
                                           shape=[]),
                                dtype=tf.int32)

        return (image_raw, image_preprocessing, label)
예제 #2
0
  def parser(self, serialized_example):
    """Parses a single tf.Example into image and label tensors."""
    if self.test_small_sample:
      image = serialized_example
      label = tf.constant(0, tf.int32)
    else:
      features = tf.parse_single_example(
          serialized_example,
          features={
              'raw_image':
                  tf.FixedLenFeature((), tf.string, default_value=''),
              'height':
                  tf.FixedLenFeature([], dtype=tf.int64, default_value=-1),
              'width':
                  tf.FixedLenFeature([], dtype=tf.int64, default_value=-1),
              self.saliency_method:
                  tf.VarLenFeature(tf.float32),
              'label':
                  tf.FixedLenFeature([], dtype=tf.int64, default_value=-1),
              'prediction_class':
                  tf.FixedLenFeature([], dtype=tf.int64, default_value=-1)
          })
      image = tf.image.decode_image(features['raw_image'], 3)
      image = tf.image.convert_image_dtype(image, dtype=tf.float32)

      saliency_heatmap = tf.expand_dims(features[self.saliency_method].values,
                                        0)
      saliency_heatmap = tf.reshape(saliency_heatmap, IMAGE_DIMS)

      if self.transformation in ['modified_image', 'random_baseline']:
        # we apply test_time pre-processing to the raw image before modifying
        # according to the estimator ranking.
        image_preprocess = preprocess_image(
            image, image_size=IMAGE_DIMS[0], is_training=False)

        if self.transformation == 'modified_image':
          tf.logging.info('Computing feature importance estimate now...')
          image = compute_feature_ranking(
              input_image=image_preprocess,
              saliency_map=saliency_heatmap,
              threshold=self.threshold,
              global_mean=self.global_mean,
              rescale_heatmap=True,
              keep_information=self.keep_information,
              use_squared_value=self.use_squared_value)

        if self.transformation == 'random_baseline':
          tf.logging.info('generating a random baseline')
          image = random_ranking(
              input_image=image_preprocess,
              global_mean=self.global_mean,
              threshold=self.threshold,
              keep_information=self.keep_information)

      if self.mode == 'train':
        is_training = True
      else:
        is_training = False

      if self.transformation in ['random_baseline', 'modified_image']:
        tf.logging.info('starting pre-processing for training/eval')
        image = preprocess_image(
            image, image_size=IMAGE_DIMS[0], is_training=is_training)

      if self.transformation == 'raw_image':
        tf.logging.info('starting pre-processing for training/eval')
        image = preprocess_image(
            image, image_size=IMAGE_DIMS[0], is_training=is_training)

      label = tf.cast(tf.reshape(features['label'], shape=[]), dtype=tf.int32)

    return image, label