Пример #1
0
  def _update_diagonal(self, g, w, m, v1, v2, opt_params):
    learning_rate = opt_params['learning_rate']
    beta2 = opt_params['second_moment_averaging']
    weight_decay = opt_params['weight_decay']

    is_beta2_1 = (beta2 == 1).astype(g.dtype)
    one_minus_beta2_except1 = is_beta2_1  + (1.0 - beta2) * (1.0 - is_beta2_1)
    v1[0] = beta2 * v1[0] + one_minus_beta2_except1 * g * g

    preconditioner = jnp.where(v1[0] > 0, 1.0 / (jnp.sqrt(v1[0]) + 1e-16),
                               jnp.zeros_like(v1[0]))

    pg = preconditioner * g
    if self._graft:
      v2[0] += g * g
      preconditioner_graft = jnp.where(
          v2[0] > 0, 1.0 / (jnp.sqrt(v2[0]) + 1e-16), jnp.zeros_like(v2[0]))
      pg_graft = preconditioner_graft * g
      pg_norm = jnp.linalg.norm(pg)
      pg_graft_norm = jnp.linalg.norm(pg_graft)
      pg = pg * (pg_graft_norm/(pg_norm + 1e-16))

    pg = pg + w * weight_decay

    if self._has_momentum:
      m, update = self._momentum_update(pg, m, opt_params['momentum'])
    else:
      update = pg

    w = w - (update * learning_rate).astype(w.dtype)
    return w, (m, v1, v2)
Пример #2
0
 def learning_rate(step):
     """Step to learning rate function."""
     ret = 1.0
     for name in factors:
         if name == 'constant':
             ret *= constant
         elif name == 'linear_warmup':
             ret *= jnp.minimum(1.0, step / warmup_steps)
         elif name == 'rsqrt_decay':
             ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
         elif name == 'rsqrt_normalized_decay':
             ret *= jnp.sqrt(warmup_steps)
             ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
         elif name == 'decay_every':
             ret *= (decay_factor**(step // steps_per_decay))
         elif name == 'cosine_decay':
             progress = jnp.maximum(0.0, (step - warmup_steps) /
                                    float(steps_per_cycle))
             ret *= (0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0))))
         else:
             raise ValueError('Unknown factor %s.' % name)
     # TODO(henrykm): return float(jnp.max(minimum, ret)) would be
     # better but causes TypeError: 'numpy.float64' object cannot
     # be interpreted as an integer
     if ret <= minimum:
         return minimum
     return ret
Пример #3
0
    def _update_sketched(self, g, w, m, v1, v2, opt_params):
        """Update for higher-rank parameters."""
        learning_rate = opt_params['learning_rate']
        momentum = opt_params['momentum']
        beta2 = opt_params['second_moment_averaging']
        weight_decay = opt_params['weight_decay']

        shape = w.shape
        rank = len(shape)
        reshaped_accumulators = [
            jnp.reshape(v1[i], self._expanded_shape(shape, i))
            for i in range(rank)
        ]
        acc = self._minimum(reshaped_accumulators)

        is_beta2_1 = (beta2 == 1).astype(g.dtype)
        one_minus_beta2_except1 = is_beta2_1 + (1.0 - beta2) * (1.0 -
                                                                is_beta2_1)
        acc = beta2 * acc + one_minus_beta2_except1 * g * g

        preconditioner = jnp.where(acc > 0.0, 1.0 / (jnp.sqrt(acc) + 1e-16),
                                   jnp.zeros_like(acc))
        pg = g * preconditioner
        if self._graft:
            v2_acc = self._minimum([
                jnp.reshape(v2[i], self._expanded_shape(shape, i))
                for i in range(rank)
            ])
            v2_acc = v2_acc + g * g
            preconditioner_graft = jnp.where(v2_acc > 0.0,
                                             1.0 / (jnp.sqrt(v2_acc) + 1e-16),
                                             jnp.zeros_like(v2_acc))
            pg_graft = preconditioner_graft * g
            pg_norm = jnp.linalg.norm(pg)
            pg_graft_norm = jnp.linalg.norm(pg_graft)
            pg = pg * (pg_graft_norm / (pg_norm + 1e-16))

        pg = pg + w * weight_decay

        if self._has_momentum:
            m, update = self._momentum_update(pg, m, momentum)
        else:
            update = pg

        w = w - (learning_rate * update).astype(w.dtype)
        for i in range(len(v1)):
            axes = list(range(int(i))) + list(range(int(i) + 1, rank))
            dim_accumulator = jnp.amax(acc, axis=axes)
            v1[i] = dim_accumulator

        if self._graft:
            for i in range(len(v2)):
                axes = list(range(int(i))) + list(range(int(i) + 1, rank))
                dim_accumulator = jnp.amax(v2_acc, axis=axes)
                v2[i] = dim_accumulator
        return w, (m, v1, v2)
