Exemple #1
0
def DotProductAttention(query, key, value, mask, dropout, mode, rng):
    """Core dot product self-attention.

  Args:
    query: array of representations
    key: array of representations
    value: array of representations
    mask: attention-mask, gates attention
    dropout: float: dropout rate
    mode: 'eval' or 'train': whether to use dropout
    rng: JAX PRNGKey: subkey for disposable use

  Returns:
    Self attention for q, k, v arrays.
  """
    depth = np.shape(query)[-1]
    dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth)
    if mask is not None:
        # TODO(kitaev): workaround for https://github.com/google/jax/issues/850
        # We must ensure that both mask and the -1e9 constant have a data dependency
        # on the input. Broadcasted copies of these use a lot of memory, so they
        # should be computed at runtime (rather than being global constants).
        if math.backend_name() == 'jax':
            mask = jax.lax.tie_in(dots, mask)
        # JAX's `full_like` already ties in -1e9 to dots.
        dots = np.where(mask, dots, np.full_like(dots, -1e9))
    # Softmax.
    dots = np.exp(dots - math.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
        raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
        keep = math.random.bernoulli(rng, 1.0 - dropout, dots.shape)
        dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots))
    out = np.matmul(dots, value)
    return out
Exemple #2
0
def PPOObjective(dist_inputs, values, returns, dones, rewards, actions,
                 old_log_probs, log_prob_fun, epsilon, normalize_advantages):
    """PPO Objective."""
    # dist_inputs of the shape float32[128,1,18]
    # values of the shape float32[128,1,1]
    # returns of the shape float32[128,1,1]
    # dones of the shape float32[128,1,1]
    # rewards of the shape int32[128,1,1]
    # actions of the shape int32[128,1]
    # and old_log_probs of the shape float32[128,1]
    returns = returns.squeeze(axis=2)
    values = values.squeeze(axis=2)
    dones = dones.squeeze(axis=2)
    rewards = rewards.squeeze(axis=2)
    assert rewards.shape == dones.shape, (
        f'rewards.shape was {rewards.shape} and dones.shape was {dones.shape}')
    assert dones.shape == values.shape, (
        f'dones.shape was {dones.shape} and values.shape was {values.shape}')
    assert returns.shape == values.shape, (
        f'returns.shape was {returns.shape} and values.shape was {values.shape}'
    )
    assert returns.shape == old_log_probs.shape, (
        f'returns.shape was {returns.shape} and'
        f'old_log_probs.shape was {old_log_probs.shape}')

    probs_ratio = ProbsRatio(dist_inputs, actions, old_log_probs, log_prob_fun)
    assert probs_ratio.shape == old_log_probs.shape, (
        f'probs_ratio.shape was {probs_ratio.shape} and'
        f'old_log_probs.shape was {old_log_probs.shape}')

    # jaxified versions of
    # returns[dones] = rewards[dones]
    # values[dones] = 0
    returns = jnp.where(dones, rewards, returns)
    values = jnp.where(dones, jnp.zeros_like(values), values)
    advantages = returns - values
    if normalize_advantages:
        advantages = advantages - jnp.mean(advantages)
        advantages /= jnp.std(advantages) + 1e-8
    assert old_log_probs.shape == advantages.shape, (
        f'old_log_probs.shape was {old_log_probs.shape} and advantages.shape was '
        f'{advantages.shape}')

    unclipped_objective = UnclippedObjective(probs_ratio, advantages)
    assert unclipped_objective.shape == advantages.shape, (
        f'old_log_probs.shape was {old_log_probs.shape} and'
        f'unclipped_objective.shape was {unclipped_objective.shape}')

    clipped_objective = ClippedObjective(probs_ratio, advantages, epsilon)
    assert clipped_objective.shape == advantages.shape, (
        f'clipped_objective.shape was {clipped_objective.shape} and'
        f'advantages.shape was {advantages.shape}')

    ppo_objective = jnp.minimum(unclipped_objective, clipped_objective)
    assert ppo_objective.shape == advantages.shape, (
        f'ppo_objective.shape was {ppo_objective.shape} and'
        f'advantages.shape was {advantages.shape}')

    return ppo_objective
