def hardware_bernoulli(rng_key, p=np.float32(0.5), shape=None): """Faster RNG.""" y = 1.0 x = 0.0 if FLAGS.use_bfloat16_activation: y = jnp.bfloat16(y) x = jnp.bfloat16(0.0) p = jnp.bfloat16(p) y = lax.tie_in(rng_key, y) m = lax.rng_uniform(x, y, shape) if FLAGS.use_bfloat16_activation: assert m.dtype == jnp.bfloat16 return m < p
def _remat_translation_using_cond(*args, jaxpr: core.Jaxpr): # Implements: # if(rng(0, 1) < 2) # return eval_jaxpr(*args) # else: # return 0 avals_out = tuple(ov.aval for ov in jaxpr.outvars) def remat_comp(*args): return tuple(core.eval_jaxpr(jaxpr, (), *args)) def dummy_comp(*args): return tuple(_map(_dummy_remat_result, avals_out)) cond_pred = (lax.rng_uniform(np.float32(0), np.float32(1), shape=()) < np.float32(2)) return cond(cond_pred, remat_comp, dummy_comp, *args)
def hardware_bernoulli(rng_key, p=np.float32(0.5), shape=None): return lax.rng_uniform(lax.tie_in(rng_key, 0.0), 1.0, shape) < p
def cond(carry): counter, _, _ = carry return counter < lax.rng_uniform(np.int32(1), np.int32(2), shape=())