Esempio n. 1
0
  def _update_diagonal(self, g, w, m, v1, v2, opt_params):
    learning_rate = opt_params['learning_rate']
    beta2 = opt_params['second_moment_averaging']
    weight_decay = opt_params['weight_decay']

    is_beta2_1 = (beta2 == 1).astype(g.dtype)
    one_minus_beta2_except1 = is_beta2_1  + (1.0 - beta2) * (1.0 - is_beta2_1)
    v1[0] = beta2 * v1[0] + one_minus_beta2_except1 * g * g

    preconditioner = jnp.where(v1[0] > 0, 1.0 / (jnp.sqrt(v1[0]) + 1e-16),
                               jnp.zeros_like(v1[0]))

    pg = preconditioner * g
    if self._graft:
      v2[0] += g * g
      preconditioner_graft = jnp.where(
          v2[0] > 0, 1.0 / (jnp.sqrt(v2[0]) + 1e-16), jnp.zeros_like(v2[0]))
      pg_graft = preconditioner_graft * g
      pg_norm = jnp.linalg.norm(pg)
      pg_graft_norm = jnp.linalg.norm(pg_graft)
      pg = pg * (pg_graft_norm/(pg_norm + 1e-16))

    pg = pg + w * weight_decay

    if self._has_momentum:
      m, update = self._momentum_update(pg, m, opt_params['momentum'])
    else:
      update = pg

    w = w - (update * learning_rate).astype(w.dtype)
    return w, (m, v1, v2)
Esempio n. 2
0
    def _update_sketched(self, g, w, m, v1, v2, opt_params):
        """Update for higher-rank parameters."""
        learning_rate = opt_params['learning_rate']
        momentum = opt_params['momentum']
        beta2 = opt_params['second_moment_averaging']
        weight_decay = opt_params['weight_decay']

        shape = w.shape
        rank = len(shape)
        reshaped_accumulators = [
            jnp.reshape(v1[i], self._expanded_shape(shape, i))
            for i in range(rank)
        ]
        acc = self._minimum(reshaped_accumulators)

        is_beta2_1 = (beta2 == 1).astype(g.dtype)
        one_minus_beta2_except1 = is_beta2_1 + (1.0 - beta2) * (1.0 -
                                                                is_beta2_1)
        acc = beta2 * acc + one_minus_beta2_except1 * g * g

        preconditioner = jnp.where(acc > 0.0, 1.0 / (jnp.sqrt(acc) + 1e-16),
                                   jnp.zeros_like(acc))
        pg = g * preconditioner
        if self._graft:
            v2_acc = self._minimum([
                jnp.reshape(v2[i], self._expanded_shape(shape, i))
                for i in range(rank)
            ])
            v2_acc = v2_acc + g * g
            preconditioner_graft = jnp.where(v2_acc > 0.0,
                                             1.0 / (jnp.sqrt(v2_acc) + 1e-16),
                                             jnp.zeros_like(v2_acc))
            pg_graft = preconditioner_graft * g
            pg_norm = jnp.linalg.norm(pg)
            pg_graft_norm = jnp.linalg.norm(pg_graft)
            pg = pg * (pg_graft_norm / (pg_norm + 1e-16))

        pg = pg + w * weight_decay

        if self._has_momentum:
            m, update = self._momentum_update(pg, m, momentum)
        else:
            update = pg

        w = w - (learning_rate * update).astype(w.dtype)
        for i in range(len(v1)):
            axes = list(range(int(i))) + list(range(int(i) + 1, rank))
            dim_accumulator = jnp.amax(acc, axis=axes)
            v1[i] = dim_accumulator

        if self._graft:
            for i in range(len(v2)):
                axes = list(range(int(i))) + list(range(int(i) + 1, rank))
                dim_accumulator = jnp.amax(v2_acc, axis=axes)
                v2[i] = dim_accumulator
        return w, (m, v1, v2)
Esempio n. 3
0
 def init(self, weights):
     shape = weights.shape
     slots = []
     if self._factored and len(shape) >= 2:
         v_row = jnp.zeros(shape[:-1], dtype=jnp.float32)
         v_col = jnp.zeros(shape[:-2] + shape[-1:], dtype=jnp.float32)
         slots.extend([v_row, v_col])
     else:
         v = jnp.zeros_like(weights)
         slots.append(v)
     if self._do_momentum:
         m = jnp.zeros_like(weights)
         slots.append(m)
     return slots
Esempio n. 4
0
 def f(preds, values, returns, actions, mask):
   advantages = jnp.squeeze(returns - stop_gradient(values), axis=-1)
   logps = self._policy_dist.log_prob(preds, actions)
   awr_loss = actor_critic.AWRLoss(beta=self._beta, w_max=self._w_max)(
       (logps, advantages, jnp.zeros_like(logps), mask))
   l2_value_loss = jnp.mean((returns - values)**2) * self._value_loss_coeff
   return awr_loss + l2_value_loss
