def forward_with_state(self, inputs, weights=base.EMPTY_WEIGHTS, state=base.EMPTY_STATE, rng=None, **kwargs): del weights q, k, v = inputs if self._mode in ('train', 'eval'): mask_size = q.shape[-2] # Not all backends define np.tril. However, using onp.tril is inefficient # in that it creates a large global constant. TODO(kitaev): try to find an # alternative that works across all backends. if math.backend_name() == 'jax': mask = np.tril(np.ones((1, mask_size, mask_size), dtype=onp.bool_), k=0) else: mask = onp.tril(onp.ones((1, mask_size, mask_size), dtype=onp.bool_), k=0) else: assert self._mode == 'predict' state = _fast_inference_update_state(inputs, state) (k, v, mask, _) = state res = DotProductAttention(q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=rng) return res, state
def forward_with_state(self, inputs, weights, state, rng): del weights q, k, v = inputs if self._mode == 'predict': state = _fast_inference_update_state(inputs, state) (k, v, mask, _) = state else: mask_size = q.shape[-2] # Not all backends define jnp.tril. However, using np.tril is inefficient # in that it creates a large global constant. TODO(kitaev): try to find an # alternative that works across all backends. if math.backend_name() == 'jax': mask = jnp.tril(jnp.ones((1, mask_size, mask_size), dtype=np.bool_), k=0) else: mask = np.tril(np.ones((1, mask_size, mask_size), dtype=np.bool_), k=0) res = DotProductAttention(q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=rng) return res, state