def forward_unbatched(self, x, mask=None, *, weights, state, update_state): del update_state if self.share_qk: w_q, w_v, w_o = weights else: w_q, w_k, w_v, w_o = weights q = np.matmul(x, w_q) k = None if not self.share_qk: k = np.matmul(x, w_k) v = np.matmul(x, w_v) mask_fn = functools.partial(mask_self_attention, causal=self.causal, exclude_self=self.share_qk, masked=self.masked) q_info = kv_info = jax.lax.tie_in(x, np.arange(q.shape[-2])) assert (mask is not None) == self.masked if self.masked: # mask is a boolean array (True means "is valid token") ones_like_mask = jax.lax.tie_in(x, np.ones_like(mask, dtype=np.int32)) kv_info = kv_info * np.where(mask, ones_like_mask, -ones_like_mask) o, _ = attend( q, k, v, q_chunk_len=self.chunk_len, kv_chunk_len=self.chunk_len, n_chunks_before=self.n_chunks_before, n_chunks_after=self.n_chunks_after, mask_fn=mask_fn, q_info=q_info, kv_info=kv_info, dropout=self.attention_dropout, rng=None, # TODO(kitaev): support RNG ) out = np.matmul(o, w_o) return out, state
def init(self, params): return np.ones_like(params)
def WeightMask(target, mask_id=0, **unused_kwargs): if mask_id is None: return np.ones_like(target) return 1.0 - np.equal(target, mask_id).astype(np.float32)
def _ElementMask(target, id_to_mask=0, **unused_kwargs): """Returns a mask with zeros marking elements to exclude from calculations.""" if id_to_mask is None: return np.ones_like(target) return 1.0 - np.equal(target, id_to_mask).astype(np.float32)
def WeightMask(target, mask_id=0, **kw): del kw if mask_id is None: return np.ones_like(target) return 1.0 - np.equal(target, mask_id).astype(np.float32)
def init(self, weights): return np.ones_like(weights)