예제 #1
0
    def test_sample_linspace_sequence(self):
        sequence = tf.range(100)
        sampled_seq_1 = preprocess_ops_3d.sample_linspace_sequence(
            sequence, 10, 10, 1)
        sampled_seq_2 = preprocess_ops_3d.sample_linspace_sequence(
            sequence, 7, 10, 1)
        sampled_seq_3 = preprocess_ops_3d.sample_linspace_sequence(
            sequence, 7, 5, 2)
        sampled_seq_4 = preprocess_ops_3d.sample_linspace_sequence(
            sequence, 101, 1, 1)

        self.assertAllEqual(sampled_seq_1, range(100))
        # [0, 1, 2, 3, 4, ..., 8, 9, 15, 16, ..., 97, 98, 99]
        self.assertAllEqual(
            sampled_seq_2,
            [15 * i + j for i, j in itertools.product(range(7), range(10))])
        # [0, 2, 4, 6, 8, 15, 17, 19, ..., 96, 98]
        self.assertAllEqual(
            sampled_seq_3,
            [15 * i + 2 * j for i, j in itertools.product(range(7), range(5))])
        self.assertAllEqual(sampled_seq_4, [0] + list(range(100)))
예제 #2
0
def _process_image(image: tf.Tensor,
                   is_training: bool = True,
                   is_ssl: bool = False,
                   num_frames: int = 32,
                   stride: int = 1,
                   num_test_clips: int = 1,
                   min_resize: int = 256,
                   crop_size: int = 224,
                   num_crops: int = 1,
                   zero_centering_image: bool = False,
                   seed: Optional[int] = None) -> tf.Tensor:
    """Processes a serialized image tensor.

  Args:
    image: Input Tensor of shape [timesteps] and type tf.string of serialized
      frames.
    is_training: Whether or not in training mode. If True, random sample, crop
      and left right flip is used.
    is_ssl: Whether or not in self-supervised pre-training mode.
    num_frames: Number of frames per subclip.
    stride: Temporal stride to sample frames.
    num_test_clips: Number of test clips (1 by default). If more than 1, this
      will sample multiple linearly spaced clips within each video at test time.
      If 1, then a single clip in the middle of the video is sampled. The clips
      are aggreagated in the batch dimension.
    min_resize: Frames are resized so that min(height, width) is min_resize.
    crop_size: Final size of the frame after cropping the resized frames. Both
      height and width are the same.
    num_crops: Number of crops to perform on the resized frames.
    zero_centering_image: If True, frames are normalized to values in [-1, 1].
      If False, values in [0, 1].
    seed: A deterministic seed to use when sampling.

  Returns:
    Processed frames. Tensor of shape
      [num_frames * num_test_clips, crop_size, crop_size, 3].
  """
    # Validate parameters.
    if is_training and num_test_clips != 1:
        logging.warning(
            '`num_test_clips` %d is ignored since `is_training` is `True`.',
            num_test_clips)

    # Temporal sampler.
    if is_training:
        # Sampler for training.
        if is_ssl:
            # Sample two clips from linear decreasing distribution.
            image = video_ssl_preprocess_ops.sample_ssl_sequence(
                image, num_frames, True, stride)
        else:
            # Sample random clip.
            image = preprocess_ops_3d.sample_sequence(image, num_frames, True,
                                                      stride)

    else:
        # Sampler for evaluation.
        if num_test_clips > 1:
            # Sample linspace clips.
            image = preprocess_ops_3d.sample_linspace_sequence(
                image, num_test_clips, num_frames, stride)
        else:
            # Sample middle clip.
            image = preprocess_ops_3d.sample_sequence(image, num_frames, False,
                                                      stride)

    # Decode JPEG string to tf.uint8.
    image = preprocess_ops_3d.decode_jpeg(image, 3)

    if is_training:
        # Standard image data augmentation: random resized crop and random flip.
        if is_ssl:
            image_1, image_2 = tf.split(image, num_or_size_splits=2, axis=0)
            image_1 = preprocess_ops_3d.random_crop_resize(
                image_1, crop_size, crop_size, num_frames, 3, (0.5, 2),
                (0.3, 1))
            image_1 = preprocess_ops_3d.random_flip_left_right(image_1, seed)
            image_2 = preprocess_ops_3d.random_crop_resize(
                image_2, crop_size, crop_size, num_frames, 3, (0.5, 2),
                (0.3, 1))
            image_2 = preprocess_ops_3d.random_flip_left_right(image_2, seed)

        else:
            image = preprocess_ops_3d.random_crop_resize(
                image, crop_size, crop_size, num_frames, 3, (0.5, 2), (0.3, 1))
            image = preprocess_ops_3d.random_flip_left_right(image, seed)
    else:
        # Resize images (resize happens only if necessary to save compute).
        image = preprocess_ops_3d.resize_smallest(image, min_resize)
        # Three-crop of the frames.
        image = preprocess_ops_3d.crop_image(image, crop_size, crop_size,
                                             False, num_crops)

    # Cast the frames in float32, normalizing according to zero_centering_image.
    if is_training and is_ssl:
        image_1 = preprocess_ops_3d.normalize_image(image_1,
                                                    zero_centering_image)
        image_2 = preprocess_ops_3d.normalize_image(image_2,
                                                    zero_centering_image)

    else:
        image = preprocess_ops_3d.normalize_image(image, zero_centering_image)

    # Self-supervised pre-training augmentations.
    if is_training and is_ssl:
        # Temporally consistent color jittering.
        image_1 = video_ssl_preprocess_ops.random_color_jitter_3d(image_1)
        image_2 = video_ssl_preprocess_ops.random_color_jitter_3d(image_2)
        # Temporally consistent gaussian blurring.
        image_1 = video_ssl_preprocess_ops.random_blur(image_1, crop_size,
                                                       crop_size, 1.0)
        image_2 = video_ssl_preprocess_ops.random_blur(image_2, crop_size,
                                                       crop_size, 0.1)
        image_2 = video_ssl_preprocess_ops.random_solarization(image_2)
        image = tf.concat([image_1, image_2], axis=0)
        image = tf.clip_by_value(image, 0., 1.)

    return image
