Example #1
0
 def body(args, _):
     """Body for the while loop executing the binary search."""
     bit_index, value = args
     new_value = jnp.bitwise_or(value, jnp.left_shift(1, bit_index))
     larger = larger_count(scores, bitcast(new_value, jnp.float32))
     next_value = jnp.where(jnp.logical_xor(larger >= k, kth_negative),
                            new_value, value)
     return (bit_index - 1, next_value), None
Example #2
0
File: util.py Project: byzhang/d3p
        def iter_func(position):
            for j in range(num_iters):
                j = jnp.uint32(j)
                upper = jnp.right_shift(position, bits_lower)
                lower = jnp.bitwise_and(position, mask_lower)
                mixer = hash_func_in(upper + seed_offst + j)

                tmp = jnp.bitwise_xor(lower, mixer)
                position = upper + (jnp.left_shift(
                    jnp.bitwise_and(tmp, mask_lower), bits_upper))
            return position
Example #3
0
def _topk_mask(scores, k):
    """Efficient implementation of topk_mask for TPUs."""
    def bitcast(data, newtype):
        return jax.lax.bitcast_convert_type(data, newtype)

    def larger_count(data, limit):
        """Number of elements larger than limit along the most minor dimension."""
        ret = []
        for d in data:
            ret.append(
                jnp.sum(
                    (d > jnp.reshape(limit, [-1] + [1] *
                                     (len(d.shape) - 1))).astype(jnp.int32),
                    axis=list(range(1, len(d.shape)))))
        return sum(ret)

    def body(args, _):
        """Body for the while loop executing the binary search."""
        bit_index, value = args
        new_value = jnp.bitwise_or(value, jnp.left_shift(1, bit_index))
        larger = larger_count(scores, bitcast(new_value, jnp.float32))
        next_value = jnp.where(jnp.logical_xor(larger >= k, kth_negative),
                               new_value, value)
        return (bit_index - 1, next_value), None

    kth_negative = (larger_count(scores, jnp.array(0.0)) < k)
    limit_sign = jnp.where(kth_negative,
                           jnp.broadcast_to(1, kth_negative.shape),
                           jnp.broadcast_to(0, kth_negative.shape))
    next_value = jnp.left_shift(limit_sign, 31)
    bit_index = jnp.array(30)
    (_, limit), _ = lax.scan(body, (bit_index, next_value), None, length=31)
    ret = []
    for score in scores:
        # Filter scores that are smaller than the threshold.
        ret.append(
            jnp.where(
                score >= jnp.reshape(bitcast(limit, jnp.float32), [-1] + [1] *
                                     (len(score.shape) - 1)),
                jnp.ones(score.shape), jnp.zeros(score.shape)))

    return ret
Example #4
0
def posterize(image, bits):
    """
    Equivalent of PIL Posterize.
    Args:
        image: image tensor
        bits: bits to shift

    Returns:
        Augmented image.
    """
    has_alpha = image.shape[-1] == 4
    alpha = None

    if has_alpha:
        image, alpha = image[:, :, :3], image[:, :, -1:]

    shift = 8 - bits.astype('int32')
    degenerate = jnp.left_shift(jnp.right_shift(image, shift), shift)

    if has_alpha:
        return jnp.concatenate([degenerate, alpha], axis=-1).astype('uint8')
    return degenerate.astype('uint8')
Example #5
0
File: util.py Project: byzhang/d3p
    def permute32(vals):
        def hash_func_in(x):
            x = jnp.bitwise_xor(x, jnp.right_shift(x, jnp.uint32(16)))
            x *= jnp.uint32(0x85ebca6b)
            x = jnp.bitwise_xor(x, jnp.right_shift(x, jnp.uint32(13)))
            x *= jnp.uint32(0xc2b2ae35)
            x = jnp.bitwise_xor(x, jnp.right_shift(x, jnp.uint32(16)))

            return x

        num_iters = np.uint32(8)

        bits = jnp.uint32(len(bin(capacity)) - 2)
        bits_lower = jnp.right_shift(bits, 1)
        bits_upper = bits - bits_lower
        mask_lower = (jnp.left_shift(jnp.uint32(1),
                                     bits_lower)) - jnp.uint32(1)

        seed_offst = hash_func_in(seed)
        position = vals

        def iter_func(position):
            for j in range(num_iters):
                j = jnp.uint32(j)
                upper = jnp.right_shift(position, bits_lower)
                lower = jnp.bitwise_and(position, mask_lower)
                mixer = hash_func_in(upper + seed_offst + j)

                tmp = jnp.bitwise_xor(lower, mixer)
                position = upper + (jnp.left_shift(
                    jnp.bitwise_and(tmp, mask_lower), bits_upper))
            return position

        position = iter_func(position)
        position = jax.lax.while_loop(lambda position: position >= capacity,
                                      iter_func, position)

        return position
Example #6
0
def left_shift(x1, x2):
  if isinstance(x1, JaxArray): x1 = x1.value
  if isinstance(x2, JaxArray): x2 = x2.value
  return JaxArray(jnp.left_shift(x1, x2))