def test_random_crop_resize(self): resized_frames_1 = preprocess_ops_3d.random_crop_resize( self._frames, 256, 256, 6, 3, (0.5, 2), (0.3, 1)) resized_frames_2 = preprocess_ops_3d.random_crop_resize( self._frames, 224, 224, 6, 3, (0.5, 2), (0.3, 1)) resized_frames_3 = preprocess_ops_3d.random_crop_resize( self._frames, 256, 256, 6, 3, (0.8, 1.2), (0.3, 1)) resized_frames_4 = preprocess_ops_3d.random_crop_resize( self._frames, 256, 256, 6, 3, (0.5, 2), (0.1, 1)) self.assertAllEqual(resized_frames_1.shape, (6, 256, 256, 3)) self.assertAllEqual(resized_frames_2.shape, (6, 224, 224, 3)) self.assertAllEqual(resized_frames_3.shape, (6, 256, 256, 3)) self.assertAllEqual(resized_frames_4.shape, (6, 256, 256, 3))
def _process_image(image: tf.Tensor, is_training: bool = True, is_ssl: bool = False, num_frames: int = 32, stride: int = 1, num_test_clips: int = 1, min_resize: int = 224, crop_size: int = 200, zero_centering_image: bool = False, seed: Optional[int] = None) -> tf.Tensor: """Processes a serialized image tensor. Args: image: Input Tensor of shape [timesteps] and type tf.string of serialized frames. is_training: Whether or not in training mode. If True, random sample, crop and left right flip is used. is_ssl: Whether or not in self-supervised pre-training mode. num_frames: Number of frames per subclip. stride: Temporal stride to sample frames. num_test_clips: Number of test clips (1 by default). If more than 1, this will sample multiple linearly spaced clips within each video at test time. If 1, then a single clip in the middle of the video is sampled. The clips are aggreagated in the batch dimension. min_resize: Frames are resized so that min(height, width) is min_resize. crop_size: Final size of the frame after cropping the resized frames. Both height and width are the same. zero_centering_image: If True, frames are normalized to values in [-1, 1]. If False, values in [0, 1]. seed: A deterministic seed to use when sampling. Returns: Processed frames. Tensor of shape [num_frames * num_test_clips, crop_size, crop_size, 3]. """ # Validate parameters. if is_training and num_test_clips != 1: logging.warning( '`num_test_clips` %d is ignored since `is_training` is `True`.', num_test_clips) # Temporal sampler. if is_training: # Sampler for training. if is_ssl: # Sample two clips from linear decreasing distribution. image = video_ssl_preprocess_ops.sample_ssl_sequence( image, num_frames, True, stride) else: # Sample random clip. image = preprocess_ops_3d.sample_sequence(image, num_frames, True, stride) else: # Sampler for evaluation. if num_test_clips > 1: # Sample linspace clips. image = preprocess_ops_3d.sample_linspace_sequence( image, num_test_clips, num_frames, stride) else: # Sample middle clip. image = preprocess_ops_3d.sample_sequence(image, num_frames, False, stride) # Decode JPEG string to tf.uint8. image = preprocess_ops_3d.decode_jpeg(image, 3) if is_training: # Standard image data augmentation: random resized crop and random flip. if is_ssl: image_1, image_2 = tf.split(image, num_or_size_splits=2, axis=0) image_1 = preprocess_ops_3d.random_crop_resize( image_1, crop_size, crop_size, num_frames, 3, (0.5, 2), (0.3, 1)) image_1 = preprocess_ops_3d.random_flip_left_right(image_1, seed) image_2 = preprocess_ops_3d.random_crop_resize( image_2, crop_size, crop_size, num_frames, 3, (0.5, 2), (0.3, 1)) image_2 = preprocess_ops_3d.random_flip_left_right(image_2, seed) else: image = preprocess_ops_3d.random_crop_resize( image, crop_size, crop_size, num_frames, 3, (0.5, 2), (0.3, 1)) image = preprocess_ops_3d.random_flip_left_right(image, seed) else: # Resize images (resize happens only if necessary to save compute). image = preprocess_ops_3d.resize_smallest(image, min_resize) # Three-crop of the frames. image = preprocess_ops_3d.crop_image(image, min_resize, min_resize, False, 3) # Cast the frames in float32, normalizing according to zero_centering_image. if is_training and is_ssl: image_1 = preprocess_ops_3d.normalize_image(image_1, zero_centering_image) image_2 = preprocess_ops_3d.normalize_image(image_2, zero_centering_image) else: image = preprocess_ops_3d.normalize_image(image, zero_centering_image) # Self-supervised pre-training augmentations. if is_training and is_ssl: # Temporally consistent color jittering. image_1 = video_ssl_preprocess_ops.random_color_jitter_3d(image_1) image_2 = video_ssl_preprocess_ops.random_color_jitter_3d(image_2) # Temporally consistent gaussian blurring. image_1 = video_ssl_preprocess_ops.random_blur_3d( image_1, num_frames, crop_size, crop_size) image_2 = video_ssl_preprocess_ops.random_blur_3d( image_2, num_frames, crop_size, crop_size) image = tf.concat([image_1, image_2], axis=0) return image
def process_image(image: tf.Tensor, is_training: bool = True, num_frames: int = 32, stride: int = 1, random_stride_range: int = 0, num_test_clips: int = 1, min_resize: int = 256, crop_size: int = 224, num_crops: int = 1, zero_centering_image: bool = False, min_aspect_ratio: float = 0.5, max_aspect_ratio: float = 2, min_area_ratio: float = 0.49, max_area_ratio: float = 1.0, seed: Optional[int] = None) -> tf.Tensor: """Processes a serialized image tensor. Args: image: Input Tensor of shape [timesteps] and type tf.string of serialized frames. is_training: Whether or not in training mode. If True, random sample, crop and left right flip is used. num_frames: Number of frames per subclip. stride: Temporal stride to sample frames. random_stride_range: An int indicating the min and max bounds to uniformly sample different strides from the video. E.g., a value of 1 with stride=2 will uniformly sample a stride in {1, 2, 3} for each video in a batch. Only used enabled training for the purposes of frame-rate augmentation. Defaults to 0, which disables random sampling. num_test_clips: Number of test clips (1 by default). If more than 1, this will sample multiple linearly spaced clips within each video at test time. If 1, then a single clip in the middle of the video is sampled. The clips are aggreagated in the batch dimension. min_resize: Frames are resized so that min(height, width) is min_resize. crop_size: Final size of the frame after cropping the resized frames. Both height and width are the same. num_crops: Number of crops to perform on the resized frames. zero_centering_image: If True, frames are normalized to values in [-1, 1]. If False, values in [0, 1]. min_aspect_ratio: The minimum aspect range for cropping. max_aspect_ratio: The maximum aspect range for cropping. min_area_ratio: The minimum area range for cropping. max_area_ratio: The maximum area range for cropping. seed: A deterministic seed to use when sampling. Returns: Processed frames. Tensor of shape [num_frames * num_test_clips, crop_size, crop_size, 3]. """ # Validate parameters. if is_training and num_test_clips != 1: logging.warning( '`num_test_clips` %d is ignored since `is_training` is `True`.', num_test_clips) if random_stride_range < 0: raise ValueError('Random stride range should be >= 0, got {}'.format( random_stride_range)) # Temporal sampler. if is_training: if random_stride_range > 0: # Uniformly sample different frame-rates stride = tf.random.uniform([], tf.maximum(stride - random_stride_range, 1), stride + random_stride_range, dtype=tf.int32) # Sample random clip. image = preprocess_ops_3d.sample_sequence(image, num_frames, True, stride, seed) elif num_test_clips > 1: # Sample linspace clips. image = preprocess_ops_3d.sample_linspace_sequence( image, num_test_clips, num_frames, stride) else: # Sample middle clip. image = preprocess_ops_3d.sample_sequence(image, num_frames, False, stride) # Decode JPEG string to tf.uint8. image = preprocess_ops_3d.decode_jpeg(image, 3) if is_training: # Standard image data augmentation: random resized crop and random flip. image = preprocess_ops_3d.random_crop_resize( image, crop_size, crop_size, num_frames, 3, (min_aspect_ratio, max_aspect_ratio), (min_area_ratio, max_area_ratio)) image = preprocess_ops_3d.random_flip_left_right(image, seed) else: # Resize images (resize happens only if necessary to save compute). image = preprocess_ops_3d.resize_smallest(image, min_resize) # Crop of the frames. image = preprocess_ops_3d.crop_image(image, crop_size, crop_size, False, num_crops) # Cast the frames in float32, normalizing according to zero_centering_image. return preprocess_ops_3d.normalize_image(image, zero_centering_image)