def tri(n, m, k=0): # Tie in the key to avoid the mask becoming a constant. # This way XLA can construct the mask during computation and fuse it # with the attention ops. x = lax.tie_in(key, jnp.arange(n, dtype=jnp.int32)) y = lax.tie_in(key, jnp.arange(m, dtype=jnp.int32)) mask = lax.ge((lax.broadcast_in_dim( x, shape=(n, m), broadcast_dimensions=(0, ))) + k, lax.broadcast(y, [n])) return mask
def f(x): token = lax.create_token(x) y, token = lax.infeed(token, shape=jax.ShapedArray((3, 4), jnp.float32)) token = lax.outfeed(token, y + np.float32(1)) return x - 1 if config.omnistaging_enabled else lax.tie_in( token, x - 1)
def _call(self, x, training=True, rng=None): info = self.info if training: if rng is None: raise ValueError('rng is required when training is True') # Using tie_in to avoid materializing constants keep = lax.tie_in(x, random.bernoulli(rng, info.rate, x.shape)) return np.where(keep, x / info.rate, 0) else: return x
def template_build(cls, init_key, *args, name=None, **kwargs): """Instantiates layer object from RNG and layer specifications.""" if init_key is None: raise ValueError('Cannot initialize template with `None` PRNGKey.') layer_params = cls.initialize(init_key, *args, **kwargs) if init_key is not None: new_params = tree_util.tree_map(lambda x: lax.tie_in(init_key, x), (layer_params.params, layer_params.state)) layer_params = LayerParams(params=new_params[0], state=new_params[1], info=layer_params.info) return cls.new(layer_params, name=name)
def hardware_bernoulli(rng_key, p=np.float32(0.5), shape=None): """Faster RNG.""" y = 1.0 x = 0.0 if FLAGS.use_bfloat16_activation: y = jnp.bfloat16(y) x = jnp.bfloat16(0.0) p = jnp.bfloat16(p) y = lax.tie_in(rng_key, y) m = lax.rng_uniform(x, y, shape) if FLAGS.use_bfloat16_activation: assert m.dtype == jnp.bfloat16 return m < p
def step(key, state, init_key=None): transition_key, accept_key = random.split(key) next_state = st.init(inner_step)(init_key, transition_key, state)(transition_key, state) # TODO(sharadmv): add log probabilities to the state to avoid recalculation. state_log_prob = unnormalized_log_prob(state) next_state_log_prob = unnormalized_log_prob(next_state) log_unclipped_accept_prob = next_state_log_prob - state_log_prob accept_prob = harvest.sow(np.clip(np.exp(log_unclipped_accept_prob), 0., 1.), tag=MCMC_METRICS, name='accept_prob') u = lax.tie_in(accept_prob, random.uniform(accept_key)) accept = np.log(u) < log_unclipped_accept_prob return tree_util.tree_multimap(lambda n, s: np.where(accept, n, s), next_state, state)
def hardware_bernoulli(rng_key, p=np.float32(0.5), shape=None): return lax.rng_uniform(lax.tie_in(rng_key, 0.0), 1.0, shape) < p
def fake_quant(self, x, *, quantized_type, fake_dependency=None): x_dtype = x.dtype quantized_x = self.to_quantized(x, dtype=quantized_type) if fake_dependency is not None: quantized_x = lax.tie_in(fake_dependency, quantized_x) return self.from_quantized(quantized_x, dtype=x_dtype)
def cond(idx_carry): i, c = idx_carry return i < jnp.sum(lax.tie_in( i, cond_const)) # Capture cond_const
def f(n): token = lax.create_token(n) token = lax.fori_loop(0, n, doubler, token) return n if config.omnistaging_enabled else lax.tie_in(token, n)
def f(n): token = lax.create_token(n) token = lax.fori_loop(0, n, doubler, token) return lax.tie_in(token, n)
def f(x): token = lax.create_token(x) y, token = lax.infeed(token, shape=jax.ShapedArray((3, 4), np.float32)) token = lax.outfeed(token, y + onp.float32(1)) return lax.tie_in(token, x - 1)
def f(x, init_key=None): y = module.variable(np.zeros(x.shape), name='y', key=init_key) next_y = module.assign(y + 1., name='y') return lax.tie_in(next_y, x) + y