Esempio n. 1
0
 def init(self, params):
     shape = params.shape
     slots = []
     if self._factored and len(shape) >= 2:
         v_row = np.zeros(shape[:-1], dtype=np.float32)
         v_col = np.zeros(shape[:-2] + shape[-1:], dtype=np.float32)
         slots.extend([v_row, v_col])
     else:
         v = np.zeros_like(params)
         slots.append(v)
     if self._do_momentum:
         m = np.zeros_like(params)
         slots.append(m)
     return slots
Esempio n. 2
0
def DotProductAttention(query, key, value, mask, dropout, mode, rng):
    """Core dot product self-attention.

  Args:
    query: array of representations
    key: array of representations
    value: array of representations
    mask: attention-mask, gates attention
    dropout: float: dropout rate
    mode: 'eval' or 'train': whether to use dropout
    rng: JAX PRNGKey: subkey for disposable use

  Returns:
    Self attention for q, k, v arrays.
  """
    depth = np.shape(query)[-1]
    dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth)
    if mask is not None:
        # TODO(kitaev): workaround for https://github.com/google/jax/issues/850
        # We must ensure that both mask and the -1e9 constant have a data dependency
        # on the input. Broadcasted copies of these use a lot of memory, so they
        # should be computed at runtime (rather than being global constants).
        if backend.get_name() == 'jax':
            mask = jax.lax.tie_in(dots, mask)
        # JAX's `full_like` already ties in -1e9 to dots.
        dots = np.where(mask, dots, np.full_like(dots, -1e9))
    # Softmax.
    dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
        raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
        keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape)
        dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots))
    out = np.matmul(dots, value)
    return out
Esempio n. 3
0
 def _update_diagonal(self, grads, params, m, v, opt_params):
     learning_rate = opt_params['learning_rate']
     momentum = opt_params['momentum']
     v[0] += grads * grads
     preconditioner = np.where(v[0] > 0, 1.0 / np.sqrt(v[0]),
                               np.zeros_like(v[0]))
     preconditioned_grads = preconditioner * grads
     m = (1 - momentum) * preconditioned_grads + momentum * m
     params = params - (learning_rate * m).astype(params.dtype)
     return params, (m, v)
Esempio n. 4
0
 def forward_and_backward(self,
                          inputs,
                          ct,
                          state=base.EMPTY_STATE,
                          new_state=base.EMPTY_STATE,
                          rng=None,
                          **kwargs):
     del kwargs
     output, _, (qk_ct,
                 v_ct) = self.batch_call_and_or_grad(inputs[0],
                                                     inputs[2],
                                                     ct=ct,
                                                     new_state=new_state,
                                                     rng=rng)
     return output, (qk_ct, np.zeros_like(inputs[1]), v_ct)
Esempio n. 5
0
 def forward_with_state(self, x, weights=base.EMPTY_WEIGHTS,
                        state=base.EMPTY_STATE, rng=None, **kwargs):
   """Execute dropout."""
   del kwargs
   if self._mode != 'train':
     return x, state
   rate = self._initial_rate
   if isinstance(state, dict) and self._name in state:
     rate = state[self._name]
   if rng is None:
     msg = ('Dropout layer requires apply_fn to be called with a rng keyword '
            'argument. That is, instead of `Dropout(weights, inputs)`, call '
            'it like `Dropout(weights, inputs, rng=key)`.')
     raise ValueError(msg)
   keep = backend.random.bernoulli(rng, 1.0 - rate, x.shape)
   return np.where(keep, x / (1.0 - rate), np.zeros_like(x)), state
Esempio n. 6
0
 def backward(self,
              inputs,
              output,
              ct,
              weights=base.EMPTY_WEIGHTS,
              state=base.EMPTY_STATE,
              new_state=base.EMPTY_STATE,
              rng=None,
              **kwargs):
     del output, weights, state
     _, _, (qk_ct, v_ct) = self.batch_call_and_or_grad(inputs[0],
                                                       inputs[2],
                                                       return_output=False,
                                                       ct=ct,
                                                       new_state=new_state,
                                                       rng=rng)
     inputs_ct = (qk_ct, np.zeros_like(inputs[1]), v_ct)
     return inputs_ct, ()
 def _update_sketched(self, grads, params, m, v, opt_params):
   """Update for higher-rank parameters."""
   learning_rate = opt_params['learning_rate']
   momentum = opt_params['momentum']
   shape = params.shape
   rank = len(shape)
   reshaped_accumulators = [np.reshape(v[i], self._expanded_shape(shape, i))
                            for i in range(rank)]
   current_accumulator = self._minimum(reshaped_accumulators)
   current_accumulator += grads * grads
   accumulator_inv_sqrt = np.where(current_accumulator > 0.0,
                                   1.0 / np.sqrt(current_accumulator),
                                   np.zeros_like(current_accumulator))
   preconditioned_gradient = grads * accumulator_inv_sqrt
   m = (1.0 - momentum) * preconditioned_gradient + momentum * m
   params = params - (learning_rate * m).astype(params.dtype)
   for i in range(len(v)):
     axes = list(range(int(i))) + list(range(int(i) + 1, rank))
     dim_accumulator = np.amax(current_accumulator, axis=axes)
     v[i] = dim_accumulator
   return params, (m, v)