Exemple #3
0
def A2CObjective(dist_inputs, values, returns, dones, rewards, actions, mask,
                 log_prob_fun, normalize_advantages):
    """Definition of the Advantage Actor Critic (A2C) loss."""
    # dist_inputs of the shape float32[128,1,18]
    # values of the shape float32[128,1,1]
    # returns of the shape float32[128,1,1]
    # dones of the shape int32[128,1,1]
    # actions of the shape int32[128,1]
    # and mask of the shape float32[128,1]
    # We have to squeeze values and returns, because we
    # are planning to compute (return - values) * new_log_probs * mask
    # and all of them should be of the same dimension
    values = values.squeeze(axis=2)
    returns = returns.squeeze(axis=2)
    dones = dones.squeeze(axis=2)
    rewards = rewards.squeeze(axis=2)
    assert rewards.shape == dones.shape, (
        f'rewards.shape was {rewards.shape} and dones.shape was {dones.shape}')
    assert dones.shape == values.shape, (
        f'dones.shape was {dones.shape} and values.shape was {values.shape}')
    assert returns.shape == values.shape, (
        f'returns.shape was {returns.shape} and values.shape was {values.shape}'
    )
    assert values.shape == mask.shape, (
        f'values.shape was {values.shape} and mask.shape was {mask.shape}')
    assert returns.shape[0] == dist_inputs.shape[0], (
        f'returns.shape[0] was {returns.shape[0]} and dist_inputs.shape[0] was '
        f'{dist_inputs.shape[0]}')

    new_log_probs = NewLogProbs(dist_inputs, actions, log_prob_fun)
    assert new_log_probs.shape == mask.shape, (
        f'new_log_probs.shape was {new_log_probs.shape} and mask.shape was '
        f'{mask.shape}')

    # jaxified versions of
    # returns[dones] = rewards[dones]
    # values[dones] = 0
    returns = jnp.where(dones, rewards, returns)
    values = jnp.where(dones, jnp.zeros_like(values), values)
    advantages = returns - values
    if normalize_advantages:
        advantages = advantages - jnp.mean(advantages)
        advantages /= jnp.std(advantages) + 1e-8
    assert new_log_probs.shape == advantages.shape, (
        f'new_log_probs.shape was {new_log_probs.shape} and advantages.shape was '
        f'{advantages.shape}')

    # One of the motivation to the squeezes and assertions is to
    # avoid [128,1] * [128,1,1] * [128] multiplications in the definition
    # of the a2c objective - we insist on the same shapes
    a2c_objective = -jnp.sum(new_log_probs * advantages * mask) / jnp.sum(mask)
    return a2c_objective
Exemple #4
0
def DotProductAttention(queries, keys, values, mask, dropout, mode, rng):
    """Computes new activations via masked attention-weighted sum of values.

  This function is the core of the attention mechanism. It:
    - computes per-head attention weights from per-head `(queries, keys)`,
    - applies `mask` to screen out positions that come from padding tokens,
    - optionally applies dropout to attention weights, and
    - uses attention weights to combine per-head `values` vectors.

  Args:
    queries: Per-head activations representing attention queries.
    keys: Per-head activations representing attention keys.
    values: Per-head activations to be combined by computed attention weights.
    mask: Mask that distinguishes positions with real content vs. padding.
    dropout: Probababilistic rate for dropout applied to attention activations
        (based on query-key pairs) before dotting them with values.
    mode: Either 'train' or eval'. Dropout applies only in 'train' mode.
    rng: Single-use random number generator (JAX PRNG key).

  Returns:
    Per-head activations resulting from masked per-head attention-weighted
    sum of per-head values.
  """
    d_feature = queries.shape[-1]
    dots = jnp.matmul(queries, jnp.swapaxes(keys, -1,
                                            -2)) / jnp.sqrt(d_feature)
    if mask is not None:
        # TODO(kitaev): workaround for https://github.com/google/jax/issues/850
        # We must ensure that both mask and the -1e9 constant have a data dependency
        # on the input. Broadcasted copies of these use a lot of memory, so they
        # should be computed at runtime (rather than being global constants).
        if math.backend_name() == 'jax':
            mask = jax.lax.tie_in(dots, mask)
        # JAX's `full_like` already ties in -1e9 to dots.
        dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))
    # Softmax.
    dots = jnp.exp(dots - math.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
        raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
        keep = math.random.bernoulli(rng, 1.0 - dropout, dots.shape)
        dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots))
    out = jnp.matmul(dots, values)
    return out