Esempio n. 5
0
 def forward(self, x):
   rng = self.rng
   batch_size, length = x.shape[0], x.shape[1]
   max_pos = min(self._bases)**self._n_digits
   rng1, rng2, rng3 = fastmath.random.split(rng, 3)
   assert length < max_pos, 'length (%d) >= max_pos (%d)' % (length, max_pos)
   positions = jnp.arange(0, length)[None, :]
   if self._mode == 'train':
     # In 1% of training cases still start from 0 to be exactly as in eval.
     start_from_nonzero = fastmath.random.randint(
         rng1, (batch_size,), 0, self._start_from_zero_one_in)
     start_from_nonzero = jnp.minimum(1, start_from_nonzero)
     random_start = fastmath.random.randint(
         rng2, (batch_size,), 0, max_pos-length)
     random_start *= start_from_nonzero
     positions += random_start[:, None]
   res = []
   for bn, base in enumerate(self._bases):
     pos_embeddings = []
     cur_positions = positions
     for i in range(self._n_digits):
       cur_indices = jnp.mod(cur_positions, base)
       cur_positions = cur_positions // base
       s = self.weights[bn][i]
       pos_embeddings.append(cur_indices.astype(jnp.float32)[:, :, None] * s)
     embeddings = jnp.concatenate(pos_embeddings, axis=-1)
     if self._mode == 'train':
       base_dropout = fastmath.random.randint(
           rng3, (batch_size,), 0, self._base_dropout_one_in)
       base_dropout = jnp.minimum(1, base_dropout).astype(jnp.float32)
       embeddings *= base_dropout[:, None, None]
     res.append(embeddings)
   res = sum(res) + jnp.zeros_like(x)
   return x + res
Esempio n. 6
0
def PPOObjective(dist_inputs, values, returns, dones, rewards, actions,
                 old_log_probs, log_prob_fun, epsilon, normalize_advantages):
    """PPO Objective."""
    # dist_inputs of the shape float32[128,1,18]
    # values of the shape float32[128,1,1]
    # returns of the shape float32[128,1,1]
    # dones of the shape float32[128,1,1]
    # rewards of the shape int32[128,1,1]
    # actions of the shape int32[128,1]
    # and old_log_probs of the shape float32[128,1]
    returns = returns.squeeze(axis=2)
    values = values.squeeze(axis=2)
    dones = dones.squeeze(axis=2)
    rewards = rewards.squeeze(axis=2)
    assert rewards.shape == dones.shape, (
        f'rewards.shape was {rewards.shape} and dones.shape was {dones.shape}')
    assert dones.shape == values.shape, (
        f'dones.shape was {dones.shape} and values.shape was {values.shape}')
    assert returns.shape == values.shape, (
        f'returns.shape was {returns.shape} and values.shape was {values.shape}'
    )
    assert returns.shape == old_log_probs.shape, (
        f'returns.shape was {returns.shape} and'
        f'old_log_probs.shape was {old_log_probs.shape}')

    probs_ratio = ProbsRatio(dist_inputs, actions, old_log_probs, log_prob_fun)
    assert probs_ratio.shape == old_log_probs.shape, (
        f'probs_ratio.shape was {probs_ratio.shape} and'
        f'old_log_probs.shape was {old_log_probs.shape}')

    # jaxified versions of
    # returns[dones] = rewards[dones]
    # values[dones] = 0
    returns = jnp.where(dones, rewards, returns)
    values = jnp.where(dones, jnp.zeros_like(values), values)
    advantages = returns - values
    if normalize_advantages:
        advantages = advantages - jnp.mean(advantages)
        advantages /= jnp.std(advantages) + 1e-8
    assert old_log_probs.shape == advantages.shape, (
        f'old_log_probs.shape was {old_log_probs.shape} and advantages.shape was '
        f'{advantages.shape}')

    unclipped_objective = UnclippedObjective(probs_ratio, advantages)
    assert unclipped_objective.shape == advantages.shape, (
        f'old_log_probs.shape was {old_log_probs.shape} and'
        f'unclipped_objective.shape was {unclipped_objective.shape}')

    clipped_objective = ClippedObjective(probs_ratio, advantages, epsilon)
    assert clipped_objective.shape == advantages.shape, (
        f'clipped_objective.shape was {clipped_objective.shape} and'
        f'advantages.shape was {advantages.shape}')

    ppo_objective = jnp.minimum(unclipped_objective, clipped_objective)
    assert ppo_objective.shape == advantages.shape, (
        f'ppo_objective.shape was {ppo_objective.shape} and'
        f'advantages.shape was {advantages.shape}')

    return ppo_objective
