Exemplo n.º 1
0
    def test_random_erase_replaces_some_pixels(self):
        image = tf.zeros((224, 224, 3), dtype=tf.float32)
        augmenter = augment.RandomErasing(probability=1., max_count=10)

        aug_image = augmenter.distort(image)

        self.assertEqual((224, 224, 3), aug_image.shape)
        self.assertNotEqual(0, tf.reduce_max(aug_image))
Exemplo n.º 2
0
    def __init__(self,
                 output_size: List[int],
                 num_classes: float,
                 image_field_key: str = DEFAULT_IMAGE_FIELD_KEY,
                 label_field_key: str = DEFAULT_LABEL_FIELD_KEY,
                 decode_jpeg_only: bool = True,
                 aug_rand_hflip: bool = True,
                 aug_crop: Optional[bool] = True,
                 aug_type: Optional[common.Augmentation] = None,
                 color_jitter: float = 0.,
                 random_erasing: Optional[common.RandomErasing] = None,
                 is_multilabel: bool = False,
                 dtype: str = 'float32'):
        """Initializes parameters for parsing annotations in the dataset.

    Args:
      output_size: `Tensor` or `list` for [height, width] of output image. The
        output_size should be divided by the largest feature stride 2^max_level.
      num_classes: `float`, number of classes.
      image_field_key: `str`, the key name to encoded image in tf.Example.
      label_field_key: `str`, the key name to label in tf.Example.
      decode_jpeg_only: `bool`, if True, only JPEG format is decoded, this is
        faster than decoding other types. Default is True.
      aug_rand_hflip: `bool`, if True, augment training with random
        horizontal flip.
      aug_crop: `bool`, if True, perform random cropping during training and
        center crop during validation.
      aug_type: An optional Augmentation object to choose from AutoAugment and
        RandAugment.
      color_jitter: Magnitude of color jitter. If > 0, the value is used to
        generate random scale factor for brightness, contrast and saturation.
        See `preprocess_ops.color_jitter` for more details.
      random_erasing: if not None, augment input image by random erasing. See
        `augment.RandomErasing` for more details.
      is_multilabel: A `bool`, whether or not each example has multiple labels.
      dtype: `str`, cast output image in dtype. It can be 'float32', 'float16',
        or 'bfloat16'.
    """
        self._output_size = output_size
        self._aug_rand_hflip = aug_rand_hflip
        self._aug_crop = aug_crop
        self._num_classes = num_classes
        self._image_field_key = image_field_key
        if dtype == 'float32':
            self._dtype = tf.float32
        elif dtype == 'float16':
            self._dtype = tf.float16
        elif dtype == 'bfloat16':
            self._dtype = tf.bfloat16
        else:
            raise ValueError('dtype {!r} is not supported!'.format(dtype))
        if aug_type:
            if aug_type.type == 'autoaug':
                self._augmenter = augment.AutoAugment(
                    augmentation_name=aug_type.autoaug.augmentation_name,
                    cutout_const=aug_type.autoaug.cutout_const,
                    translate_const=aug_type.autoaug.translate_const)
            elif aug_type.type == 'randaug':
                self._augmenter = augment.RandAugment(
                    num_layers=aug_type.randaug.num_layers,
                    magnitude=aug_type.randaug.magnitude,
                    cutout_const=aug_type.randaug.cutout_const,
                    translate_const=aug_type.randaug.translate_const,
                    prob_to_apply=aug_type.randaug.prob_to_apply,
                    exclude_ops=aug_type.randaug.exclude_ops)
            else:
                raise ValueError(
                    'Augmentation policy {} not supported.'.format(
                        aug_type.type))
        else:
            self._augmenter = None
        self._label_field_key = label_field_key
        self._color_jitter = color_jitter
        if random_erasing:
            self._random_erasing = augment.RandomErasing(
                probability=random_erasing.probability,
                min_area=random_erasing.min_area,
                max_area=random_erasing.max_area,
                min_aspect=random_erasing.min_aspect,
                max_aspect=random_erasing.max_aspect,
                min_count=random_erasing.min_count,
                max_count=random_erasing.max_count,
                trials=random_erasing.trials)
        else:
            self._random_erasing = None
        self._is_multilabel = is_multilabel
        self._decode_jpeg_only = decode_jpeg_only