Exemple #5
0
def Relu():
  r"""Returns a layer that computes the Rectified Linear Unit (ReLU) function.

  .. math::
      f(x) = \left\{ \begin{array}{cl}
          0 & \text{if}\ x \leq 0, \\
          x & \text{otherwise}.
      \end{array} \right.
  """
  return Fn('Relu', lambda x: np.where(x <= 0, np.zeros_like(x), x))
Exemple #6
0
 def _update_diagonal(self, grads, weights, m, v, opt_params):
   learning_rate = opt_params['learning_rate']
   momentum = opt_params['momentum']
   v[0] += grads * grads
   preconditioner = np.where(v[0] > 0, 1.0 / np.sqrt(v[0]),
                             np.zeros_like(v[0]))
   preconditioned_grads = preconditioner * grads
   m = (1 - momentum) * preconditioned_grads + momentum * m
   weights = weights - (learning_rate * m).astype(weights.dtype)
   return weights, (m, v)
Exemple #7
0
 def forward(self, x, weights):
     """Execute dropout."""
     if self._mode != 'train':
         return x
     state, rng = self.state, self.rng
     rate = self._initial_rate
     if isinstance(state, dict) and self._name in state:
         rate = state[self._name]
     keep = math.random.bernoulli(rng, 1.0 - rate, x.shape)
     return jnp.where(keep, x / (1.0 - rate), jnp.zeros_like(x))
Exemple #8
0
def LeakyRelu(a=0.01):
  r"""Returns a ReLU-like layer with linear nonzero outputs for negative inputs.

  .. math::
      f(x) = \left\{ \begin{array}{cl}
          ax & \text{if}\ x \leq 0, \\
          x  & \text{otherwise}.
      \end{array} \right.

  Args:
    a: Slope of line for negative inputs.
  """
  return Fn('LeakyRelu', lambda x: np.where(x >= 0, x, a * x))
Exemple #9
0
 def forward_with_state(self, x, weights, state, rng):
   """Execute dropout."""
   if self._mode != 'train':
     return x, state
   rate = self._initial_rate
   if isinstance(state, dict) and self._name in state:
     rate = state[self._name]
   if rng is None:
     msg = ('Dropout layer requires apply_fn to be called with a rng keyword '
            'argument. That is, instead of `Dropout(weights, inputs)`, call '
            'it like `Dropout(weights, inputs, rng=key)`.')
     raise ValueError(msg)
   keep = math.random.bernoulli(rng, 1.0 - rate, x.shape)
   return jnp.where(keep, x / (1.0 - rate), jnp.zeros_like(x)), state
Exemple #10
0
def Selu(alpha=1.6732632423543772848170429916717,
         lmbda=1.0507009873554804934193349852946):
  r"""Returns an `Elu`-like layer with an additional scaling/slope parameter.

  .. math::
      f(x) = \left\{ \begin{array}{cl}
          \lambda \cdot \alpha \cdot (e^x - 1) & \text{if}\ x \leq 0, \\
          \lambda \cdot x                      & \text{otherwise}.
      \end{array} \right.

  Args:
    alpha: Coefficient multiplying the exponential, for negative inputs.
    lmbda: Coefficient scaling the whole function.
  """
  return Fn('Selu', lambda x: lmbda * np.where(x > 0, x, alpha * np.expm1(x)))
Exemple #11
0
def Elu(a=1.):
  r"""Returns a ReLU-like layer with exponential outputs for negative inputs.

  .. math::
      f(x) = \left\{ \begin{array}{cl}
          a \cdot (e^x - 1) & \text{if}\ x \leq 0, \\
          x                 & \text{otherwise}.
      \end{array} \right.

  (Asymptotically, :math:`f(x)\rightarrow -a` as :math:`x\rightarrow - \infty`.)

  Args:
    a: Coefficient multiplying the exponential, for negative inputs.
  """
  return Fn('Elu', lambda x: np.where(x > 0, x, a * np.expm1(x)))