예제 #3
0
def process_image(image: tf.Tensor,
                  is_training: bool = True,
                  num_frames: int = 32,
                  stride: int = 1,
                  random_stride_range: int = 0,
                  num_test_clips: int = 1,
                  min_resize: int = 256,
                  crop_size: int = 224,
                  num_crops: int = 1,
                  zero_centering_image: bool = False,
                  min_aspect_ratio: float = 0.5,
                  max_aspect_ratio: float = 2,
                  min_area_ratio: float = 0.49,
                  max_area_ratio: float = 1.0,
                  augmenter: Optional[augment.ImageAugment] = None,
                  seed: Optional[int] = None) -> tf.Tensor:
    """Processes a serialized image tensor.

  Args:
    image: Input Tensor of shape [timesteps] and type tf.string of serialized
      frames.
    is_training: Whether or not in training mode. If True, random sample, crop
      and left right flip is used.
    num_frames: Number of frames per subclip.
    stride: Temporal stride to sample frames.
    random_stride_range: An int indicating the min and max bounds to uniformly
      sample different strides from the video. E.g., a value of 1 with stride=2
      will uniformly sample a stride in {1, 2, 3} for each video in a batch.
      Only used enabled training for the purposes of frame-rate augmentation.
      Defaults to 0, which disables random sampling.
    num_test_clips: Number of test clips (1 by default). If more than 1, this
      will sample multiple linearly spaced clips within each video at test time.
      If 1, then a single clip in the middle of the video is sampled. The clips
      are aggreagated in the batch dimension.
    min_resize: Frames are resized so that min(height, width) is min_resize.
    crop_size: Final size of the frame after cropping the resized frames. Both
      height and width are the same.
    num_crops: Number of crops to perform on the resized frames.
    zero_centering_image: If True, frames are normalized to values in [-1, 1].
      If False, values in [0, 1].
    min_aspect_ratio: The minimum aspect range for cropping.
    max_aspect_ratio: The maximum aspect range for cropping.
    min_area_ratio: The minimum area range for cropping.
    max_area_ratio: The maximum area range for cropping.
    augmenter: Image augmenter to distort each image.
    seed: A deterministic seed to use when sampling.

  Returns:
    Processed frames. Tensor of shape
      [num_frames * num_test_clips, crop_size, crop_size, 3].
  """
    # Validate parameters.
    if is_training and num_test_clips != 1:
        logging.warning(
            '`num_test_clips` %d is ignored since `is_training` is `True`.',
            num_test_clips)

    if random_stride_range < 0:
        raise ValueError('Random stride range should be >= 0, got {}'.format(
            random_stride_range))

    # Temporal sampler.
    if is_training:
        if random_stride_range > 0:
            # Uniformly sample different frame-rates
            stride = tf.random.uniform([],
                                       tf.maximum(stride - random_stride_range,
                                                  1),
                                       stride + random_stride_range,
                                       dtype=tf.int32)

        # Sample random clip.
        image = preprocess_ops_3d.sample_sequence(image, num_frames, True,
                                                  stride, seed)
    elif num_test_clips > 1:
        # Sample linspace clips.
        image = preprocess_ops_3d.sample_linspace_sequence(
            image, num_test_clips, num_frames, stride)
    else:
        # Sample middle clip.
        image = preprocess_ops_3d.sample_sequence(image, num_frames, False,
                                                  stride)

    # Decode JPEG string to tf.uint8.
    if image.dtype == tf.string:
        image = preprocess_ops_3d.decode_jpeg(image, 3)

    if is_training:
        # Standard image data augmentation: random resized crop and random flip.
        image = preprocess_ops_3d.random_crop_resize(
            image, crop_size, crop_size, num_frames, 3,
            (min_aspect_ratio, max_aspect_ratio),
            (min_area_ratio, max_area_ratio))
        image = preprocess_ops_3d.random_flip_left_right(image, seed)

        if augmenter is not None:
            image = augmenter.distort(image)
    else:
        # Resize images (resize happens only if necessary to save compute).
        image = preprocess_ops_3d.resize_smallest(image, min_resize)
        # Crop of the frames.
        image = preprocess_ops_3d.crop_image(image, crop_size, crop_size,
                                             False, num_crops)

    # Cast the frames in float32, normalizing according to zero_centering_image.
    return preprocess_ops_3d.normalize_image(image, zero_centering_image)