Пример #4
0
def Gelu():
  r"""Returns a layer that computes the Gaussian Error Linear Unit function.

  .. math::
      f(x) = \frac{x}{2} \cdot (1 + \hbox{erf}(\frac{x}{\sqrt{2}}))
  """
  return Fn('Gelu', lambda x: x * 0.5 * (1.0 + fastmath.erf(x / jnp.sqrt(2.0))))
Пример #5
0
def DotProductAttention(query, key, value, mask):
    """Dot product self-attention.
    Args:
        query (jax.interpreters.xla.DeviceArray): array of query representations with shape (L_q by d)
        key (jax.interpreters.xla.DeviceArray): array of key representations with shape (L_k by d)
        value (jax.interpreters.xla.DeviceArray): array of value representations with shape (L_k by d) where L_v = L_k
        mask (jax.interpreters.xla.DeviceArray): attention-mask, gates attention with shape (L_q by L_k)

    Returns:
        jax.interpreters.xla.DeviceArray: Self-attention array for q, k, v arrays. (L_q by L_k)
    """

    assert query.shape[-1] == key.shape[-1] == value.shape[
        -1], "Embedding dimensions of q, k, v aren't all the same"

    depth = query.shape[-1]

    # Calculate scaled query key dot product according to formula above
    dots = jnp.matmul(query, jnp.swapaxes(key, -1, -2)) / jnp.sqrt(depth)

    if mask is not None:  # The 'None' in this line does not need to be replaced
        dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))

    # Softmax formula implementation
    logsumexp = trax.fastmath.logsumexp(dots, axis=-1, keepdims=True)

    dots = jnp.exp(dots - logsumexp)
    attention = jnp.matmul(dots, value)

    return attention
Пример #6
0
 def forward(self, x):
     scale, bias = self.weights
     mean = jnp.mean(x, axis=-1, keepdims=True)
     sub = x - mean
     variance = jnp.mean(sub * sub, axis=-1, keepdims=True)
     norm_inputs = sub / jnp.sqrt(variance + self._epsilon)
     return norm_inputs * scale + bias
Пример #7
0
 def log_prob(self, inputs, point):
     point = point.reshape(inputs.shape[:-1] + (-1, ))
     return (
         # L2 term.
         -jnp.sum((point - inputs)**2, axis=-1) / (2 * self._std**2) -
         # Normalizing constant.
         ((jnp.log(self._std) + jnp.log(jnp.sqrt(2 * jnp.pi))) *
          np.prod(self._shape)))
Пример #8
0
 def update(self, step, grads, weights, avg_sq_grad, opt_params):
     del step
     lr = opt_params['learning_rate']
     gamma = opt_params['gamma']
     eps = opt_params['eps']
     avg_sq_grad = avg_sq_grad * gamma + grads**2 * (1. - gamma)
     weights = weights - (lr * grads /
                          (jnp.sqrt(avg_sq_grad) + eps)).astype(
                              weights.dtype)
     return weights, avg_sq_grad