Esempio n. 7
0
 def init(self, w):
   momentum = []
   if self._has_momentum:
     momentum = jnp.zeros_like(w)
   v1s = [jnp.zeros(sz, dtype=w.dtype) for sz in w.shape]
   v2s = []
   if self._graft:
     v2s = [jnp.zeros(sz, dtype=w.dtype) for sz in w.shape]
   return (momentum, v1s, v2s)
Esempio n. 8
0
 def _UpdateRow(x):
   # row_e - (L1, H), row_d - (L2, H), row_mask_e - (L1,)
   row_e, row_d, row_mask_e = x
   # final_row - (L1+L2, H)
   final_row = jnp.concatenate([row_e, jnp.zeros_like(row_d)], axis=0)
   # Find the last real token/vector of the encoder.
   e_idx = jnp.sum(row_mask_e, dtype=jnp.int32)
   # Starting after that index, update with the decoder row.
   return jax.lax.dynamic_update_slice(final_row, row_d, (e_idx, 0))
Esempio n. 9
0
def DotProductAttention(queries, keys, values, pos_emb, context_bias,
                        location_bias, mask, separate_cls, dropout, mode, rng):
    """Computes new activations via masked attention-weighted sum of values.

  This function is the core of the attention mechanism. It:
    - computes per-head attention weights from per-head `queries` and `keys`,
    - applies `mask` to screen out positions that come from padding tokens,
    - optionally applies dropout to attention weights, and
    - uses attention weights to combine per-head `values` vectors.

  Args:
    queries: Per-head activations representing attention queries.
    keys: Per-head activations representing attention keys.
    values: Per-head activations to be combined by computed attention weights.
    pos_emb: Per-head activations representing positional embeddings.
    context_bias: Global context bias from Transformer XL's attention.
    location_bias: Global location bias from Transformer XL's attention.
    mask: Mask that distinguishes positions with real content vs. padding.
    separate_cls: True/False if we separate_cls in calculations.
    dropout: Probabilistic rate for dropout applied to attention strengths
      (based on query-key pairs) before applying them to values.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
    rng: Single-use random number generator (JAX PRNG key).

  Returns:
    Per-head activations resulting from masked per-head attention-weighted
    sum of per-head values.
  """
    d_feature = queries.shape[-1]
    keys_len, queries_len = keys.shape[-2], queries.shape[-2]
    funnel_factor, is_upsampling = calc_funnel_ratio(keys_len, queries_len)

    ac = jnp.einsum('bnid,bnjd->bnij', queries + context_bias, keys)
    bd = jnp.einsum('bnid,jnd->bnij', queries + location_bias, pos_emb)

    if mode != 'predict':
        bd = _fast_matrix_shift(bd, funnel_factor, is_upsampling)

    if separate_cls:
        # Masking out location part of attention for cls token
        bd = bd.at[:, :, :, 0].set(0)
        bd = bd.at[:, :, 0, :].set(0)

    dots = (ac + bd) / jnp.sqrt(d_feature)
    if mask is not None:
        dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))
    # Softmax.
    dots = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
        raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
        keep = fastmath.random.bernoulli(rng, 1.0 - dropout, dots.shape)
        dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots))
    out = jnp.matmul(dots, values)
    out = out.astype(jnp.float32)
    dots = dots.astype(jnp.float32)
    return out, dots
Esempio n. 10
0
 def _update_diagonal(self, grads, weights, m, v, opt_params):
   learning_rate = opt_params['learning_rate']
   momentum = opt_params['momentum']
   v[0] += grads * grads
   preconditioner = jnp.where(v[0] > 0, 1.0 / jnp.sqrt(v[0]),
                              jnp.zeros_like(v[0]))
   preconditioned_grads = preconditioner * grads
   m = (1 - momentum) * preconditioned_grads + momentum * m
   weights = weights - (learning_rate * m).astype(weights.dtype)
   return weights, (m, v)
Esempio n. 11
0
 def _UpdateRow(x):
     # row_e - (L1, H), row_d - (L2, H), row_mask_e - (L1,)
     row_e, row_d, row_mask_e = x
     # final_row - (L1+L2, H)
     final_row = jnp.concatenate([row_e, jnp.zeros_like(row_d)], axis=0)
     # Find the last real token/vector of the encoder.
     e_idx = jnp.sum(row_mask_e, dtype=jnp.int32)
     # Starting after that index, update with the decoder row.
     zero = jnp.array(0, dtype=e_idx.dtype)  # avoid int32/int64 mismatch
     return fastmath.dynamic_update_slice(final_row, row_d, (e_idx, zero))
Esempio n. 12
0
def Relu():
  r"""Returns a layer that computes the Rectified Linear Unit (ReLU) function.

  .. math::
      f(x) = \left\{ \begin{array}{cl}
          0 & \text{if}\ x \leq 0, \\
          x & \text{otherwise}.
      \end{array} \right.
  """
  return Fn('Relu', lambda x: jnp.where(x <= 0, jnp.zeros_like(x), x))
