Beispiel #1
0
 def tri(n, m, k=0):
     # Tie in the key to avoid the mask becoming a constant.
     # This way XLA can construct the mask during computation and fuse it
     # with the attention ops.
     x = lax.tie_in(key, jnp.arange(n, dtype=jnp.int32))
     y = lax.tie_in(key, jnp.arange(m, dtype=jnp.int32))
     mask = lax.ge((lax.broadcast_in_dim(
         x, shape=(n, m), broadcast_dimensions=(0, ))) + k,
                   lax.broadcast(y, [n]))
     return mask
Beispiel #2
0
 def f(x):
     token = lax.create_token(x)
     y, token = lax.infeed(token,
                           shape=jax.ShapedArray((3, 4), jnp.float32))
     token = lax.outfeed(token, y + np.float32(1))
     return x - 1 if config.omnistaging_enabled else lax.tie_in(
         token, x - 1)
Beispiel #3
0
 def _call(self, x, training=True, rng=None):
   info = self.info
   if training:
     if rng is None:
       raise ValueError('rng is required when training is True')
     # Using tie_in to avoid materializing constants
     keep = lax.tie_in(x, random.bernoulli(rng, info.rate, x.shape))
     return np.where(keep, x / info.rate, 0)
   else:
     return x
Beispiel #4
0
def template_build(cls, init_key, *args, name=None, **kwargs):
  """Instantiates layer object from RNG and layer specifications."""
  if init_key is None:
    raise ValueError('Cannot initialize template with `None` PRNGKey.')
  layer_params = cls.initialize(init_key, *args, **kwargs)
  if init_key is not None:
    new_params = tree_util.tree_map(lambda x: lax.tie_in(init_key, x),
                                    (layer_params.params, layer_params.state))
    layer_params = LayerParams(params=new_params[0], state=new_params[1],
                               info=layer_params.info)
  return cls.new(layer_params, name=name)
Beispiel #5
0
def hardware_bernoulli(rng_key, p=np.float32(0.5), shape=None):
  """Faster RNG."""
  y = 1.0
  x = 0.0
  if FLAGS.use_bfloat16_activation:
    y = jnp.bfloat16(y)
    x = jnp.bfloat16(0.0)
    p = jnp.bfloat16(p)
  y = lax.tie_in(rng_key, y)
  m = lax.rng_uniform(x, y, shape)
  if FLAGS.use_bfloat16_activation:
    assert m.dtype == jnp.bfloat16
  return m < p
Beispiel #6
0
 def step(key, state, init_key=None):
     transition_key, accept_key = random.split(key)
     next_state = st.init(inner_step)(init_key, transition_key,
                                      state)(transition_key, state)
     # TODO(sharadmv): add log probabilities to the state to avoid recalculation.
     state_log_prob = unnormalized_log_prob(state)
     next_state_log_prob = unnormalized_log_prob(next_state)
     log_unclipped_accept_prob = next_state_log_prob - state_log_prob
     accept_prob = harvest.sow(np.clip(np.exp(log_unclipped_accept_prob),
                                       0., 1.),
                               tag=MCMC_METRICS,
                               name='accept_prob')
     u = lax.tie_in(accept_prob, random.uniform(accept_key))
     accept = np.log(u) < log_unclipped_accept_prob
     return tree_util.tree_multimap(lambda n, s: np.where(accept, n, s),
                                    next_state, state)
Beispiel #7
0
def hardware_bernoulli(rng_key, p=np.float32(0.5), shape=None):
    return lax.rng_uniform(lax.tie_in(rng_key, 0.0), 1.0, shape) < p
Beispiel #8
0
 def fake_quant(self, x, *, quantized_type, fake_dependency=None):
     x_dtype = x.dtype
     quantized_x = self.to_quantized(x, dtype=quantized_type)
     if fake_dependency is not None:
         quantized_x = lax.tie_in(fake_dependency, quantized_x)
     return self.from_quantized(quantized_x, dtype=x_dtype)
Beispiel #9
0
 def cond(idx_carry):
     i, c = idx_carry
     return i < jnp.sum(lax.tie_in(
         i, cond_const))  # Capture cond_const
Beispiel #10
0
 def f(n):
     token = lax.create_token(n)
     token = lax.fori_loop(0, n, doubler, token)
     return n if config.omnistaging_enabled else lax.tie_in(token, n)
Beispiel #11
0
 def f(n):
     token = lax.create_token(n)
     token = lax.fori_loop(0, n, doubler, token)
     return lax.tie_in(token, n)
Beispiel #12
0
 def f(x):
     token = lax.create_token(x)
     y, token = lax.infeed(token,
                           shape=jax.ShapedArray((3, 4), np.float32))
     token = lax.outfeed(token, y + onp.float32(1))
     return lax.tie_in(token, x - 1)
Beispiel #13
0
 def f(x, init_key=None):
     y = module.variable(np.zeros(x.shape), name='y', key=init_key)
     next_y = module.assign(y + 1., name='y')
     return lax.tie_in(next_y, x) + y