Exemple #12
0
  def forward(self, x):
    """Executes this layer as part of a forward pass through the model.

    Args:
      x: Tensor of activations.

    Returns:
      Tensor of same shape and dtype as the input.
    """
    if self._mode != 'train':
      return x
    state, rng = self.state, self.rng
    rate = self._initial_rate
    if isinstance(state, dict) and self._name in state:
      rate = state[self._name]
    keep = math.random.bernoulli(rng, 1.0 - rate, x.shape)
    return jnp.where(keep, x / (1.0 - rate), jnp.zeros_like(x))
Exemple #13
0
 def tree_update(self, step, grad_tree, weight_tree, slots, opt_params):
     """Assembles node-local weight and slot updates for the full layer tree."""
     grads_flat = _tree_flatten(grad_tree)
     if self._clip_grad_norm is not None:
         max_norm = self._clip_grad_norm
         norm = np.sqrt(sum(np.vdot(x, x) for x in grads_flat))
         grads_flat = [
             np.where(norm < max_norm, g, g * (max_norm / norm))
             for g in grads_flat
         ]
     weights_flat = _tree_flatten(weight_tree)
     updated_pairs = [
         self._update_and_check(step, grad, weight, slot, opt_params)
         for (grad, weight, slot) in zip(grads_flat, weights_flat, slots)
     ]
     new_weights_flat, self.slots = zip(*updated_pairs)
     new_weights, _ = _tree_unflatten(new_weights_flat, weight_tree)
     return new_weights, self.slots
Exemple #14
0
    def forward_unbatched(self, x, mask=None, *, weights, state, update_state):
        del update_state
        if self.share_qk:
            w_q, w_v, w_o = weights
        else:
            w_q, w_k, w_v, w_o = weights

        q = np.matmul(x, w_q)
        k = None
        if not self.share_qk:
            k = np.matmul(x, w_k)
        v = np.matmul(x, w_v)

        mask_fn = functools.partial(mask_self_attention,
                                    causal=self.causal,
                                    exclude_self=self.share_qk,
                                    masked=self.masked)
        q_info = kv_info = jax.lax.tie_in(x, np.arange(q.shape[-2]))

        assert (mask is not None) == self.masked
        if self.masked:
            # mask is a boolean array (True means "is valid token")
            ones_like_mask = jax.lax.tie_in(x,
                                            np.ones_like(mask, dtype=np.int32))
            kv_info = kv_info * np.where(mask, ones_like_mask, -ones_like_mask)

        o, _ = attend(
            q,
            k,
            v,
            q_chunk_len=self.chunk_len,
            kv_chunk_len=self.chunk_len,
            n_chunks_before=self.n_chunks_before,
            n_chunks_after=self.n_chunks_after,
            mask_fn=mask_fn,
            q_info=q_info,
            kv_info=kv_info,
            dropout=self.attention_dropout,
            rng=None,  # TODO(kitaev): support RNG
        )

        out = np.matmul(o, w_o)
        return out, state
Exemple #15
0
 def tree_update(self, step, grad_tree, weight_tree, slots, opt_params):
   """Assembles node-local weight and slot updates for the full layer tree."""
   grads_flat = _tree_flatten(grad_tree)
   grads_norm = self._l2_norm(grads_flat)
   if self._clip_grad_norm is not None:
     max_norm = self._clip_grad_norm
     grads_flat = [np.where(grads_norm < max_norm,  # pylint: disable=g-complex-comprehension
                            g,
                            g * (max_norm / grads_norm))
                   for g in grads_flat]
   weights_flat = _tree_flatten(weight_tree)
   weights_norm = self._l2_norm(weights_flat)
   updated_pairs = [
       self._update_and_check(step, grad, weight, slot, opt_params)
       for (grad, weight, slot) in zip(grads_flat, weights_flat, slots)
   ]
   new_weights_flat, self.slots = zip(*updated_pairs)
   new_weights, _ = _tree_unflatten(new_weights_flat, weight_tree)
   metrics = {'gradients_l2': grads_norm, 'weights_l2': weights_norm}
   return new_weights, self.slots, metrics