Esempio n. 8
0
    def forward_and_backward(self,
                             inputs,
                             ct,
                             state=base.EMPTY_STATE,
                             new_state=base.EMPTY_STATE,
                             rng=None,
                             **kwargs):
        del state, new_state, kwargs
        query, key, value = inputs
        depth = np.shape(query)[-1]
        do_backprop = ct is not None

        # jax uses the term cotangent (ct) to refer to gradient signals, and
        # vector-Jacobian product (vjp) for back-propagation through a layer.

        def make_mask(N, M, k):  # pylint: disable=invalid-name
            """Constructs a slice of the causal attention mask.

      Args:
        N: number of query positions
        M: number of key positions
        k: position of the initial query element

      Returns:
        N x M mask, where 1.0 indicates that attention is not allowed.
      """
            x = jax.lax.tie_in(k, np.arange(N, dtype=np.int32))
            y = jax.lax.tie_in(k, np.arange(M, dtype=np.int32))
            mask = jax.lax.lt((jax.lax.broadcast_in_dim(
                x, shape=(N, M), broadcast_dimensions=(0, )) + k),
                              jax.lax.broadcast(y, [N]))
            mask = jax.lax.convert_element_type(mask, np.float32)
            return mask

        def make_self_mask(N, M, k):  # pylint: disable=invalid-name
            """Masks out elements attending to self.

      Args:
        N: number of query positions
        M: number of key positions
        k: position of the initial query element

      Returns:
        N x M mask, where 1.0 indicates that attention is not allowed.
      """
            x = jax.lax.tie_in(k, np.arange(N, dtype=np.int32))
            y = jax.lax.tie_in(k, np.arange(M, dtype=np.int32))
            mask = jax.lax.eq((jax.lax.broadcast_in_dim(
                x, shape=(N, M), broadcast_dimensions=(0, )) + k),
                              jax.lax.broadcast(y, [N]))
            mask = jax.lax.convert_element_type(mask, np.float32)
            return mask

        def forward_slice(query_slice, q_loop_idx, key, value):  # pylint: disable=invalid-name
            """Forward pass for a subset of the query vectors."""
            if self._share_qk:
                key = self.make_unit_length(key)

            dots = np.matmul(query_slice, np.swapaxes(key, -1,
                                                      -2)) / np.sqrt(depth)

            # Causal masking
            mask = make_mask(dots.shape[-2], dots.shape[-1], q_loop_idx)
            dots = dots - 1e9 * mask

            # Mask out attention to self except when no other targets are available.
            if self._share_qk:
                self_mask = make_self_mask(dots.shape[-2], dots.shape[-1],
                                           q_loop_idx)
                dots = dots - 1e5 * self_mask

            # Softmax.
            dots = np.exp(dots -
                          backend.logsumexp(dots, axis=-1, keepdims=True))

            if self.dropout is not None and self.dropout > 0.0:
                # Dropout is broadcast across the batch+head dimension
                dropout_shape = (1, dots.shape[-2], dots.shape[-1])
                slice_rng = jax.random.fold_in(rng, q_loop_idx)
                keep_prob = jax.lax.tie_in(dots, 1.0 - self.dropout)
                keep = backend.random.bernoulli(slice_rng, keep_prob,
                                                dropout_shape)
                multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(
                    keep, keep_prob)
                dots = dots * multiplier

            if self._hard_k > 0:
                top_k = np.sort(dots)[...,
                                      -self._hard_k]  # Get the top-kth weight.
                top_k = jax.lax.stop_gradient(top_k)
                dots -= top_k[...,
                              np.newaxis]  # Subtract (be 0 for lower ones).
                dots = np.maximum(dots, 0)
                dots_sum = np.sum(dots, axis=-1,
                                  keepdims=True)  # Re-normalize.
                dots /= dots_sum  # Re-normalize.

            out_slice = np.matmul(dots, value)
            return out_slice

        def forward_and_vjp_slice(query_slice, q_loop_idx, key, value,
                                  ct_slice):  # pylint: disable=invalid-name
            # Capture q_loop_idx to avoid calculated gradients wrt. it.
            def forward_slice_with_q_loop_idx(query_slice, key, value):  # pylint: disable=invalid-name
                return forward_slice(query_slice, q_loop_idx, key, value)

            output_slice, vjpfun = jax.vjp(forward_slice_with_q_loop_idx,
                                           query_slice, key, value)
            return output_slice, vjpfun(ct_slice)

        q_loop_idx = np.zeros((), dtype=np.int32)
        q_loop_max = query.shape[-2]
        q_loop_stride = self._loop_stride
        if q_loop_max == 1:  # For abstract runs with unknown shapes.
            q_loop_stride = 1
        assert q_loop_max % q_loop_stride == 0, (
            'Stride must evenly divide the number of query elements.')

        out_accum = np.zeros_like(query)
        if do_backprop:
            query_ct_accum = np.zeros_like(query)
            key_ct_accum = np.zeros_like(key)
            value_ct_accum = np.zeros_like(value)
            init_vals = (q_loop_idx, out_accum, query_ct_accum, key_ct_accum,
                         value_ct_accum)
        else:
            init_vals = (q_loop_idx, out_accum)

        def cond_fun(vals):  # pylint: disable=invalid-name
            q_loop_idx = vals[0]
            return jax.lax.lt(q_loop_idx, q_loop_max)

        def body_fun(vals):  # pylint: disable=invalid-name
            """Compute a slice of the attention mechanism."""
            if do_backprop:
                (q_loop_idx, out_accum, query_ct_accum, key_ct_accum,
                 value_ct_accum) = vals
            else:
                q_loop_idx, out_accum = vals

            query_slice = jax.lax.dynamic_slice_in_dim(query,
                                                       q_loop_idx,
                                                       q_loop_stride,
                                                       axis=-2)

            if do_backprop:
                ct_slice = jax.lax.dynamic_slice_in_dim(ct,
                                                        q_loop_idx,
                                                        q_loop_stride,
                                                        axis=-2)
                out_slice, partial_ct = forward_and_vjp_slice(
                    query_slice, q_loop_idx, key, value, ct_slice)
                query_ct_accum = jax.lax.dynamic_update_slice_in_dim(
                    query_ct_accum, partial_ct[0], q_loop_idx, axis=-2)
                key_ct_accum = key_ct_accum + partial_ct[1]
                value_ct_accum = value_ct_accum + partial_ct[2]
            else:
                out_slice = forward_slice(query_slice, q_loop_idx, key, value)

            out_accum = jax.lax.dynamic_update_slice_in_dim(out_accum,
                                                            out_slice,
                                                            q_loop_idx,
                                                            axis=-2)
            q_loop_idx = q_loop_idx + q_loop_stride

            if do_backprop:
                return (q_loop_idx, out_accum, query_ct_accum, key_ct_accum,
                        value_ct_accum)
            else:
                return (q_loop_idx, out_accum)

        final_vals = jax.lax.while_loop(cond_fun, body_fun, init_vals)

        if not do_backprop:
            return final_vals[1], None
        else:
            return final_vals[1], final_vals[2:]