Esempio n. 13
0
def _per_head_attention(queries, keys, values, mask, dropout, mode, rng):
  """Computes new per-head activations via scaled dot-product attention.

  This function is the core of the attention mechanism. Given per-head
  ``queries`` (Q), ``keys`` (K), ``values`` (V), and ``mask``, it:

    - computes the scaled dot product of each Q-K pair;
    - applies ``mask`` to screen out positions that come from padding tokens
      (indicated by 0 value);
    - [in ``'train'`` mode] applies dropout to Q-K dot products;
    - computes Q-K attention strengths using a per-query softmax of the Q-K dot
      products; and
    - for each query position, combines V vectors according to the Q-K
      attention strengths.

  Args:
    queries: Per-head activations representing attention queries.
    keys: Per-head activations representing attention keys.
    values: Per-head activations to be combined by computed attention strengths.
    mask: Mask that distinguishes positions with real content vs. padding.
    dropout: Probababilistic rate for attention dropout, which overrides
        (sets to zero) some attention strengths derived from query-key
        matching. As a result, on a given forward pass, some value vectors
        don't contribute to the output, analogous to how regular dropout can
        cause some node activations to be ignored. Applies only in ``'train'``
        mode.
    mode: One of ``'train'``, ``'eval'``, or ``'predict'``.
    rng: Single-use random number generator (JAX PRNG key).

  Returns:
    Tuple of (activations, attn_strengths), where activations are new per-head
    activation vectors and attn_strengths is a matrix of per-head attention
    strengths.
  """
  if dropout >= 1.0:
    raise ValueError(f'Dropout rate ({dropout}) must be lower than 1.')

  d_feature = queries.shape[-1]

  dots = jnp.matmul(queries, jnp.swapaxes(keys, -1, -2)) / jnp.sqrt(d_feature)
  if mask is not None:
    dots = jnp.where(mask,
                     dots,
                     jnp.full_like(dots, -1e9))
  attn_strengths = (
      jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True)))
  if dropout is not None and dropout > 0.0 and mode == 'train':
    keep = fastmath.random.bernoulli(rng, 1.0 - dropout, attn_strengths.shape)
    attn_strengths = jnp.where(keep,
                               attn_strengths / (1.0 - dropout),
                               jnp.zeros_like(attn_strengths))
  activations = jnp.matmul(attn_strengths, values).astype(jnp.float32)
  attn_strengths = attn_strengths.astype(jnp.float32)
  return activations, attn_strengths
Esempio n. 14
0
def A2CObjective(dist_inputs, values, returns, dones, rewards, actions, mask,
                 log_prob_fun, normalize_advantages):
    """Definition of the Advantage Actor Critic (A2C) loss."""
    # dist_inputs of the shape float32[128,1,18]
    # values of the shape float32[128,1,1]
    # returns of the shape float32[128,1,1]
    # dones of the shape int32[128,1,1]
    # actions of the shape int32[128,1]
    # and mask of the shape float32[128,1]
    # We have to squeeze values and returns, because we
    # are planning to compute (return - values) * new_log_probs * mask
    # and all of them should be of the same dimension
    values = values.squeeze(axis=2)
    returns = returns.squeeze(axis=2)
    dones = dones.squeeze(axis=2)
    rewards = rewards.squeeze(axis=2)
    assert rewards.shape == dones.shape, (
        f'rewards.shape was {rewards.shape} and dones.shape was {dones.shape}')
    assert dones.shape == values.shape, (
        f'dones.shape was {dones.shape} and values.shape was {values.shape}')
    assert returns.shape == values.shape, (
        f'returns.shape was {returns.shape} and values.shape was {values.shape}'
    )
    assert values.shape == mask.shape, (
        f'values.shape was {values.shape} and mask.shape was {mask.shape}')
    assert returns.shape[0] == dist_inputs.shape[0], (
        f'returns.shape[0] was {returns.shape[0]} and dist_inputs.shape[0] was '
        f'{dist_inputs.shape[0]}')

    new_log_probs = NewLogProbs(dist_inputs, actions, log_prob_fun)
    assert new_log_probs.shape == mask.shape, (
        f'new_log_probs.shape was {new_log_probs.shape} and mask.shape was '
        f'{mask.shape}')

    # jaxified versions of
    # returns[dones] = rewards[dones]
    # values[dones] = 0
    returns = jnp.where(dones, rewards, returns)
    values = jnp.where(dones, jnp.zeros_like(values), values)
    advantages = returns - values
    if normalize_advantages:
        advantages = advantages - jnp.mean(advantages)
        advantages /= jnp.std(advantages) + 1e-8
    assert new_log_probs.shape == advantages.shape, (
        f'new_log_probs.shape was {new_log_probs.shape} and advantages.shape was '
        f'{advantages.shape}')

    # One of the motivation to the squeezes and assertions is to
    # avoid [128,1] * [128,1,1] * [128] multiplications in the definition
    # of the a2c objective - we insist on the same shapes
    a2c_objective = -jnp.sum(new_log_probs * advantages * mask) / jnp.sum(mask)
    return a2c_objective
