def _randaugment_inner_for_loop(_, in_args): """ Loop body for for randougment. Args: i: loop iteration in_args: loop body arguments Returns: updated loop arguments """ (image, geometric_transforms, random_key, available_ops, op_probs, magnitude, cutout_const, translate_const, join_transforms, default_replace_value) = in_args random_keys = random.split(random_key, num=8) random_key = random_keys[0] # keep for next iteration op_to_select = random.choice(random_keys[1], available_ops, p=op_probs) mask_value = jnp.where(default_replace_value > 0, jnp.ones([image.shape[-1]]) * default_replace_value, random.randint(random_keys[2], [image.shape[-1]], minval=-1, maxval=256)) random_magnitude = random.uniform(random_keys[3], [], minval=0., maxval=magnitude) cutout_mask = color_transforms.get_random_cutout_mask( random_keys[4], image.shape, cutout_const) translate_vals = (random.uniform(random_keys[5], [], minval=0.0, maxval=1.0) * translate_const, random.uniform(random_keys[6], [], minval=0.0, maxval=1.0) * translate_const) negate = random.randint(random_keys[7], [], minval=0, maxval=2).astype('bool') args = level_to_arg(cutout_mask, translate_vals, negate, random_magnitude, mask_value) if DEBUG: print(op_to_select, args[op_to_select]) image, geometric_transform = _apply_ops(image, args, op_to_select) image, geometric_transform = jax.lax.cond( jnp.logical_or(join_transforms, jnp.all( jnp.not_equal(geometric_transform, jnp.identity(4)))), lambda op: (op[0], op[1]), lambda op: (transforms.apply_transform(op[0], op[1], mask_value=mask_value), jnp.identity(4)), (image, geometric_transform) ) geometric_transforms = jnp.matmul(geometric_transforms, geometric_transform) return(image, geometric_transforms, random_key, available_ops, op_probs, magnitude, cutout_const, translate_const, join_transforms, default_replace_value)
def distort_image_with_randaugment(image, num_layers, magnitude, random_key, cutout_const=40, translate_const=50.0, default_replace_value=None, available_ops=DEFAULT_OPS, op_probs=DEFAULT_PROBS, join_transforms=False): """Applies the RandAugment policy to `image`. RandAugment is from the paper https://arxiv.org/abs/1909.13719, Args: image: `Tensor` of shape [height, width, 3] representing an image. num_layers: Integer, the number of augmentation transformations to apply sequentially to an image. Represented as (N) in the paper. Usually best values will be in the range [1, 3]. magnitude: Integer, shared magnitude across all augmentation operations. Represented as (M) in the paper. Usually best values are in the range [5, 30]. random_key: random key to do random stuff join_transforms: reduce multiple transforms to one. Much more efficient but simpler. cutout_const: max cutout size int translate_const: maximum translation amount int default_replace_value: default replacement value for pixels outside of the image available_ops: available operations op_probs: probabilities of operations join_transforms: apply transformations immediately or join them Returns: The augmented version of `image`. """ geometric_transforms = jnp.identity(4) for_i_args = (image, geometric_transforms, random_key, available_ops, op_probs, magnitude, cutout_const, translate_const, join_transforms, default_replace_value) if DEBUG: # un-jitted for i in range(num_layers): for_i_args = _randaugment_inner_for_loop(i, for_i_args) else: # jitted for_i_args = jax.lax.fori_loop(0, num_layers, _randaugment_inner_for_loop, for_i_args) image, geometric_transforms = for_i_args[0], for_i_args[1] if join_transforms: replace_value = default_replace_value or random.randint(random_key, [image.shape[-1]], minval=0, maxval=256) image = transforms.apply_transform(image, geometric_transforms, mask_value=replace_value) return image
def test_scale(): factor = 3 inputs = jnp.pad(jnp.ones((1, 1, 4), dtype='uint8') * 255, ((1, 1), (1, 1), (0, 0)), constant_values=0) targets = jnp.ones_like(rgba_img) * 255 outputs = transforms.apply_transform(jnp.pad(jnp.ones( (1, 1, 4), dtype='uint8'), ((1, 1), (1, 1), (0, 0)), constant_values=0) * 255, transforms.scale_3d(scale_x=factor, scale_y=factor), bilinear=False) compare(inputs, targets, outputs)
def test_rotate90(): inputs = rgba_img targets = jnp.rot90(rgba_img, k=2) outputs = transforms.apply_transform(rgba_img, transforms.rotate90(n=2)) compare(inputs, targets, outputs)
def test_vertical_flip(): inputs = rgba_img targets = rgba_img[::-1] outputs = transforms.apply_transform( rgba_img, transforms.flip(horizontal=False, vertical=True)) compare(inputs, targets, outputs)