Esempio n. 9
0
 def drop_for_hash(self, x, rng):
     rate = self._drop_for_hash_rate
     if self._mode == 'train' and rate > 0.0:
         keep = backend.random.bernoulli(rng, 1.0 - rate, x.shape)
         return np.where(keep, x / (1.0 - rate), np.zeros_like(x))
     return x
Esempio n. 10
0
    def batch_call_and_or_grad(self,
                               qk,
                               v,
                               ct=None,
                               return_output=True,
                               new_state=None,
                               return_state=False,
                               rng=None):
        assert return_output or ct is not None, 'No work to perform!'
        if new_state is not None and new_state is not base.EMPTY_STATE:
            buckets = new_state
        else:
            buckets = None

        # The approach here is to perform attention for one batch element and head
        # at a time. Note that there is absolutely no interaction across examples or
        # heads: this layer has no parameters, and hashing patterns are also
        # different across examples/heads. As a result, batching doesn't give any
        # performance gains except in the case of accelerator under-utilization. We
        # assume that hash-based attention will be applied primarily to long
        # sequences, where unbatched attention for a single head has sufficient
        # computation to fill up the accelerator.

        batch_loop_idx = np.zeros((), dtype=np.int32)
        batch_loop_max = qk.shape[0]

        init_vals = (batch_loop_idx, )
        if return_output:
            out_accum = np.zeros_like(qk)
            init_vals = init_vals + (out_accum, )
        if return_state:
            buckets_accum = np.zeros(
                [qk.shape[0], self.n_hashes * qk.shape[1]], dtype=np.int32)
            init_vals = init_vals + (buckets_accum, )
        if ct is not None:
            qk_ct_accum = np.zeros_like(qk)
            v_ct_accum = np.zeros_like(v)
            init_vals = init_vals + (qk_ct_accum, v_ct_accum)

        def cond_fun(vals):
            batch_loop_idx = vals[0]
            return jax.lax.lt(batch_loop_idx, batch_loop_max)

        def body_fun(vals):
            """Performs attention for a single batch element and head."""
            batch_loop_idx = vals[0]
            if self._prng is None:
                hash_slice_rng = jax.random.fold_in(rng, batch_loop_idx)
                hash_rng, slice_rng = backend.random.split(hash_slice_rng)
            else:
                # TODO(kitaev): Maybe use the same RNG across examples (but not heads)?
                hash_rng = jax.random.fold_in(self._prng, batch_loop_idx)
                slice_rng = jax.random.fold_in(rng, batch_loop_idx)
            qk_slice = jax.lax.dynamic_index_in_dim(qk,
                                                    batch_loop_idx,
                                                    axis=0,
                                                    keepdims=False)
            v_slice = jax.lax.dynamic_index_in_dim(v,
                                                   batch_loop_idx,
                                                   axis=0,
                                                   keepdims=False)

            if buckets is None:
                buckets_slice = self.hash_vectors(qk_slice, rng=hash_rng)
            else:
                buckets_slice = jax.lax.dynamic_index_in_dim(buckets,
                                                             batch_loop_idx,
                                                             axis=0,
                                                             keepdims=False)

            if ct is None:
                out_slice = self.single_call(qk_slice,
                                             v_slice,
                                             buckets_slice,
                                             rng=slice_rng)
            else:

                def _do_single_call(qk_slice, v_slice):
                    return self.single_call(qk_slice,
                                            v_slice,
                                            buckets_slice,
                                            rng=slice_rng)

                ct_slice = jax.lax.dynamic_index_in_dim(ct,
                                                        batch_loop_idx,
                                                        axis=0,
                                                        keepdims=False)
                out_slice, vjpfun = jax.vjp(_do_single_call, qk_slice, v_slice)
                qk_ct_slice, v_ct_slice = vjpfun(ct_slice)

            new_vals = (batch_loop_idx + 1, )
            if return_output:
                out_accum = vals[1]
                out_accum = jax.lax.dynamic_update_index_in_dim(out_accum,
                                                                out_slice,
                                                                batch_loop_idx,
                                                                axis=0)
                new_vals = new_vals + (out_accum, )
            if return_state:
                buckets_accum = vals[2]
                buckets_accum = jax.lax.dynamic_update_index_in_dim(
                    buckets_accum, buckets_slice, batch_loop_idx, axis=0)
                new_vals = new_vals + (buckets_accum, )
            if ct is not None:
                qk_ct_accum, v_ct_accum = vals[-2:]
                qk_ct_accum = jax.lax.dynamic_update_index_in_dim(
                    qk_ct_accum, qk_ct_slice, batch_loop_idx, axis=0)
                v_ct_accum = jax.lax.dynamic_update_index_in_dim(
                    v_ct_accum, v_ct_slice, batch_loop_idx, axis=0)
                new_vals = new_vals + (qk_ct_accum, v_ct_accum)

            return new_vals

        final_vals = jax.lax.while_loop(cond_fun, body_fun, init_vals)

        if return_output:
            out = final_vals[1]
        else:
            out = None

        if return_state:
            state = final_vals[2]
        else:
            state = None

        if ct is not None:
            input_ct = final_vals[-2:]
        else:
            input_ct = None

        return out, state, input_ct
Esempio n. 11
0
def ParametricRelu(x, a=1., **unused_kwargs):
  return np.maximum(a * x, np.zeros_like(x))
Esempio n. 12
0
def Relu(x, **unused_kwargs):
  return np.maximum(x, np.zeros_like(x))
Esempio n. 13
0
 def init(self, params):
     vs = [np.zeros(sz, dtype=params.dtype) for sz in params.shape]
     return (np.zeros_like(params), vs)
Esempio n. 14
0
 def init(self, params):
     m = np.zeros_like(params)
     v = np.zeros_like(params)
     return m, v
Esempio n. 15
0
 def init(self, params):
     return np.zeros_like(params)