Exemple #1
0
def _rbg_random_bits(key: jnp.ndarray, bit_width: int,
                     shape: Sequence[int]) -> jnp.ndarray:
    if not key.shape == (4, ) and key.dtype == jnp.dtype('uint32'):
        raise TypeError("_rbg_random_bits got invalid prng key.")
    if bit_width not in (8, 16, 32, 64):
        raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
    _, bits = lax.rng_bit_generator(key, shape, dtype=UINT_DTYPES[bit_width])
    return bits
Exemple #2
0
def _unsafe_rbg_fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray:
    _, random_bits = lax.rng_bit_generator(_rbg_seed(data), (10, 4),
                                           dtype='uint32')
    return key ^ random_bits[-1]
Exemple #3
0
def _unsafe_rbg_split(key: jnp.ndarray, num: int) -> jnp.ndarray:
    # treat 10 iterations of random bits as a 'hash function'
    _, keys = lax.rng_bit_generator(key, (10 * num, 4), dtype='uint32')
    return keys[::10]