def AtariConvInit(kernel_shape, rng, dtype=jnp.float32): """The standard init for Conv laters and Atari.""" filter_height, filter_width, fan_in, _ = kernel_shape std = 1 / jnp.sqrt(fan_in * filter_height * filter_width) return random.uniform(rng, kernel_shape, dtype, minval=-std, maxval=std)
def _z_score(self, x, mean, variance): mu = mean.astype(x.dtype) sigma = np.sqrt(variance + self._epsilon).astype(x.dtype) return (x - mu) / sigma
def KaimingUniformInitializer(out_dim=-1, in_dim=-2, param=0.): """Returns an initializer for random uniform Kaiming-scaled coefficients.""" return ScaledInitializer(out_dim, in_dim, 2.0 / jnp.sqrt(1 + param**2), 'fan_in', 'uniform')
def LayerNorm(x, weights, epsilon=1e-6, **unused_kwargs): # pylint: disable=invalid-name (scale, bias) = weights mean = np.mean(x, axis=-1, keepdims=True) variance = np.mean((x - mean)**2, axis=-1, keepdims=True) norm_inputs = (x - mean) / np.sqrt(variance + epsilon) return norm_inputs * scale + bias
def l2_norm(tree): """Compute the l2 norm of a pytree of arrays. Useful for weight decay.""" leaves = _tree_flatten(tree) return np.sqrt(sum(np.vdot(x, x) for x in leaves))
def Gelu(x, **unused_kwargs): return x * 0.5 * (1.0 + math.erf(x / np.sqrt(2.0)))
def attend( q, k=None, v=None, q_chunk_len=None, kv_chunk_len=None, n_chunks_before=0, n_chunks_after=0, mask_fn=None, q_info=None, kv_info=None, dropout=0.0, rng=None, ): """Dot-product attention, with optional chunking and/or masking. Args: q: Query vectors, shape [q_len, d_qk] k: Key vectors, shape [kv_len, d_qk]; or None v: Value vectors, shape [kv_len, d_v] q_chunk_len: Set to non-zero to enable chunking for query vectors kv_chunk_len: Set to non-zero to enable chunking for key/value vectors n_chunks_before: Number of adjacent previous chunks to attend to n_chunks_after: Number of adjacent subsequent chunks to attend to mask_fn: TODO(kitaev) doc q_info: Query-associated metadata for masking kv_info: Key-associated metadata for masking dropout: Dropout rate rng: RNG for dropout Returns: A tuple (output, dots_logsumexp). The output has shape [q_len, d_v], and dots_logsumexp has shape [q_len]. The logsumexp of the attention probabilities is useful for combining multiple rounds of attention (as in LSH attention). """ assert v is not None share_qk = (k is None) if q_info is None: q_info = np.arange(q.shape[-2]) if kv_info is None and not share_qk: kv_info = np.arange(v.shape[-2]) # Split q/k/v into chunks along the time axis, if desired. if q_chunk_len is not None: q = np.reshape(q, (-1, q_chunk_len, q.shape[-1])) q_info = np.reshape(q_info, (-1, q_chunk_len)) if share_qk: assert kv_chunk_len is None or kv_chunk_len == q_chunk_len k = q kv_chunk_len = q_chunk_len kv_info = q_info elif kv_chunk_len is not None: k = np.reshape(k, (-1, kv_chunk_len, k.shape[-1])) kv_info = np.reshape(kv_info, (-1, kv_chunk_len)) if kv_chunk_len is not None: v = np.reshape(v, (-1, kv_chunk_len, v.shape[-1])) if share_qk: k = length_normalized(k) k = k / np.sqrt(k.shape[-1]) # Optionally include adjacent chunks. if q_chunk_len is not None or kv_chunk_len is not None: assert q_chunk_len is not None and kv_chunk_len is not None else: assert n_chunks_before == 0 and n_chunks_after == 0 k = look_adjacent(k, n_chunks_before, n_chunks_after) v = look_adjacent(v, n_chunks_before, n_chunks_after) kv_info = look_adjacent(kv_info, n_chunks_before, n_chunks_after) # Dot-product attention. dots = np.matmul(q, np.swapaxes(k, -1, -2)) # Masking if mask_fn is not None: dots = mask_fn(dots, q_info[..., :, None], kv_info[..., None, :]) # Softmax. dots_logsumexp = logsumexp(dots, axis=-1, keepdims=True) dots = np.exp(dots - dots_logsumexp) if dropout > 0.0: assert rng is not None # Dropout is broadcast across the bin dimension dropout_shape = (dots.shape[-2], dots.shape[-1]) # TODO(kitaev): verify that tie-in is safe to remove (in light of jax fix) keep_prob = jax.lax.tie_in(dots, 1.0 - dropout) keep = jax.random.bernoulli(rng, keep_prob, dropout_shape) multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(keep, keep_prob) dots = dots * multiplier # The softmax normalizer (dots_logsumexp) is used by multi-round LSH attn. out = np.matmul(dots, v) out = np.reshape(out, (-1, out.shape[-1])) dots_logsumexp = np.reshape(dots_logsumexp, (-1,)) return out, dots_logsumexp
def length_normalized(x, epsilon=1e-6): variance = np.mean(x**2, axis=-1, keepdims=True) norm_inputs = x / np.sqrt(variance + epsilon) return norm_inputs