def _preprocess_fn(self, features, labels, mode):
        """Resize images and convert them from uint8 -> float32."""
        if 'image' in features:
            ndim = len(features.image.shape)
            is_sequence = (ndim > 4)
            input_size = self._src_img_res
            target_size = self._crop_size
            features.original_image = features.image
            features.image = distortion.preprocess_image(
                features.image, mode, is_sequence, input_size, target_size)

            features.image = tf.image.convert_image_dtype(
                features.image, tf.float32)
            out_feature_spec = self.get_out_feature_specification(mode)
            if out_feature_spec.image.shape != features.image.shape:
                features.image = meta_tfdata.multi_batch_apply(
                    tf.image.resize_images, 2, features.image,
                    out_feature_spec.image.shape.as_list()[-3:-1])

        if self._mixup_alpha > 0. and labels and mode == TRAIN:
            lmbda = tfp.distributions.Beta(self._mixup_alpha,
                                           self._mixup_alpha).sample()
            for key, x in features.items():
                if x.dtype in FLOAT_DTYPES:
                    features[key] = lmbda * x + (1 - lmbda) * tf.reverse(
                        x, axis=[0])
            for key, x in labels.items():
                if x.dtype in FLOAT_DTYPES:
                    labels[key] = lmbda * x + (1 - lmbda) * tf.reverse(
                        x, axis=[0])
        return features, labels
    def _preprocess_fn(self, features, labels, mode):
        """Resize images and convert them from uint8 -> float32."""
        ndim = len(features.image.shape)
        is_sequence = (ndim > 4)
        input_size = self._src_img_res
        target_size = self._crop_size
        features.image = distortion.preprocess_image(features.image, mode,
                                                     is_sequence, input_size,
                                                     target_size)

        out_feature_spec = self.get_out_feature_specification(mode)
        if out_feature_spec.image.shape != features.image.shape:
            features.image = meta_tfdata.multi_batch_apply(
                tf.image.resize_images, 2,
                tf.image.convert_image_dtype(features.image, tf.float32),
                out_feature_spec.image.shape.as_list()[-3:-1])
        return features, labels