Beispiel #1
0
def AtariConvInit(kernel_shape, rng, dtype=jnp.float32):
    """The standard init for Conv laters and Atari."""
    filter_height, filter_width, fan_in, _ = kernel_shape
    std = 1 / jnp.sqrt(fan_in * filter_height * filter_width)
    return random.uniform(rng, kernel_shape, dtype, minval=-std, maxval=std)
Beispiel #2
0
 def _z_score(self, x, mean, variance):
     mu = mean.astype(x.dtype)
     sigma = np.sqrt(variance + self._epsilon).astype(x.dtype)
     return (x - mu) / sigma
Beispiel #3
0
def KaimingUniformInitializer(out_dim=-1, in_dim=-2, param=0.):
    """Returns an initializer for random uniform Kaiming-scaled coefficients."""
    return ScaledInitializer(out_dim, in_dim, 2.0 / jnp.sqrt(1 + param**2),
                             'fan_in', 'uniform')
Beispiel #4
0
def LayerNorm(x, weights, epsilon=1e-6, **unused_kwargs):  # pylint: disable=invalid-name
    (scale, bias) = weights
    mean = np.mean(x, axis=-1, keepdims=True)
    variance = np.mean((x - mean)**2, axis=-1, keepdims=True)
    norm_inputs = (x - mean) / np.sqrt(variance + epsilon)
    return norm_inputs * scale + bias
Beispiel #5
0
def l2_norm(tree):
  """Compute the l2 norm of a pytree of arrays. Useful for weight decay."""
  leaves = _tree_flatten(tree)
  return np.sqrt(sum(np.vdot(x, x) for x in leaves))
Beispiel #6
0
def Gelu(x, **unused_kwargs):
    return x * 0.5 * (1.0 + math.erf(x / np.sqrt(2.0)))
Beispiel #7
0
def attend(
    q, k=None, v=None,
    q_chunk_len=None, kv_chunk_len=None,
    n_chunks_before=0, n_chunks_after=0,
    mask_fn=None, q_info=None, kv_info=None,
    dropout=0.0, rng=None,
    ):
  """Dot-product attention, with optional chunking and/or masking.

  Args:
    q: Query vectors, shape [q_len, d_qk]
    k: Key vectors, shape [kv_len, d_qk]; or None
    v: Value vectors, shape [kv_len, d_v]
    q_chunk_len: Set to non-zero to enable chunking for query vectors
    kv_chunk_len: Set to non-zero to enable chunking for key/value vectors
    n_chunks_before: Number of adjacent previous chunks to attend to
    n_chunks_after: Number of adjacent subsequent chunks to attend to
    mask_fn: TODO(kitaev) doc
    q_info: Query-associated metadata for masking
    kv_info: Key-associated metadata for masking
    dropout: Dropout rate
    rng: RNG for dropout

  Returns:
    A tuple (output, dots_logsumexp). The output has shape [q_len, d_v], and
    dots_logsumexp has shape [q_len]. The logsumexp of the attention
    probabilities is useful for combining multiple rounds of attention (as in
    LSH attention).
  """
  assert v is not None
  share_qk = (k is None)

  if q_info is None:
    q_info = np.arange(q.shape[-2])

  if kv_info is None and not share_qk:
    kv_info = np.arange(v.shape[-2])

  # Split q/k/v into chunks along the time axis, if desired.
  if q_chunk_len is not None:
    q = np.reshape(q, (-1, q_chunk_len, q.shape[-1]))
    q_info = np.reshape(q_info, (-1, q_chunk_len))

  if share_qk:
    assert kv_chunk_len is None or kv_chunk_len == q_chunk_len
    k = q
    kv_chunk_len = q_chunk_len
    kv_info = q_info
  elif kv_chunk_len is not None:
    k = np.reshape(k, (-1, kv_chunk_len, k.shape[-1]))
    kv_info = np.reshape(kv_info, (-1, kv_chunk_len))

  if kv_chunk_len is not None:
    v = np.reshape(v, (-1, kv_chunk_len, v.shape[-1]))

  if share_qk:
    k = length_normalized(k)
  k = k / np.sqrt(k.shape[-1])

  # Optionally include adjacent chunks.
  if q_chunk_len is not None or kv_chunk_len is not None:
    assert q_chunk_len is not None and kv_chunk_len is not None
  else:
    assert n_chunks_before == 0 and n_chunks_after == 0

  k = look_adjacent(k, n_chunks_before, n_chunks_after)
  v = look_adjacent(v, n_chunks_before, n_chunks_after)
  kv_info = look_adjacent(kv_info, n_chunks_before, n_chunks_after)

  # Dot-product attention.
  dots = np.matmul(q, np.swapaxes(k, -1, -2))

  # Masking
  if mask_fn is not None:
    dots = mask_fn(dots, q_info[..., :, None], kv_info[..., None, :])

  # Softmax.
  dots_logsumexp = logsumexp(dots, axis=-1, keepdims=True)
  dots = np.exp(dots - dots_logsumexp)

  if dropout > 0.0:
    assert rng is not None
    # Dropout is broadcast across the bin dimension
    dropout_shape = (dots.shape[-2], dots.shape[-1])
    # TODO(kitaev): verify that tie-in is safe to remove (in light of jax fix)
    keep_prob = jax.lax.tie_in(dots, 1.0 - dropout)
    keep = jax.random.bernoulli(rng, keep_prob, dropout_shape)
    multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(keep, keep_prob)
    dots = dots * multiplier

  # The softmax normalizer (dots_logsumexp) is used by multi-round LSH attn.
  out = np.matmul(dots, v)
  out = np.reshape(out, (-1, out.shape[-1]))
  dots_logsumexp = np.reshape(dots_logsumexp, (-1,))
  return out, dots_logsumexp
Beispiel #8
0
def length_normalized(x, epsilon=1e-6):
  variance = np.mean(x**2, axis=-1, keepdims=True)
  norm_inputs = x / np.sqrt(variance + epsilon)
  return norm_inputs