Esempio n. 15
0
 def batches_stream(self):
   """Use the RLTask self._task to create inputs to the value model."""
   for np_trajectory in self._task.trajectory_batch_stream(
       self._batch_size, max_slice_length=self._max_slice_length, epochs=[-1]):
     # Insert an extra depth dimension, so the target shape is consistent with
     # the network output shape.
     yield (np_trajectory.observations,         # Inputs to the value model.
            np_trajectory.returns[:, :, None],
            np_trajectory.dones[:, :, None],
            np_trajectory.rewards[:, :, None],
            np_trajectory.actions,
            jnp.zeros_like(np_trajectory.mask),
            np_trajectory.mask)
Esempio n. 16
0
def ParametricRelu(a=1.):
  r"""Returns a layer that computes a ReLU function with the given slope.

  .. math::
      f(x) = \left\{ \begin{array}{cl}
          0  & \text{if}\ x \leq 0, \\
          ax & \text{otherwise}.
      \end{array} \right.

  Args:
    a: Slope of line for positive inputs.
  """
  return Fn('ParametricRelu', lambda x: jnp.maximum(a * x, jnp.zeros_like(x)))
Esempio n. 17
0
  def favor_denominator_bwd(qkp, r_ct):
    precision, qs, ks, p = qkp

    def body(carry, qkx):
      p, p_ct = carry
      q, k, x_ct = qkx
      q_ct = jnp.einsum('...,...m->...m', x_ct, p, precision=precision)
      p_ct += jnp.einsum('...,...m->...m', x_ct, q, precision=precision)
      k_ct = p_ct
      p -= k
      return (p, p_ct), (q_ct, k_ct)

    _, (qs_ct, ks_ct) = fastmath.scan(
        body, (p, jnp.zeros_like(p)), (qs, ks, r_ct), reverse=True)
    return (None, None, qs_ct, ks_ct)
Esempio n. 18
0
  def favor_numerator_bwd(pqkv, w_ct):
    precision, p, qs, ks, vs = pqkv

    def body(carry, qkv_xct):
      p, p_ct = carry
      q, k, v, x_ct = qkv_xct
      q_ct = jnp.einsum('...d,...md->...m', x_ct, p, precision=precision)
      p_ct += jnp.einsum('...d,...m->...md', x_ct, q, precision=precision)
      k_ct = jnp.einsum('...md,...d->...m', p_ct, v, precision=precision)
      v_ct = jnp.einsum('...md,...m->...d', p_ct, k, precision=precision)
      p -= jnp.einsum('...m,...d->...md', k, v, precision=precision)
      return (p, p_ct), (q_ct, k_ct, v_ct)

    _, (qs_ct, ks_ct, vs_ct) = fastmath.scan(
        body, (p, jnp.zeros_like(p)), (qs, ks, vs, w_ct), reverse=True)
    return (None, None, qs_ct, ks_ct, vs_ct)
Esempio n. 19
0
    def favor_denominator_bwd(init_prefix_sum_value, precision, qkp, r_ct):
        del init_prefix_sum_value

        def body(carry, qkx):
            p, p_ct = carry
            q, k, x_ct = qkx
            q_ct = np.einsum('...,...m->...m', x_ct, p, precision=precision)
            p_ct += np.einsum('...,...m->...m', x_ct, q, precision=precision)
            k_ct = p_ct
            p -= k
            return (p, p_ct), (q_ct, k_ct)

        qs, ks, p = qkp
        _, (qs_ct, ks_ct) = fastmath.scan(body, (p, np.zeros_like(p)),
                                          (qs, ks, r_ct),
                                          reverse=True)
        return (qs_ct, ks_ct)
