예제 #1
0
def hardware_bernoulli(rng_key, p=np.float32(0.5), shape=None):
  """Faster RNG."""
  y = 1.0
  x = 0.0
  if FLAGS.use_bfloat16_activation:
    y = jnp.bfloat16(y)
    x = jnp.bfloat16(0.0)
    p = jnp.bfloat16(p)
  y = lax.tie_in(rng_key, y)
  m = lax.rng_uniform(x, y, shape)
  if FLAGS.use_bfloat16_activation:
    assert m.dtype == jnp.bfloat16
  return m < p
예제 #2
0
def _remat_translation_using_cond(*args, jaxpr: core.Jaxpr):
    # Implements:
    #  if(rng(0, 1) < 2)
    #    return eval_jaxpr(*args)
    #  else:
    #    return 0
    avals_out = tuple(ov.aval for ov in jaxpr.outvars)

    def remat_comp(*args):
        return tuple(core.eval_jaxpr(jaxpr, (), *args))

    def dummy_comp(*args):
        return tuple(_map(_dummy_remat_result, avals_out))

    cond_pred = (lax.rng_uniform(np.float32(0), np.float32(1), shape=()) <
                 np.float32(2))
    return cond(cond_pred, remat_comp, dummy_comp, *args)
예제 #3
0
def hardware_bernoulli(rng_key, p=np.float32(0.5), shape=None):
    return lax.rng_uniform(lax.tie_in(rng_key, 0.0), 1.0, shape) < p
예제 #4
0
 def cond(carry):
     counter, _, _ = carry
     return counter < lax.rng_uniform(np.int32(1), np.int32(2), shape=())