Example #1
0
def DotProductAttention(query, key, value, mask, dropout, mode, rng):
    """Core dot product self-attention.

  Args:
    query: array of representations
    key: array of representations
    value: array of representations
    mask: attention-mask, gates attention
    dropout: float: dropout rate
    mode: 'eval' or 'train': whether to use dropout
    rng: JAX PRNGKey: subkey for disposable use

  Returns:
    Self attention for q, k, v arrays.
  """
    depth = np.shape(query)[-1]
    dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth)
    if mask is not None:
        # TODO(kitaev): workaround for https://github.com/google/jax/issues/850
        # We must ensure that both mask and the -1e9 constant have a data dependency
        # on the input. Broadcasted copies of these use a lot of memory, so they
        # should be computed at runtime (rather than being global constants).
        if backend.get_name() == 'jax':
            mask = jax.lax.tie_in(dots, mask)
        # JAX's `full_like` already ties in -1e9 to dots.
        dots = np.where(mask, dots, np.full_like(dots, -1e9))
    # Softmax.
    dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
        raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
        keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape)
        dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots))
    out = np.matmul(dots, value)
    return out
Example #2
0
 def _update_diagonal(self, grads, params, m, v, opt_params):
     learning_rate = opt_params['learning_rate']
     momentum = opt_params['momentum']
     v[0] += grads * grads
     preconditioner = np.where(v[0] > 0, 1.0 / np.sqrt(v[0]),
                               np.zeros_like(v[0]))
     preconditioned_grads = preconditioner * grads
     m = (1 - momentum) * preconditioned_grads + momentum * m
     params = params - (learning_rate * m).astype(params.dtype)
     return params, (m, v)
Example #3
0
    def _forward_predict(self, inputs, state, rng):
        if not self._share_qk:
            state = _fast_inference_update_state(inputs, state)
            (q, _, _) = inputs
            (ks, vs, mask, index) = state
        else:
            mask_excluding_attention_in_place = state[2]
            (q, _, v) = inputs
            k = self.make_unit_length(q)
            state = _fast_inference_update_state((q, k, v), state)
            (ks, vs, mask, index) = state
            # Only the initial position in a sequence may attend to itself.
            mask = np.where(index > 1, mask_excluding_attention_in_place, mask)

        output = attention.DotProductAttention(q,
                                               ks,
                                               vs,
                                               mask,
                                               dropout=self.dropout,
                                               mode=self._mode,
                                               rng=rng)

        def roll_state(state):
            """Rolls the buffers backward to make space for new data."""
            (ks, vs, mask, index) = state

            # Move the second bin into the first one's place in both buffers.
            def roll_buffer(buf):
                return jax.ops.index_update(
                    buf,
                    jax.ops.index[:, :self.bin_length, :],
                    buf[:, self.bin_length:, :],
                )

            (ks, vs) = map(roll_buffer, (ks, vs))
            # Zero out the second bin in the mask.
            mask = jax.ops.index_update(mask, jax.ops.index[:, :,
                                                            self.bin_length:],
                                        0)
            # Update the index to match the rolled buffers.
            index -= self.bin_length
            return (ks, vs, mask, index)

        # Once we get to the end of the buffer, move the second bin back to make
        # space for new data: [ bin_i bin_{i+1} | ] -> [ bin_{i+1} | bin_{i+1} ],
        # where | is where index points at in the buffer.
        state = jax.lax.cond(
            pred=(index == 2 * self.bin_length),
            true_operand=state,
            true_fun=roll_state,
            false_operand=state,
            false_fun=(lambda x: x),
        )
        return (output, state)
Example #4
0
 def forward_with_state(self, x, weights=base.EMPTY_WEIGHTS,
                        state=base.EMPTY_STATE, rng=None, **kwargs):
   """Execute dropout."""
   del kwargs
   if self._mode != 'train':
     return x, state
   rate = self._initial_rate
   if isinstance(state, dict) and self._name in state:
     rate = state[self._name]
   if rng is None:
     msg = ('Dropout layer requires apply_fn to be called with a rng keyword '
            'argument. That is, instead of `Dropout(weights, inputs)`, call '
            'it like `Dropout(weights, inputs, rng=key)`.')
     raise ValueError(msg)
   keep = backend.random.bernoulli(rng, 1.0 - rate, x.shape)
   return np.where(keep, x / (1.0 - rate), np.zeros_like(x)), state
 def _update_sketched(self, grads, params, m, v, opt_params):
   """Update for higher-rank parameters."""
   learning_rate = opt_params['learning_rate']
   momentum = opt_params['momentum']
   shape = params.shape
   rank = len(shape)
   reshaped_accumulators = [np.reshape(v[i], self._expanded_shape(shape, i))
                            for i in range(rank)]
   current_accumulator = self._minimum(reshaped_accumulators)
   current_accumulator += grads * grads
   accumulator_inv_sqrt = np.where(current_accumulator > 0.0,
                                   1.0 / np.sqrt(current_accumulator),
                                   np.zeros_like(current_accumulator))
   preconditioned_gradient = grads * accumulator_inv_sqrt
   m = (1.0 - momentum) * preconditioned_gradient + momentum * m
   params = params - (learning_rate * m).astype(params.dtype)
   for i in range(len(v)):
     axes = list(range(int(i))) + list(range(int(i) + 1, rank))
     dim_accumulator = np.amax(current_accumulator, axis=axes)
     v[i] = dim_accumulator
   return params, (m, v)
Example #6
0
 def drop_for_hash(self, x, rng):
     rate = self._drop_for_hash_rate
     if self._mode == 'train' and rate > 0.0:
         keep = backend.random.bernoulli(rng, 1.0 - rate, x.shape)
         return np.where(keep, x / (1.0 - rate), np.zeros_like(x))
     return x
Example #7
0
def Selu(x,
         alpha=1.6732632423543772848170429916717,
         lmbda=1.0507009873554804934193349852946):
  return lmbda * np.where(x > 0, x, alpha * np.expm1(x))
Example #8
0
def Elu(x, a=1., **unused_kwargs):
  return np.where(x > 0, x, a * np.expm1(x))
Example #9
0
def LeakyRelu(x, a=0.01, **unused_kwargs):
  return np.where(x >= 0, x, a * x)
Example #10
0
def clip_grads(grad_tree, max_norm):
    """Clip gradients stored as a pytree of arrays to maximum norm `max_norm`."""
    norm = l2_norm(grad_tree)
    normalize = lambda g: np.where(norm < max_norm, g, g * (max_norm / norm))
    return layers.nested_map(grad_tree, normalize)