Esempio n. 20
0
def DotProductAttention(queries, keys, values, mask, dropout, mode, rng):
  """Computes new activations via masked attention-weighted sum of values.

  This function is the core of the attention mechanism. It:
    - computes per-head attention weights from per-head `queries` and `keys`,
    - applies `mask` to screen out positions that come from padding tokens,
    - optionally applies dropout to attention weights, and
    - uses attention weights to combine per-head `values` vectors.

  Args:
    queries: Per-head activations representing attention queries.
    keys: Per-head activations representing attention keys.
    values: Per-head activations to be combined by computed attention weights.
    mask: Mask that distinguishes positions with real content vs. padding.
    dropout: Probababilistic rate for dropout applied to attention strengths
        (based on query-key pairs) before applying them to values.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
    rng: Single-use random number generator (JAX PRNG key).

  Returns:
    Per-head activations resulting from masked per-head attention-weighted
    sum of per-head values.
  """
  d_feature = queries.shape[-1]
  dots = jnp.matmul(queries, jnp.swapaxes(keys, -1, -2)) / jnp.sqrt(d_feature)
  if mask is not None:
    # TODO(kitaev): workaround for https://github.com/google/jax/issues/850
    # We must ensure that both mask and the -1e9 constant have a data dependency
    # on the input. Broadcasted copies of these use a lot of memory, so they
    # should be computed at runtime (rather than being global constants).
    if fastmath.is_backend(fastmath.Backend.JAX):
      mask = jax.lax.tie_in(dots, mask)
    # JAX's `full_like` already ties in -1e9 to dots.
    dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))
  # Softmax.
  dots = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True))
  if dropout >= 1.0:
    raise ValueError('Dropout rates must be lower than 1.')
  if dropout is not None and dropout > 0.0 and mode == 'train':
    keep = fastmath.random.bernoulli(rng, 1.0 - dropout, dots.shape)
    dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots))
  out = jnp.matmul(dots, values)
  out = out.astype(jnp.float32)
  dots = dots.astype(jnp.float32)
  return out, dots
Esempio n. 21
0
    def favor_numerator_bwd(init_prefix_sum_value, precision, pqkv, w_ct):
        del init_prefix_sum_value

        def body(carry, qkv_xct):
            p, p_ct = carry
            q, k, v, x_ct = qkv_xct
            q_ct = np.einsum('...d,...md->...m', x_ct, p, precision=precision)
            p_ct += np.einsum('...d,...m->...md', x_ct, q, precision=precision)
            k_ct = np.einsum('...md,...d->...m', p_ct, v, precision=precision)
            v_ct = np.einsum('...md,...m->...d', p_ct, k, precision=precision)
            p -= np.einsum('...m,...d->...md', k, v, precision=precision)
            return (p, p_ct), (q_ct, k_ct, v_ct)

        p, qs, ks, vs = pqkv
        _, (qs_ct, ks_ct, vs_ct) = fastmath.scan(body, (p, np.zeros_like(p)),
                                                 (qs, ks, vs, w_ct),
                                                 reverse=True)
        return qs_ct, ks_ct, vs_ct
Esempio n. 22
0
def DotProductAttention(queries, keys, values, mask, dropout, mode, rng):
    """Computes new activations via masked attention-weighted sum of values.

  This function is the core of the attention mechanism. It:
    - computes per-head attention weights from per-head ``queries`` and
      ``keys``,
    - applies ``mask`` to screen out positions that come from padding tokens,
    - optionally applies dropout to attention weights, and
    - uses attention weights to combine per-head ``values`` vectors.

  Args:
    queries: Per-head activations representing attention queries.
    keys: Per-head activations representing attention keys.
    values: Per-head activations to be combined by computed attention weights.
    mask: Mask that distinguishes positions with real content vs. padding.
    dropout: Probababilistic rate for attention dropout, which overrides
        (sets to zero) some attention strengths derived from query-key
        matching. As a result, on a given forward pass, some value vectors
        don't contribute to the output, analogous to how regular dropout can
        cause some node activations to be ignored.
    mode: One of ``'train'``, ``'eval'``, or ``'predict'``.
    rng: Single-use random number generator (JAX PRNG key).

  Returns:
    Per-head activations resulting from masked per-head attention-weighted
    sum of per-head values.
  """
    d_feature = queries.shape[-1]
    dots = jnp.matmul(queries, jnp.swapaxes(keys, -1,
                                            -2)) / jnp.sqrt(d_feature)
    if mask is not None:
        dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))
    # Softmax.
    dots = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
        raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
        keep = fastmath.random.bernoulli(rng, 1.0 - dropout, dots.shape)
        dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots))
    out = jnp.matmul(dots, values)
    out = out.astype(jnp.float32)
    dots = dots.astype(jnp.float32)
    return out, dots
Esempio n. 23
0
  def _calc_attn_scores(q, k):
    ac = jnp.einsum('bnid,bnjd->bnij', q + context_bias, k)
    bd = jnp.einsum('bnid,jnd->bnij', q + location_bias, pos_emb)

    if mode != 'predict':
      bd = _fast_matrix_shift(bd)

    dots = (ac + bd) / jnp.sqrt(d_feature)
    dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))

    # Softmax.
    dots = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
      raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
      keep = fastmath.random.bernoulli(rng, 1.0 - dropout, dots.shape)
      dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots))

    return dots