Пример #9
0
def _per_head_attention(queries, keys, values, mask, dropout, mode, rng):
  """Computes new per-head activations via scaled dot-product attention.

  This function is the core of the attention mechanism. Given per-head
  ``queries`` (Q), ``keys`` (K), ``values`` (V), and ``mask``, it:

    - computes the scaled dot product of each Q-K pair;
    - applies ``mask`` to screen out positions that come from padding tokens
      (indicated by 0 value);
    - [in ``'train'`` mode] applies dropout to Q-K dot products;
    - computes Q-K attention strengths using a per-query softmax of the Q-K dot
      products; and
    - for each query position, combines V vectors according to the Q-K
      attention strengths.

  Args:
    queries: Per-head activations representing attention queries.
    keys: Per-head activations representing attention keys.
    values: Per-head activations to be combined by computed attention strengths.
    mask: Mask that distinguishes positions with real content vs. padding.
    dropout: Probababilistic rate for attention dropout, which overrides
        (sets to zero) some attention strengths derived from query-key
        matching. As a result, on a given forward pass, some value vectors
        don't contribute to the output, analogous to how regular dropout can
        cause some node activations to be ignored. Applies only in ``'train'``
        mode.
    mode: One of ``'train'``, ``'eval'``, or ``'predict'``.
    rng: Single-use random number generator (JAX PRNG key).

  Returns:
    Tuple of (activations, attn_strengths), where activations are new per-head
    activation vectors and attn_strengths is a matrix of per-head attention
    strengths.
  """
  if dropout >= 1.0:
    raise ValueError(f'Dropout rate ({dropout}) must be lower than 1.')

  d_feature = queries.shape[-1]

  dots = jnp.matmul(queries, jnp.swapaxes(keys, -1, -2)) / jnp.sqrt(d_feature)
  if mask is not None:
    dots = jnp.where(mask,
                     dots,
                     jnp.full_like(dots, -1e9))
  attn_strengths = (
      jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True)))
  if dropout is not None and dropout > 0.0 and mode == 'train':
    keep = fastmath.random.bernoulli(rng, 1.0 - dropout, attn_strengths.shape)
    attn_strengths = jnp.where(keep,
                               attn_strengths / (1.0 - dropout),
                               jnp.zeros_like(attn_strengths))
  activations = jnp.matmul(attn_strengths, values).astype(jnp.float32)
  attn_strengths = attn_strengths.astype(jnp.float32)
  return activations, attn_strengths
Пример #10
0
 def learning_rate(step):
     """Step to learning rate function."""
     ret = 1.0
     for name in factors:
         if name == 'constant':
             ret *= constant
         elif name == 'linear_warmup':
             ret *= jnp.minimum(1.0, step / warmup_steps)
         elif name == 'rsqrt_decay':
             ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
         elif name == 'rsqrt_normalized_decay':
             ret *= jnp.sqrt(warmup_steps)
             ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
         elif name == 'decay_every':
             ret *= (decay_factor**(step // steps_per_decay))
         elif name == 'cosine_decay':
             progress = jnp.maximum(0.0, (step - warmup_steps) /
                                    float(steps_per_cycle))
             ret *= (0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0))))
         else:
             raise ValueError('Unknown factor %s.' % name)
     return float(ret)
Пример #11
0
 def update(self, step, grads, weights, slots, opt_params):
   m, v = slots
   learning_rate = opt_params['learning_rate']
   weight_decay_rate = opt_params['weight_decay_rate']
   b1 = opt_params['b1']
   b2 = opt_params['b2']
   eps = opt_params['eps']
   m = (1 - b1) * grads + b1 * m  # First  moment estimate.
   v = (1 - b2) * (grads ** 2) + b2 * v  # Second moment estimate.
   mhat = m / (1 - b1 ** (step + 1))  # Bias correction.
   vhat = v / (1 - b2 ** (step + 1))
   new_weights = ((1 - weight_decay_rate) * weights - (
       learning_rate * mhat / (jnp.sqrt(vhat) + eps))).astype(weights.dtype)
   return new_weights, (m, v)