Exemple #16
0
 def _update_sketched(self, grads, weights, m, v, opt_params):
   """Update for higher-rank parameters."""
   learning_rate = opt_params['learning_rate']
   momentum = opt_params['momentum']
   shape = weights.shape
   rank = len(shape)
   reshaped_accumulators = [np.reshape(v[i], self._expanded_shape(shape, i))
                            for i in range(rank)]
   current_accumulator = self._minimum(reshaped_accumulators)
   current_accumulator += grads * grads
   accumulator_inv_sqrt = np.where(current_accumulator > 0.0,
                                   1.0 / np.sqrt(current_accumulator),
                                   np.zeros_like(current_accumulator))
   preconditioned_gradient = grads * accumulator_inv_sqrt
   m = (1.0 - momentum) * preconditioned_gradient + momentum * m
   weights = weights - (learning_rate * m).astype(weights.dtype)
   for i in range(len(v)):
     axes = list(range(int(i))) + list(range(int(i) + 1, rank))
     dim_accumulator = np.amax(current_accumulator, axis=axes)
     v[i] = dim_accumulator
   return weights, (m, v)
Exemple #17
0
    def tree_update(self, step, grad_tree, weight_tree, slots, opt_params):
        """Assembles node-local weight and slot updates for the full layer tree.

    Args:
      step: Current step number in the training process.
      grad_tree: Gradients for the entire model, in a tree that matches the
          model's layer structure.
      weight_tree: Current weights for the entire model, in a tree that matches
          the model's layer structure.
      slots: Optimizer slots.
      opt_params: Optimizer hyperparameters (e.g. learning rate, momentum).

    Returns:
      Tuple `(weights, slots)`, where `weights` are the optimizer-updated
      weights for the whole model (in a tree matching the model's layer
      structure) and `slots` are the updated optimizer slot values.
    """
        grads_flat = math.tree_flatten(grad_tree)
        grads_norm = self._l2_norm(grads_flat)
        if self._clip_grad_norm is not None:
            max_norm = self._clip_grad_norm
            grads_flat = [
                np.where(
                    grads_norm < max_norm,  # pylint: disable=g-complex-comprehension
                    g,
                    g * (max_norm / grads_norm)) for g in grads_flat
            ]
        weights_flat = math.tree_flatten(weight_tree)
        weights_norm = self._l2_norm(weights_flat)
        updated_pairs = [
            self._update_and_check(step, grad, weight, slot, opt_params)
            for (grad, weight, slot) in zip(grads_flat, weights_flat, slots)
        ]
        new_weights_flat, self.slots = zip(*updated_pairs)
        new_weights, _ = math.tree_unflatten(new_weights_flat, weight_tree)
        metrics = {'gradients_l2': grads_norm, 'weights_l2': weights_norm}
        return new_weights, self.slots, metrics
Exemple #18
0
def clip_grads(grad_tree, max_norm):
    """Clip gradients stored as a pytree of arrays to maximum norm `max_norm`."""
    norm = l2_norm(grad_tree)
    normalize = lambda g: np.where(norm < max_norm, g, g * (max_norm / norm))
    return layers.nested_map(grad_tree, normalize)
Exemple #19
0
def Elu(a=1.):
    return Fn('Elu', lambda x: np.where(x > 0, x, a * np.expm1(x)))
Exemple #20
0
def LeakyRelu(a=0.01):
    return Fn('LeakyRelu', lambda x: np.where(x >= 0, x, a * x))
Exemple #21
0
def Elu(x, a=1., **unused_kwargs):
    return np.where(x > 0, x, a * np.expm1(x))
Exemple #22
0
def LeakyRelu(x, a=0.01, **unused_kwargs):
    return np.where(x >= 0, x, a * x)
Exemple #23
0
def LeakyRelu(x, a=0.01):
    return np.where(x >= 0, x, a * x)
Exemple #24
0
def Selu(x,
         alpha=1.6732632423543772848170429916717,
         lmbda=1.0507009873554804934193349852946):
    return lmbda * np.where(x > 0, x, alpha * np.expm1(x))
Exemple #25
0
def Elu(x, a=1.):
    return np.where(x > 0, x, a * np.expm1(x))