def body(args, _): """Body for the while loop executing the binary search.""" bit_index, value = args new_value = jnp.bitwise_or(value, jnp.left_shift(1, bit_index)) larger = larger_count(scores, bitcast(new_value, jnp.float32)) next_value = jnp.where(jnp.logical_xor(larger >= k, kth_negative), new_value, value) return (bit_index - 1, next_value), None
def iter_func(position): for j in range(num_iters): j = jnp.uint32(j) upper = jnp.right_shift(position, bits_lower) lower = jnp.bitwise_and(position, mask_lower) mixer = hash_func_in(upper + seed_offst + j) tmp = jnp.bitwise_xor(lower, mixer) position = upper + (jnp.left_shift( jnp.bitwise_and(tmp, mask_lower), bits_upper)) return position
def _topk_mask(scores, k): """Efficient implementation of topk_mask for TPUs.""" def bitcast(data, newtype): return jax.lax.bitcast_convert_type(data, newtype) def larger_count(data, limit): """Number of elements larger than limit along the most minor dimension.""" ret = [] for d in data: ret.append( jnp.sum( (d > jnp.reshape(limit, [-1] + [1] * (len(d.shape) - 1))).astype(jnp.int32), axis=list(range(1, len(d.shape))))) return sum(ret) def body(args, _): """Body for the while loop executing the binary search.""" bit_index, value = args new_value = jnp.bitwise_or(value, jnp.left_shift(1, bit_index)) larger = larger_count(scores, bitcast(new_value, jnp.float32)) next_value = jnp.where(jnp.logical_xor(larger >= k, kth_negative), new_value, value) return (bit_index - 1, next_value), None kth_negative = (larger_count(scores, jnp.array(0.0)) < k) limit_sign = jnp.where(kth_negative, jnp.broadcast_to(1, kth_negative.shape), jnp.broadcast_to(0, kth_negative.shape)) next_value = jnp.left_shift(limit_sign, 31) bit_index = jnp.array(30) (_, limit), _ = lax.scan(body, (bit_index, next_value), None, length=31) ret = [] for score in scores: # Filter scores that are smaller than the threshold. ret.append( jnp.where( score >= jnp.reshape(bitcast(limit, jnp.float32), [-1] + [1] * (len(score.shape) - 1)), jnp.ones(score.shape), jnp.zeros(score.shape))) return ret
def posterize(image, bits): """ Equivalent of PIL Posterize. Args: image: image tensor bits: bits to shift Returns: Augmented image. """ has_alpha = image.shape[-1] == 4 alpha = None if has_alpha: image, alpha = image[:, :, :3], image[:, :, -1:] shift = 8 - bits.astype('int32') degenerate = jnp.left_shift(jnp.right_shift(image, shift), shift) if has_alpha: return jnp.concatenate([degenerate, alpha], axis=-1).astype('uint8') return degenerate.astype('uint8')
def permute32(vals): def hash_func_in(x): x = jnp.bitwise_xor(x, jnp.right_shift(x, jnp.uint32(16))) x *= jnp.uint32(0x85ebca6b) x = jnp.bitwise_xor(x, jnp.right_shift(x, jnp.uint32(13))) x *= jnp.uint32(0xc2b2ae35) x = jnp.bitwise_xor(x, jnp.right_shift(x, jnp.uint32(16))) return x num_iters = np.uint32(8) bits = jnp.uint32(len(bin(capacity)) - 2) bits_lower = jnp.right_shift(bits, 1) bits_upper = bits - bits_lower mask_lower = (jnp.left_shift(jnp.uint32(1), bits_lower)) - jnp.uint32(1) seed_offst = hash_func_in(seed) position = vals def iter_func(position): for j in range(num_iters): j = jnp.uint32(j) upper = jnp.right_shift(position, bits_lower) lower = jnp.bitwise_and(position, mask_lower) mixer = hash_func_in(upper + seed_offst + j) tmp = jnp.bitwise_xor(lower, mixer) position = upper + (jnp.left_shift( jnp.bitwise_and(tmp, mask_lower), bits_upper)) return position position = iter_func(position) position = jax.lax.while_loop(lambda position: position >= capacity, iter_func, position) return position
def left_shift(x1, x2): if isinstance(x1, JaxArray): x1 = x1.value if isinstance(x2, JaxArray): x2 = x2.value return JaxArray(jnp.left_shift(x1, x2))