Пример #12
0
def DotProductAttention(queries, keys, values, mask, dropout, mode, rng):
    """Computes new activations via masked attention-weighted sum of values.

  This function is the core of the attention mechanism. It:
    - computes per-head attention weights from per-head ``queries`` and
      ``keys``,
    - applies ``mask`` to screen out positions that come from padding tokens,
    - optionally applies dropout to attention weights, and
    - uses attention weights to combine per-head ``values`` vectors.

  Args:
    queries: Per-head activations representing attention queries.
    keys: Per-head activations representing attention keys.
    values: Per-head activations to be combined by computed attention weights.
    mask: Mask that distinguishes positions with real content vs. padding.
    dropout: Probababilistic rate for attention dropout, which overrides
        (sets to zero) some attention strengths derived from query-key
        matching. As a result, on a given forward pass, some value vectors
        don't contribute to the output, analogous to how regular dropout can
        cause some node activations to be ignored.
    mode: One of ``'train'``, ``'eval'``, or ``'predict'``.
    rng: Single-use random number generator (JAX PRNG key).

  Returns:
    Per-head activations resulting from masked per-head attention-weighted
    sum of per-head values.
  """
    d_feature = queries.shape[-1]
    dots = jnp.matmul(queries, jnp.swapaxes(keys, -1,
                                            -2)) / jnp.sqrt(d_feature)
    if mask is not None:
        dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))
    # Softmax.
    dots = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
        raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
        keep = fastmath.random.bernoulli(rng, 1.0 - dropout, dots.shape)
        dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots))
    out = jnp.matmul(dots, values)
    out = out.astype(jnp.float32)
    dots = dots.astype(jnp.float32)
    return out, dots
Пример #13
0
  def _calc_attn_scores(q, k):
    ac = jnp.einsum('bnid,bnjd->bnij', q + context_bias, k)
    bd = jnp.einsum('bnid,jnd->bnij', q + location_bias, pos_emb)

    if mode != 'predict':
      bd = _fast_matrix_shift(bd)

    dots = (ac + bd) / jnp.sqrt(d_feature)
    dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))

    # Softmax.
    dots = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
      raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
      keep = fastmath.random.bernoulli(rng, 1.0 - dropout, dots.shape)
      dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots))

    return dots
Пример #14
0
def DotProductAttention(query, key, value, mask):
    assert query.shape[-1] == key.shape[-1] == value.shape[-1]

    depth = query.shape[-1]

    dots = jnp.matmul(query, jnp.swapaxes(key, -1, -2)) / jnp.sqrt(
        depth)  # Part of dot product formula

    # Apply mask
    if mask is not None:
        dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))

    # Rest of dot product attention formula
    logsumexp = trax.fastmath.logsumexp(dots, axis=-1, keepdims=True)

    dots = jnp.exp(dots - logsumexp)

    attention = jnp.matmul(dots, value)

    return attention
Пример #15
0
def DotProductAttention(query, key, value, mask):
    """Dot product self-attention.
    Args:
        query (jax.interpreters.xla.DeviceArray): array of query representations with shape (L_q by d)
        key (jax.interpreters.xla.DeviceArray): array of key representations with shape (L_k by d)
        value (jax.interpreters.xla.DeviceArray): array of value representations with shape (L_k by d) where L_v = L_k
        mask (jax.interpreters.xla.DeviceArray): attention-mask, gates attention with shape (L_q by L_k)

    Returns:
        jax.interpreters.xla.DeviceArray: Self-attention array for q, k, v arrays. (L_q by L_k)
    """

    assert query.shape[-1] == key.shape[-1] == value.shape[
        -1], "Embedding dimensions of q, k, v aren't all the same"

    # scaling down (Q. K) dot product with square root of depth
    depth = query.shape[-1]

    # Calculate scaled query key dot product according to formula above
    dots = jnp.matmul(query, jnp.swapaxes(key, -1, -2)) / jnp.sqrt(depth)

    # Apply the mask
    if mask is not None:
        dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))

    # Softmax formula implementation
    # Use trax.fastmath.logsumexp of dots to avoid underflow by division by large numbers
    logsumexp = trax.fastmath.logsumexp(dots, axis=-1, keepdims=True)

    # Note: softmax = e^(dots - logsumexp(dots)) = E^dots / sumexp(dots)
    dots = jnp.exp(dots - logsumexp)

    # Multiply dots by value to get self-attention
    # Use jnp.matmul()
    attention = jnp.matmul(dots, value)

    return attention