def _preprocess_fn(self, features, labels, mode): """Resize images and convert them from uint8 -> float32.""" if 'image' in features: ndim = len(features.image.shape) is_sequence = (ndim > 4) input_size = self._src_img_res target_size = self._crop_size features.original_image = features.image features.image = distortion.preprocess_image( features.image, mode, is_sequence, input_size, target_size) features.image = tf.image.convert_image_dtype( features.image, tf.float32) out_feature_spec = self.get_out_feature_specification(mode) if out_feature_spec.image.shape != features.image.shape: features.image = meta_tfdata.multi_batch_apply( tf.image.resize_images, 2, features.image, out_feature_spec.image.shape.as_list()[-3:-1]) if self._mixup_alpha > 0. and labels and mode == TRAIN: lmbda = tfp.distributions.Beta(self._mixup_alpha, self._mixup_alpha).sample() for key, x in features.items(): if x.dtype in FLOAT_DTYPES: features[key] = lmbda * x + (1 - lmbda) * tf.reverse( x, axis=[0]) for key, x in labels.items(): if x.dtype in FLOAT_DTYPES: labels[key] = lmbda * x + (1 - lmbda) * tf.reverse( x, axis=[0]) return features, labels
def _preprocess_fn(self, features, labels, mode): """Resize images and convert them from uint8 -> float32.""" ndim = len(features.image.shape) is_sequence = (ndim > 4) input_size = self._src_img_res target_size = self._crop_size features.image = distortion.preprocess_image(features.image, mode, is_sequence, input_size, target_size) out_feature_spec = self.get_out_feature_specification(mode) if out_feature_spec.image.shape != features.image.shape: features.image = meta_tfdata.multi_batch_apply( tf.image.resize_images, 2, tf.image.convert_image_dtype(features.image, tf.float32), out_feature_spec.image.shape.as_list()[-3:-1]) return features, labels