Esempio n. 24
0
 def _update_sketched(self, grads, weights, m, v, opt_params):
   """Update for higher-rank parameters."""
   learning_rate = opt_params['learning_rate']
   momentum = opt_params['momentum']
   shape = weights.shape
   rank = len(shape)
   reshaped_accumulators = [jnp.reshape(v[i], self._expanded_shape(shape, i))
                            for i in range(rank)]
   current_accumulator = self._minimum(reshaped_accumulators)
   current_accumulator += grads * grads
   accumulator_inv_sqrt = jnp.where(current_accumulator > 0.0,
                                    1.0 / jnp.sqrt(current_accumulator),
                                    jnp.zeros_like(current_accumulator))
   preconditioned_gradient = grads * accumulator_inv_sqrt
   m = (1.0 - momentum) * preconditioned_gradient + momentum * m
   weights = weights - (learning_rate * m).astype(weights.dtype)
   for i in range(len(v)):
     axes = list(range(int(i))) + list(range(int(i) + 1, rank))
     dim_accumulator = jnp.amax(current_accumulator, axis=axes)
     v[i] = dim_accumulator
   return weights, (m, v)
Esempio n. 25
0
def threefry_2x32_prange(key, lo: int = 0, hi: int = 2):
    """Splits a key into a stream of random keys.

  This uses the little-endian counter mode.

  Args:
    key: uint32[2] the key to split
    lo: the range to start extracting from
    hi: the range to stop extracting from

  Returns:
    keys: uint32[hi - lo, 2] the split keys
  """
    if not (key.shape == (2, ) and key.dtype == jnp.uint32):
        raise ValueError('key must be uint32[2]')
    if not hi < 2**32:
        # You shouldn't really be using more than half the key size anyways.
        raise NotImplementedError('only 32-bit sizes are supported')
    # Create a 64-bit counter:
    i_lo = jnp.arange(lo, hi, dtype=jnp.uint32)
    i_hi = jnp.zeros_like(i_lo)
    i = jnp.stack([i_lo, i_hi], axis=-1)
    return threefry_2x32_prf(key, i)
Esempio n. 26
0
    def forward(self, x):
        """Executes this layer as part of a forward pass through the model.

    Args:
      x: Tensor of same shape and dtype as the input signature used to
        initialize this layer.

    Returns:
      Tensor of same shape and dtype as the input.
    """
        m1, w1, w2, b2 = self.weights
        x_shape = x.shape
        x = jnp.reshape(x,
                        [-1, x_shape[-1]])  # Easier to operate on flattened x.

        # Q: check if we need bias and/or put relu after the m1 dot?
        mask_logits = jnp.dot(x, m1)
        # Softmax.
        mask_logsumexp = fastmath.logsumexp(mask_logits,
                                            axis=-1,
                                            keepdims=True)
        log_mask = mask_logits - mask_logsumexp
        mask = jnp.exp(log_mask)
        # Gumbel-softmax with straight-through discretization.
        # TODO(lukaszkaiser, chowdhery): Extract this block and share
        rng1, rng2 = fastmath.random.split(self.rng, 2)
        u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6,
                                    1.0 - 1e-6)
        g = -jnp.log(-jnp.log(u))
        selected_experts = jnp.argmax(log_mask + g * self._temperature,
                                      axis=-1)
        if self._mode == 'train':
            # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797
            quant_mask = tl.one_hot(selected_experts, self._num_experts)
            quant_mask = fastmath.stop_gradient(quant_mask)
            quant_mask += mask - fastmath.stop_gradient(
                mask)  # straight-through
            # We will sometimes (50% of the batches) use the soft-mask instead of
            # the quantized mask to improve training stability (see the paper above).
            # Q: is selecting 50% of batches the best? Other %? Mixed in-batch?
            select = fastmath.random.uniform(rng2, (), jnp.float32, -1.0, 1.0)
            quant_mask = jnp.where(select > 0.0, quant_mask, mask)
        else:
            quant_mask = tl.one_hot(selected_experts, self._num_experts)
        quant_mask = jnp.reshape(quant_mask, [-1, self._num_experts, 1])
        quant_mask_shape = quant_mask.shape
        batch_size = quant_mask.shape[0]

        if self._mode == 'predict' and batch_size == 1:
            # This implementation mimicks inference for batch_size 1.
            start_idx = selected_experts[0] * self._n_elements_in_block
            # w1 is [d_model, d_ff], w is [d_model, n_elements_in_block]
            w = fastmath.dynamic_slice(
                w1, [0, start_idx], [w1.shape[0], self._n_elements_in_block])
            mid = jnp.dot(x, w)
            relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
            # w2 is [d_ff, d_model], v is [n_elements_in_block, d_model]
            v = fastmath.dynamic_slice(
                w2, [start_idx, 0], [self._n_elements_in_block, w2.shape[-1]])
            v = jnp.reshape(v, [self._n_elements_in_block, -1])
            res = jnp.dot(relu, v) + b2
        else:
            expanded_mask = jnp.broadcast_to(
                quant_mask, (quant_mask_shape[0], quant_mask.shape[1],
                             self._n_elements_in_block))
            expanded_mask = jnp.reshape(expanded_mask, (-1, self._d_ff))
            mid = jnp.dot(x, w1) * expanded_mask  # [joint_batch, d_ff]
            relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
            res = jnp.dot(relu, w2) + b2

        return jnp.reshape(res, x_shape)  # un-flatten if needed
