def hash_func_in(x): x = jnp.bitwise_xor(x, jnp.right_shift(x, jnp.uint32(16))) x *= jnp.uint32(0x85ebca6b) x = jnp.bitwise_xor(x, jnp.right_shift(x, jnp.uint32(13))) x *= jnp.uint32(0xc2b2ae35) x = jnp.bitwise_xor(x, jnp.right_shift(x, jnp.uint32(16))) return x
def iter_func(position): for j in range(num_iters): j = jnp.uint32(j) upper = jnp.right_shift(position, bits_lower) lower = jnp.bitwise_and(position, mask_lower) mixer = hash_func_in(upper + seed_offst + j) tmp = jnp.bitwise_xor(lower, mixer) position = upper + (jnp.left_shift( jnp.bitwise_and(tmp, mask_lower), bits_upper)) return position
def permute32(vals): def hash_func_in(x): x = jnp.bitwise_xor(x, jnp.right_shift(x, jnp.uint32(16))) x *= jnp.uint32(0x85ebca6b) x = jnp.bitwise_xor(x, jnp.right_shift(x, jnp.uint32(13))) x *= jnp.uint32(0xc2b2ae35) x = jnp.bitwise_xor(x, jnp.right_shift(x, jnp.uint32(16))) return x num_iters = np.uint32(8) bits = jnp.uint32(len(bin(capacity)) - 2) bits_lower = jnp.right_shift(bits, 1) bits_upper = bits - bits_lower mask_lower = (jnp.left_shift(jnp.uint32(1), bits_lower)) - jnp.uint32(1) seed_offst = hash_func_in(seed) position = vals def iter_func(position): for j in range(num_iters): j = jnp.uint32(j) upper = jnp.right_shift(position, bits_lower) lower = jnp.bitwise_and(position, mask_lower) mixer = hash_func_in(upper + seed_offst + j) tmp = jnp.bitwise_xor(lower, mixer) position = upper + (jnp.left_shift( jnp.bitwise_and(tmp, mask_lower), bits_upper)) return position position = iter_func(position) position = jax.lax.while_loop(lambda position: position >= capacity, iter_func, position) return position
def _fold_in_str(rng: PRNGKey, data: str) -> PRNGKey: """Folds a string into a jax.random.PRNGKey using its SHA-1 hash. This is faster than splitting an PRNGKey because it allows generating new PRNG keys in parallel that are independent of each other. Args: rng: the rng to fold the string into. data: the string to be folded in. Returns: The newly generated PRNG key. """ m = hashlib.sha1() m.update(data.encode('utf-8')) d = m.digest() hash_int = int.from_bytes(d[:4], byteorder='big') return random.fold_in(rng, jnp.uint32(hash_int))
def threefry_fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray: return _threefry_fold_in(key, jnp.uint32(data))
def Bernoulli(p, prec): p = jnp.clip(_nearest_uint(p * (1 << prec)), 1, (1 << prec) - 1) onemp = (1 << prec) - p enc_statfun = lambda x: (jnp.where(x, onemp, 0), jnp.where(x, p, onemp)) dec_statfun = lambda cf: jnp.uint32(cf >= onemp) return NonUniform(enc_statfun, dec_statfun, prec)
def _nearest_uint(arr): return jnp.uint32(jnp.ceil(arr - 0.5))
def get_index(x): h = grid.size / grid.shape idx = (x.flatten() - grid.lower) // h idx = np.where((idx < 0) | (idx > grid.shape), grid.shape, idx) return (*np.flip(np.uint32(idx)), )
def get_index(x): x = 2 * (grid.lower - x.flatten()) / grid.size + 1 idx = (grid.shape * np.arccos(x)) // np.pi idx = np.nan_to_num(idx, nan=grid.shape) return (*np.flip(np.uint32(idx)), )
def get_index(x): h = grid.size / grid.shape idx = (x.flatten() - grid.lower) // h idx = idx % grid.shape return (*np.flip(np.uint32(idx)), )