Beispiel #1
0
def _randaugment_inner_for_loop(_, in_args):
    """
    Loop body for for randougment.
    Args:
        i: loop iteration
        in_args: loop body arguments

    Returns:
        updated loop arguments
    """
    (image, geometric_transforms, random_key, available_ops, op_probs,
     magnitude, cutout_const, translate_const, join_transforms,
     default_replace_value) = in_args
    random_keys = random.split(random_key, num=8)
    random_key = random_keys[0]  # keep for next iteration
    op_to_select = random.choice(random_keys[1], available_ops, p=op_probs)
    mask_value = jnp.where(default_replace_value > 0,
                           jnp.ones([image.shape[-1]]) * default_replace_value,
                           random.randint(random_keys[2],
                                          [image.shape[-1]],
                                          minval=-1, maxval=256))
    random_magnitude = random.uniform(random_keys[3], [], minval=0.,
                                      maxval=magnitude)
    cutout_mask = color_transforms.get_random_cutout_mask(
        random_keys[4],
        image.shape,
        cutout_const)

    translate_vals = (random.uniform(random_keys[5], [], minval=0.0,
                                     maxval=1.0) * translate_const,
                      random.uniform(random_keys[6], [], minval=0.0,
                                     maxval=1.0) * translate_const)
    negate = random.randint(random_keys[7], [], minval=0,
                            maxval=2).astype('bool')

    args = level_to_arg(cutout_mask, translate_vals, negate,
                        random_magnitude, mask_value)

    if DEBUG:
        print(op_to_select, args[op_to_select])

    image, geometric_transform = _apply_ops(image, args, op_to_select)

    image, geometric_transform = jax.lax.cond(
        jnp.logical_or(join_transforms, jnp.all(
            jnp.not_equal(geometric_transform, jnp.identity(4)))),
        lambda op: (op[0], op[1]),
        lambda op: (transforms.apply_transform(op[0],
                                               op[1],
                                               mask_value=mask_value),
                    jnp.identity(4)),
        (image, geometric_transform)
    )

    geometric_transforms = jnp.matmul(geometric_transforms, geometric_transform)
    return(image, geometric_transforms, random_key, available_ops, op_probs,
           magnitude, cutout_const, translate_const, join_transforms,
           default_replace_value)
Beispiel #2
0
def distort_image_with_randaugment(image,
                                   num_layers,
                                   magnitude,
                                   random_key,
                                   cutout_const=40,
                                   translate_const=50.0,
                                   default_replace_value=None,
                                   available_ops=DEFAULT_OPS,
                                   op_probs=DEFAULT_PROBS,
                                   join_transforms=False):
    """Applies the RandAugment policy to `image`.

    RandAugment is from the paper https://arxiv.org/abs/1909.13719,

    Args:
        image: `Tensor` of shape [height, width, 3] representing an image.
        num_layers: Integer, the number of augmentation transformations to apply
          sequentially to an image. Represented as (N) in the paper. Usually best
          values will be in the range [1, 3].
        magnitude: Integer, shared magnitude across all augmentation operations.
          Represented as (M) in the paper. Usually best values are in the range
          [5, 30].
        random_key: random key to do random stuff
        join_transforms: reduce multiple transforms to one. Much more efficient but simpler.
        cutout_const: max cutout size int
        translate_const: maximum translation amount int
        default_replace_value: default replacement value for pixels outside of the image
        available_ops: available operations
        op_probs: probabilities of operations
        join_transforms: apply transformations immediately or join them

    Returns:
        The augmented version of `image`.
    """

    geometric_transforms = jnp.identity(4)

    for_i_args = (image, geometric_transforms, random_key, available_ops, op_probs,
                  magnitude, cutout_const, translate_const, join_transforms, default_replace_value)

    if DEBUG:  # un-jitted
        for i in range(num_layers):
            for_i_args = _randaugment_inner_for_loop(i, for_i_args)
    else:  # jitted
        for_i_args = jax.lax.fori_loop(0, num_layers, _randaugment_inner_for_loop, for_i_args)

    image, geometric_transforms = for_i_args[0], for_i_args[1]

    if join_transforms:
        replace_value = default_replace_value or random.randint(random_key,
                                                                [image.shape[-1]],
                                                                minval=0,
                                                                maxval=256)
        image = transforms.apply_transform(image, geometric_transforms, mask_value=replace_value)

    return image
Beispiel #3
0
def test_scale():
    factor = 3
    inputs = jnp.pad(jnp.ones((1, 1, 4), dtype='uint8') * 255,
                     ((1, 1), (1, 1), (0, 0)),
                     constant_values=0)
    targets = jnp.ones_like(rgba_img) * 255
    outputs = transforms.apply_transform(jnp.pad(jnp.ones(
        (1, 1, 4), dtype='uint8'), ((1, 1), (1, 1), (0, 0)),
                                                 constant_values=0) * 255,
                                         transforms.scale_3d(scale_x=factor,
                                                             scale_y=factor),
                                         bilinear=False)
    compare(inputs, targets, outputs)
Beispiel #4
0
def test_rotate90():
    inputs = rgba_img
    targets = jnp.rot90(rgba_img, k=2)
    outputs = transforms.apply_transform(rgba_img, transforms.rotate90(n=2))
    compare(inputs, targets, outputs)
Beispiel #5
0
def test_vertical_flip():
    inputs = rgba_img
    targets = rgba_img[::-1]
    outputs = transforms.apply_transform(
        rgba_img, transforms.flip(horizontal=False, vertical=True))
    compare(inputs, targets, outputs)