Esempio n. 27
0
 def init(self, weights):
   vs = [jnp.zeros(sz, dtype=weights.dtype) for sz in weights.shape]
   return (jnp.zeros_like(weights), vs)
Esempio n. 28
0
 def relu(x):
     return jnp.where(x <= 0, jnp.zeros_like(x), x)
Esempio n. 29
0
    def forward(self, x):
        """Executes this layer as part of a forward pass through the model.

    Args:
      x: Tensor of same shape and dtype as the input signature used to
          initialize this layer.

    Returns:
      Tensor of same shape and dtype as the input.
    """
        m1, m2, mb, w1, w2, b2 = self.weights
        if self._mode != 'predict':
            w1 = jnp.reshape(w1.T, (-1, self._d_ff))
            w2 = jnp.reshape(w2, (self._d_ff, -1))
        x_shape = x.shape
        x = jnp.reshape(x,
                        [-1, x_shape[-1]])  # Easier to operate on flattened x.

        # Q: should we add bias and/or put relu after the low-rank m1 dot?
        mask_logits = jnp.dot(jnp.dot(x, m1), m2) + mb
        mask_logits = jnp.reshape(mask_logits, [-1, self._d1, self._d2])
        # Softmax.
        mask_logsumexp = fastmath.logsumexp(mask_logits,
                                            axis=-1,
                                            keepdims=True)
        log_mask = mask_logits - mask_logsumexp
        mask = jnp.exp(log_mask)
        # Gumbel-softmax with straight-through discretization.
        rng1, rng2 = fastmath.random.split(self.rng, 2)
        u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6,
                                    1.0 - 1e-6)
        g = -jnp.log(-jnp.log(u))
        quant_mask = jnp.argmax(log_mask + g * self._temperature, axis=-1)
        if self._mode == 'train':
            # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797
            quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block)
            quant_mask = fastmath.stop_gradient(quant_mask)
            quant_mask += mask - fastmath.stop_gradient(
                mask)  # straight-through
            # We will sometimes (quant_prob of the batches) use the soft-mask instead
            # of the quantized mask to improve training stability (see paper above).
            select = fastmath.random.uniform(rng2, (), jnp.float32, 0.0, 1.0)
            quant_mask = jnp.where(select < self._quant_prob, quant_mask, mask)
            quant_mask = jnp.reshape(quant_mask, [-1, self._d_ff])

        if self._mode == 'train':
            # In training, run full matmul to get benefits from the above tricks.
            mid = jnp.dot(x, w1) * quant_mask  # [joint_batch, d_ff]
            relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
            res = jnp.dot(relu, w2) + b2
        elif self._mode == 'predict':
            # w1 = jnp.reshape(w1.T, (self._d1, self._d2, -1))
            # w2 = jnp.reshape(w2, (self._d1, self._d2, -1))
            # This implementation mimicks inference. It's not efficient for large
            # size of joint_batch, but at inference that will be 1 most of the time.
            # Shapes:
            # quant_mask is [joint_batch, self._d1]
            # w1 is [d_model, self._d1, self._d2]
            # we'll index w1 with advanced numpy indexing, first range over
            # self._d1 times the batch size, second range being quant_mask
            batch_size = quant_mask.shape[0]
            idx1 = jnp.array([jnp.arange(self._d1)] * batch_size)
            # flatten indices and select from w1
            idx1 = jnp.reshape(idx1, [-1])
            idx2 = jnp.reshape(quant_mask, [-1])
            w = w1[idx1,
                   idx2, :]  # now we have per-element weights with batch dim
            w = jnp.reshape(w, [batch_size, self._d1, -1])
            mid = jnp.einsum('ai,aji->aj', x, w)
            relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
            # w2 is [self._d1, self._d2, d_model]
            v = w2[idx1, idx2, :]
            v = jnp.reshape(v, [batch_size, self._d1, -1])
            res = jnp.einsum('ai,aij->aj', relu, v) + b2
        else:
            quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block)
            quant_mask = jnp.reshape(quant_mask, [-1, self._d_ff])
            mid = jnp.dot(x, w1) * quant_mask  # [joint_batch, d_ff]
            relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
            res = jnp.dot(relu, w2) + b2

        return jnp.reshape(res, x_shape)  # un-flatten if needed
Esempio n. 30
0
 def backward(self, inputs, output, grad, weights, state, new_state,
              rng):
     return (jnp.zeros_like(grad), ())