Ejemplo n.º 1
0
 def init(self, params):
     shape = params.shape
     slots = []
     if self._factored and len(shape) >= 2:
         v_row = np.zeros(shape[:-1], dtype=np.float32)
         v_col = np.zeros(shape[:-2] + shape[-1:], dtype=np.float32)
         slots.extend([v_row, v_col])
     else:
         v = np.zeros_like(params)
         slots.append(v)
     if self._do_momentum:
         m = np.zeros_like(params)
         slots.append(m)
     return slots
Ejemplo n.º 2
0
 def init(self, x):
   shape = x.shape
   state = []
   if self._factored and len(shape) >= 2:
     v_row = np.zeros(shape[:-1], dtype=np.float32)
     v_col = np.zeros(shape[:-2] + shape[-1:], dtype=np.float32)
     state.extend([v_row, v_col])
   else:
     v = np.zeros_like(x)
     state.append(v)
   if self._beta1:
     m = np.zeros_like(x)
     state.append(m)
   return state
Ejemplo n.º 3
0
 def backward(self, inputs, output, ct, params=(), state=(), rng=None,
              **kwargs):
   del output, params, state
   _, (qk_ct, v_ct) = self.batch_call_and_or_grad(
       inputs[0], inputs[2], return_output=False, ct=ct, rng=rng)
   inputs_ct = (qk_ct, np.zeros_like(inputs[1]), v_ct)
   return inputs_ct, ()
Ejemplo n.º 4
0
def DotProductAttention(query, key, value, mask, dropout, mode, rng):
    """Core dot product self-attention.

  Args:
    query: array of representations
    key: array of representations
    value: array of representations
    mask: attention-mask, gates attention
    dropout: float: dropout rate
    mode: 'eval' or 'train': whether to use dropout
    rng: JAX PRNGKey: subkey for disposable use

  Returns:
    Self attention for q, k, v arrays.
  """
    depth = np.shape(query)[-1]
    dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth)
    if mask is not None:
        # TODO(kitaev): workaround for https://github.com/google/jax/issues/850
        # We must ensure that both mask and the -1e9 constant have a data dependency
        # on the input. Broadcasted copies of these use a lot of memory, so they
        # should be computed at runtime (rather than being global constants).
        if backend.get_name() == 'jax':
            mask = jax.lax.tie_in(dots, mask)
        dots = np.where(mask, dots, np.full_like(dots, -1e9))
    # Softmax.
    dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
        raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
        keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape)
        dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots))
    out = np.matmul(dots, value)
    return out
Ejemplo n.º 5
0
 def _update_diagonal(self, step, g, x, m, v):
     v[0] += g * g
     preconditioner = np.where(v[0] > 0, 1.0 / np.sqrt(v[0]),
                               np.zeros_like(v[0]))
     preconditioned_g = preconditioner * g
     m = (1 - self._momentum) * preconditioned_g + self._momentum * m
     x = x - self.step_size(step) * m
     return x, (m, v)
Ejemplo n.º 6
0
 def _update_diagonal(self, grads, params, m, v, opt_params):
     (learning_rate, momentum) = opt_params
     v[0] += grads * grads
     preconditioner = np.where(v[0] > 0, 1.0 / np.sqrt(v[0]),
                               np.zeros_like(v[0]))
     preconditioned_grads = preconditioner * grads
     m = (1 - momentum) * preconditioned_grads + momentum * m
     params = params - (learning_rate * m).astype(params.dtype)
     return params, (m, v)
Ejemplo n.º 7
0
def Dropout(x, params, rate=0.0, mode='train', rng=None, **kwargs):
  """Layer construction function for a dropout layer with given rate."""
  del params, kwargs
  if rng is None:
    msg = ('Dropout layer requires apply_fn to be called with a rng keyword '
           'argument. That is, instead of `Dropout(params, inputs)`, call '
           'it like `Dropout(params, inputs, rng=key)`.')
    raise ValueError(msg)
  if rate >= 1.0:
    raise ValueError('Dropout rate (%f) must be lower than 1.' % rate)
  if mode == 'train' and rate > 0.0:
    keep = backend.random.bernoulli(rng, 1.0 - rate, x.shape)
    return np.where(keep, x / (1.0 - rate), np.zeros_like(x))
  else:
    return x
Ejemplo n.º 8
0
 def call(self, x, params, state, rng=None, **unused_kwargs):
   """Execute dropout."""
   del params
   rate = self._initial_rate
   if isinstance(state, dict) and self._name in state:
     rate = state[self._name]
   if rng is None:
     msg = ('Dropout layer requires apply_fn to be called with a rng keyword '
            'argument. That is, instead of `Dropout(params, inputs)`, call '
            'it like `Dropout(params, inputs, rng=key)`.')
     raise ValueError(msg)
   if self._mode != 'train':
     return x, state
   keep = backend.random.bernoulli(rng, 1.0 - rate, x.shape)
   return np.where(keep, x / (1.0 - rate), np.zeros_like(x)), state
Ejemplo n.º 9
0
def ShiftRight(x, **unused_kwargs):
    """Layer to shift the tensor to the right by padding on axis 1."""
    if not isinstance(x, (list, tuple)):  # non-chunked inputs
        pad_widths = [(0, 0), (1, 0)]
        padded = np.pad(x, pad_widths, mode='constant')
        return padded[:, :-1]
    # Handling chunked inputs. Recall that the list of chunks represents a big
    # sequence (the concatenation of the chunks). We want to shift that sequence,
    # so we put a 0 in the beginning of the first chunk and the last element of
    # that chunk is used as the new first element of the next chunk, and so on.
    padded = []
    last_value = np.zeros_like(x[0][:, -1])
    for chunk in x:
        padded_chunk = np.concatenate([last_value[:, np.newaxis], chunk],
                                      axis=1)
        last_value = chunk[:, -1]
        padded.append(padded_chunk[:, :-1])
    return padded
Ejemplo n.º 10
0
 def _update_sketched(self, grads, params, m, v, opt_params):
   """Update for higher-rank parameters."""
   (learning_rate, momentum) = opt_params
   shape = params.shape
   rank = len(shape)
   reshaped_accumulators = [np.reshape(v[i], self._expanded_shape(shape, i))
                            for i in range(rank)]
   current_accumulator = self._minimum(reshaped_accumulators)
   current_accumulator += grads * grads
   accumulator_inv_sqrt = np.where(current_accumulator > 0.0,
                                   1.0 / np.sqrt(current_accumulator),
                                   np.zeros_like(current_accumulator))
   preconditioned_gradient = grads * accumulator_inv_sqrt
   m = (1.0 - momentum) * preconditioned_gradient + momentum * m
   params = params - (learning_rate * m).astype(params.dtype)
   for i in range(len(v)):
     axes = list(range(int(i))) + list(range(int(i) + 1, rank))
     dim_accumulator = np.amax(current_accumulator, axis=axes)
     v[i] = dim_accumulator
   return params, (m, v)
Ejemplo n.º 11
0
 def _update_sketched(self, step, g, x, m, v):
     """Update for higher-rank parameters."""
     shape = x.shape
     rank = len(shape)
     reshaped_accumulators = [
         np.reshape(v[i], self._expanded_shape(shape, i))
         for i in range(rank)
     ]
     current_accumulator = self._minimum(reshaped_accumulators)
     current_accumulator += g * g
     accumulator_inv_sqrt = np.where(current_accumulator > 0.0,
                                     1.0 / np.sqrt(current_accumulator),
                                     np.zeros_like(current_accumulator))
     preconditioned_gradient = g * accumulator_inv_sqrt
     m = (1.0 -
          self._momentum) * preconditioned_gradient + self._momentum * m
     x = x - self.step_size(step) * m
     for i in range(len(v)):
         axes = list(range(int(i))) + list(range(int(i) + 1, rank))
         dim_accumulator = np.amax(current_accumulator, axis=axes)
         v[i] = dim_accumulator
     return x, (m, v)
Ejemplo n.º 12
0
    def call_and_grad(self, inputs, ct, rng=None, **kwargs):
        del kwargs
        # We use the same vector as both a query and a key. For now we haven't
        # adjusted any of the surrounding code, so we still get a separate "key"
        # input that we ignore.
        qk, ignored_k, v = inputs
        seqlen = qk.shape[-2]
        # qk/v are n_batch*n_heads, seqlen, d_head

        # bins are n_batch*n_heads, seqlen
        # They specify which hash bucket the query/key/value vectors fall in.
        bins = self.hash_vectors(qk, rng=rng)

        # joint_t is n_batch*n_heads, seqlen
        joint_t = jax.lax.tie_in(qk, np.arange(seqlen))
        joint_t = np.reshape(joint_t, (1, seqlen))
        joint_t = np.broadcast_to(joint_t, qk.shape[:-1])

        assert int((self.n_bins + 1) * seqlen) < 2**31, (
            'Potential 32-bit integer overflow; please double-check the code.')
        joint_bins_and_t = seqlen * bins + joint_t

        def chunk_scalars(x):  # pylint: disable=invalid-name
            return np.reshape(x, (x.shape[0], self.n_bins, -1))

        def chunk_vectors(x):  # pylint: disable=invalid-name
            return np.reshape(x, (x.shape[0], self.n_bins, -1, x.shape[-1]))

        def unchunk_vectors(x):  # pylint: disable=invalid-name
            return np.reshape(x, (x.shape[0], -1, x.shape[-1]))

        # Sort everything by bin number, with a secondary sort by time
        # (variables starting with "s" are sorted)
        _, sjoint_t = jax.lax.sort_key_val(joint_bins_and_t,
                                           joint_t,
                                           dimension=-1)

        sqk = np.take_along_axis(qk, sjoint_t[:, :, None], axis=-2)
        sv = np.take_along_axis(v, sjoint_t[:, :, None], axis=-2)

        if ct is not None:
            so_ct = np.take_along_axis(ct, sjoint_t[:, :, None], axis=-2)

        @jax.jit
        def binned_attn(sqk, sv):  # pylint: disable=invalid-name
            """Performs attention on sorted queries/keys/values."""
            # Split off a "bin" axis so that attention only occurs whithin chunks.
            bq_t = bkv_t = chunk_scalars(sjoint_t)
            bqk = chunk_vectors(sqk)
            bv = chunk_vectors(sv)

            # Hashing operates on unit-length vectors. Unnormalized query vectors are
            # fine because they effectively provide a learnable temperature for the
            # attention softmax, but normalizing keys is needed so that similarity for
            # the purposes of attention correctly corresponds to hash locality.
            bq = bqk
            bk = self.make_unit_length(bqk)

            # Allow each chunk to attend within itself, and also one chunk back. Chunk
            # boundaries might occur in the middle of a sequence of items from the
            # same bin, so this increases the chances of attending to relevant items.
            # TODO(kitaev): benchmark whether XLA pad operation is noticeably faster.
            bk_extra = np.concatenate([bk[:, -1:, :, :], bk[:, :-1, :, :]],
                                      axis=1)
            bk = np.concatenate([bk, bk_extra], axis=2)
            bv_extra = np.concatenate([bv[:, -1:, :, :], bv[:, :-1, :, :]],
                                      axis=1)
            bv = np.concatenate([bv, bv_extra], axis=2)
            bkv_t_extra = np.concatenate([bkv_t[:, -1:, :], bkv_t[:, :-1, :]],
                                         axis=1)
            bkv_t = np.concatenate([bkv_t, bkv_t_extra], axis=2)

            # Dot-product attention.
            dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(
                bq.shape[-1])

            # Causal masking
            mask = jax.lax.convert_element_type(
                jax.lax.lt(bq_t[:, :, :, None], bkv_t[:, :, None, :]),
                np.float32)
            dots = dots - 1e9 * mask

            # Mask out attention to self except when no other targets are available.
            self_mask = jax.lax.broadcasted_eye(dots.dtype, dots.shape, (2, 3))
            self_mask = jax.lax.tie_in(dots, self_mask)
            dots = dots - 32 * self_mask

            # Softmax.
            dots = np.exp(dots -
                          backend.logsumexp(dots, axis=-1, keepdims=True))
            bo = np.matmul(dots, bv)

            so = unchunk_vectors(bo)
            return so

        @jax.jit
        def binned_attn_vjp(sqk, sv, so_ct):  # pylint: disable=invalid-name
            so, vjpfun = jax.vjp(binned_attn, sqk, sv)
            sqkv_ct = vjpfun(so_ct)
            return so, sqkv_ct

        if ct is None:
            so = binned_attn(sqk, sv)
            _, undo_sort = jax.lax.sort_key_val(sjoint_t,
                                                joint_t,
                                                dimension=-1)
            out = np.take_along_axis(so, undo_sort[:, :, None], axis=-2)
            return out, None
        else:
            # Jax can construct a backward pass automatically, but it's about 2x
            # slower than writing our own. The main reason is that the backward pass
            # of gather is in general a scatter operation, but we know we're dealing
            # with permutations so we use gather for the backward pass too.
            so, (sqk_ct, sv_ct) = binned_attn_vjp(sqk, sv, so_ct)

            _, undo_sort = jax.lax.sort_key_val(sjoint_t,
                                                joint_t,
                                                dimension=-1)
            out = np.take_along_axis(so, undo_sort[:, :, None], axis=-2)

            qk_ct = np.take_along_axis(sqk_ct, undo_sort[:, :, None], axis=-2)
            v_ct = np.take_along_axis(sv_ct, undo_sort[:, :, None], axis=-2)

            return out, (qk_ct, np.zeros_like(ignored_k), v_ct)
Ejemplo n.º 13
0
    def call_and_grad(self, inputs, ct, rng=None, **kwargs):
        del kwargs
        query, key, value = inputs
        depth = np.shape(query)[-1]
        do_backprop = ct is not None

        # jax uses the term cotangent (ct) to refer to gradient signals, and
        # vector-Jacobian product (vjp) for back-propagation through a layer.

        def make_mask(N, M, k):  # pylint: disable=invalid-name
            """Constructs a slice of the causal attention mask.

      Args:
        N: number of query positions
        M: number of key positions
        k: position of the initial query element

      Returns:
        N x M mask, where 1.0 indicates that attention is not allowed.
      """
            x = np.arange(N, dtype=np.int32)
            y = np.arange(M, dtype=np.int32)
            mask = jax.lax.lt((jax.lax.broadcast_in_dim(
                x, shape=(N, M), broadcast_dimensions=(0, )) + k),
                              jax.lax.broadcast(y, [N]))
            mask = jax.lax.convert_element_type(mask, np.float32)
            return mask

        def forward_slice(query_slice, q_loop_idx, key, value):  # pylint: disable=invalid-name
            """Forward pass for a subset of the query vectors."""
            dots = np.matmul(query_slice, np.swapaxes(key, -1,
                                                      -2)) / np.sqrt(depth)

            # Causal masking
            mask = make_mask(dots.shape[-2], dots.shape[-1], q_loop_idx)
            dots = dots - 1e9 * mask

            # Softmax.
            dots = np.exp(dots -
                          backend.logsumexp(dots, axis=-1, keepdims=True))

            if self.dropout is not None and self.dropout > 0.0:
                # Dropout is broadcast across the batch+head dimension
                dropout_shape = (1, dots.shape[-2], dots.shape[-1])
                slice_rng = jax.random.fold_in(rng, q_loop_idx)
                keep_prob = jax.lax.tie_in(dots, 1.0 - self.dropout)
                keep = backend.random.bernoulli(slice_rng, keep_prob,
                                                dropout_shape)
                multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(
                    keep, keep_prob)
                dots = dots * multiplier

            out_slice = np.matmul(dots, value)
            return out_slice

        def forward_and_vjp_slice(query_slice, q_loop_idx, key, value,
                                  ct_slice):  # pylint: disable=invalid-name
            # Capture q_loop_idx to avoid calculated gradients wrt. it.
            def forward_slice_with_q_loop_idx(query_slice, key, value):  # pylint: disable=invalid-name
                return forward_slice(query_slice, q_loop_idx, key, value)

            output_slice, vjpfun = jax.vjp(forward_slice_with_q_loop_idx,
                                           query_slice, key, value)
            return output_slice, vjpfun(ct_slice)

        q_loop_idx = np.zeros((), dtype=np.int32)
        q_loop_max = query.shape[-2]
        q_loop_stride = self._loop_stride
        assert q_loop_max % q_loop_stride == 0, (
            'Stride must evenly divide the number of query elements.')

        out_accum = np.zeros_like(query)
        if do_backprop:
            query_ct_accum = np.zeros_like(query)
            key_ct_accum = np.zeros_like(key)
            value_ct_accum = np.zeros_like(value)
            init_vals = (q_loop_idx, out_accum, query_ct_accum, key_ct_accum,
                         value_ct_accum)
        else:
            init_vals = (q_loop_idx, out_accum)

        def cond_fun(vals):  # pylint: disable=invalid-name
            q_loop_idx = vals[0]
            return jax.lax.lt(q_loop_idx, q_loop_max)

        def body_fun(vals):  # pylint: disable=invalid-name
            """Compute a slice of the attention mechanism."""
            if do_backprop:
                (q_loop_idx, out_accum, query_ct_accum, key_ct_accum,
                 value_ct_accum) = vals
            else:
                q_loop_idx, out_accum = vals

            query_slice = jax.lax.dynamic_slice_in_dim(query,
                                                       q_loop_idx,
                                                       q_loop_stride,
                                                       axis=-2)

            if do_backprop:
                ct_slice = jax.lax.dynamic_slice_in_dim(ct,
                                                        q_loop_idx,
                                                        q_loop_stride,
                                                        axis=-2)
                out_slice, partial_ct = forward_and_vjp_slice(
                    query_slice, q_loop_idx, key, value, ct_slice)
                query_ct_accum = jax.lax.dynamic_update_slice_in_dim(
                    query_ct_accum, partial_ct[0], q_loop_idx, axis=-2)
                key_ct_accum = key_ct_accum + partial_ct[1]
                value_ct_accum = value_ct_accum + partial_ct[2]
            else:
                out_slice = forward_slice(query_slice, q_loop_idx, key, value)

            out_accum = jax.lax.dynamic_update_slice_in_dim(out_accum,
                                                            out_slice,
                                                            q_loop_idx,
                                                            axis=-2)
            q_loop_idx = q_loop_idx + q_loop_stride

            if do_backprop:
                return (q_loop_idx, out_accum, query_ct_accum, key_ct_accum,
                        value_ct_accum)
            else:
                return (q_loop_idx, out_accum)

        final_vals = jax.lax.while_loop(cond_fun, body_fun, init_vals)

        if not do_backprop:
            return final_vals[1], None
        else:
            return final_vals[1], final_vals[2:]
Ejemplo n.º 14
0
  def batch_call_and_or_grad(self, qk, v, ct=None, return_output=True,
                             rng=None):
    assert return_output or ct is not None, 'No work to perform!'
    # pylint: disable=protected-access
    stash_buckets = (return_output and ct is None
                     and base.Layer._STASH_IN is not None)
    if return_output and ct is not None and base.Layer._STASH_OUT is not None:
      buckets = base.Layer._STASH_OUT.pop(self)
    else:
      buckets = None
    # pylint: enable=protected-access

    # The approach here is to perform attention for one batch element and head
    # at a time. Note that there is absolutely no interaction across examples or
    # heads: this layer has no parameters, and hashing patterns are also
    # different across examples/heads. As a result, batching doesn't give any
    # performance gains except in the case of accelerator under-utilization. We
    # assume that hash-based attention will be applied primarily to long
    # sequences, where unbatched attention for a single head has sufficient
    # computation to fill up the accelerator.

    batch_loop_idx = np.zeros((), dtype=np.int32)
    batch_loop_max = qk.shape[0]

    init_vals = (batch_loop_idx,)
    if return_output:
      out_accum = np.zeros_like(qk)
      init_vals = init_vals + (out_accum,)
    if stash_buckets:
      buckets_accum = np.zeros(
          [qk.shape[0], self.n_hashes * qk.shape[1]], dtype=np.int32)
      init_vals = init_vals + (buckets_accum,)
    if ct is not None:
      qk_ct_accum = np.zeros_like(qk)
      v_ct_accum = np.zeros_like(v)
      init_vals = init_vals + (qk_ct_accum, v_ct_accum)

    def cond_fun(vals):
      batch_loop_idx = vals[0]
      return jax.lax.lt(batch_loop_idx, batch_loop_max)

    def body_fun(vals):
      """Performs attention for a single batch element and head."""
      batch_loop_idx = vals[0]
      if self._prng is None:
        hash_rng = jax.random.fold_in(rng, batch_loop_idx)
      else:
        # TODO(kitaev): Maybe use the same RNG across examples (but not heads)?
        hash_rng = jax.random.fold_in(self._prng, batch_loop_idx)
      qk_slice = jax.lax.dynamic_index_in_dim(
          qk, batch_loop_idx, axis=0, keepdims=False)
      v_slice = jax.lax.dynamic_index_in_dim(
          v, batch_loop_idx, axis=0, keepdims=False)

      if buckets is None:
        buckets_slice = self.hash_vectors(qk_slice, rng=hash_rng)
      else:
        buckets_slice = jax.lax.dynamic_index_in_dim(
            buckets, batch_loop_idx, axis=0, keepdims=False)

      if ct is None:
        out_slice = self.single_call(
            qk_slice, v_slice, buckets_slice, hash_rng=hash_rng)
      else:
        def _do_single_call(qk_slice, v_slice):
          return self.single_call(
              qk_slice, v_slice, buckets_slice, hash_rng=hash_rng)
        ct_slice = jax.lax.dynamic_index_in_dim(
            ct, batch_loop_idx, axis=0, keepdims=False)
        out_slice, vjpfun = jax.vjp(_do_single_call, qk_slice, v_slice)
        qk_ct_slice, v_ct_slice = vjpfun(ct_slice)

      new_vals = (batch_loop_idx + 1,)
      if return_output:
        out_accum = vals[1]
        out_accum = jax.lax.dynamic_update_index_in_dim(
            out_accum, out_slice, batch_loop_idx, axis=0)
        new_vals = new_vals + (out_accum,)
      if stash_buckets:
        buckets_accum = vals[2]
        buckets_accum = jax.lax.dynamic_update_index_in_dim(
            buckets_accum, buckets_slice, batch_loop_idx, axis=0)
        new_vals = new_vals + (buckets_accum,)
      if ct is not None:
        qk_ct_accum, v_ct_accum = vals[-2:]
        qk_ct_accum = jax.lax.dynamic_update_index_in_dim(
            qk_ct_accum, qk_ct_slice, batch_loop_idx, axis=0)
        v_ct_accum = jax.lax.dynamic_update_index_in_dim(
            v_ct_accum, v_ct_slice, batch_loop_idx, axis=0)
        new_vals = new_vals + (qk_ct_accum, v_ct_accum)

      return new_vals

    final_vals = jax.lax.while_loop(cond_fun, body_fun, init_vals)

    if return_output:
      out = final_vals[1]
    else:
      out = None

    if stash_buckets:
      base.Layer._STASH_IN[self] = final_vals[2]  # pylint: disable=protected-access

    if ct is not None:
      input_ct = final_vals[-2:]
    else:
      input_ct = None

    return out, input_ct
Ejemplo n.º 15
0
 def init(self, x):
     vs = [np.zeros(sz, dtype=x.dtype) for sz in x.shape]
     return (np.zeros_like(x), vs)
Ejemplo n.º 16
0
 def init(self, x):
     m = np.zeros_like(x)
     v = np.zeros_like(x)
     return m, v
Ejemplo n.º 17
0
def ParametricRelu(x, a=1., **unused_kwargs):
    return np.maximum(a * x, np.zeros_like(x))
Ejemplo n.º 18
0
    def forward_and_vjp(self, inputs, ct, params=(), **kwargs):
        # This is the core of the memory-efficient attention implementation, where
        # we use the jax.lax.while_loop primitive to compute attention for a small
        # set of query positions at a time. Note how in the backwards pass, we
        # compute both the forward direction (to recover the previous layer's
        # activations) and the backward direction simultaneously. This allows us to
        # only use a single loop, where the inner portion of the loop does a slice
        # of the forward+backward joint computation. Unfortunately we have had to
        # introduce a large number of wrapper classes (including
        # ReversibleAttentionHalfResidual and ApplyAttentionWrapper) for the sole
        # purpose of connecting this implementation of forward_and_vjp with the core
        # backprop implementation.

        query, key, value = inputs
        depth = np.shape(query)[-1]
        do_backprop = ct is not None

        def make_mask(N, M, k):
            x = np.arange(N, dtype=np.int32)
            y = np.arange(M, dtype=np.int32)
            mask = jax.lax.lt((jax.lax.broadcast_in_dim(
                x, shape=(N, M), broadcast_dimensions=(0, )) + k),
                              jax.lax.broadcast(y, [N]))
            mask = jax.lax.convert_element_type(mask, np.float32)
            return mask

        def forward_slice(query_slice, q_loop_idx, key, value):
            """Forward pass for a subset of the query vectors."""
            dots = np.matmul(query_slice, np.swapaxes(key, -1,
                                                      -2)) / np.sqrt(depth)

            # Causal masking
            mask = make_mask(dots.shape[-2], dots.shape[-1], q_loop_idx)
            dots = dots - 1e9 * mask

            # Softmax.
            dots = np.exp(dots - dots.max(axis=-1, keepdims=True))
            dots = dots / dots.sum(axis=-1, keepdims=True)
            out_slice = np.matmul(dots, value)
            return out_slice

        def forward_and_vjp_slice(query_slice, q_loop_idx, key, value,
                                  ct_slice):
            output_slice, vjpfun = jax.vjp(forward_slice, query_slice,
                                           q_loop_idx, key, value)
            return output_slice, vjpfun(ct_slice)

        q_loop_idx = np.zeros((), dtype=np.int32)
        q_loop_max = query.shape[2]
        q_loop_stride = self._loop_stride
        assert q_loop_max % q_loop_stride == 0, (
            'Stride must evenly divide the number of query elements.')

        out_accum = np.zeros_like(query)
        if do_backprop:
            query_ct_accum = np.zeros_like(query)
            key_ct_accum = np.zeros_like(key)
            value_ct_accum = np.zeros_like(value)
            init_vals = (q_loop_idx, out_accum, query_ct_accum, key_ct_accum,
                         value_ct_accum)
        else:
            init_vals = (q_loop_idx, out_accum)

        def cond_fun(vals):
            q_loop_idx = vals[0]
            return jax.lax.lt(q_loop_idx, q_loop_max)

        def body_fun(vals):
            """Compute a slice of the attention mechanism."""
            if do_backprop:
                (q_loop_idx, out_accum, query_ct_accum, key_ct_accum,
                 value_ct_accum) = vals
            else:
                q_loop_idx, out_accum = vals

            query_slice = jax.lax.dynamic_slice_in_dim(query,
                                                       q_loop_idx,
                                                       q_loop_stride,
                                                       axis=2)

            if do_backprop:
                ct_slice = jax.lax.dynamic_slice_in_dim(ct,
                                                        q_loop_idx,
                                                        q_loop_stride,
                                                        axis=2)
                out_slice, partial_ct = forward_and_vjp_slice(
                    query_slice, q_loop_idx, key, value, ct_slice)
                query_ct_accum = jax.lax.dynamic_update_slice_in_dim(
                    query_ct_accum, partial_ct[0], q_loop_idx, axis=2)
                # ignore partial_ct[1], which is wrt the loop idx
                key_ct_accum = key_ct_accum + partial_ct[2]
                value_ct_accum = value_ct_accum + partial_ct[3]
            else:
                out_slice = forward_slice(query_slice, q_loop_idx, key, value)

            out_accum = jax.lax.dynamic_update_slice_in_dim(out_accum,
                                                            out_slice,
                                                            q_loop_idx,
                                                            axis=2)
            q_loop_idx = q_loop_idx + q_loop_stride

            if do_backprop:
                return (q_loop_idx, out_accum, query_ct_accum, key_ct_accum,
                        value_ct_accum)
            else:
                return (q_loop_idx, out_accum)

        final_vals = jax.lax.while_loop(cond_fun, body_fun, init_vals)

        if not do_backprop:
            return final_vals[1], None
        else:
            return final_vals[1], final_vals[2:]
Ejemplo n.º 19
0
 def forward_and_backward(self, inputs, ct, rng=None, **kwargs):
   del kwargs
   output, (qk_ct, v_ct) = self.batch_call_and_or_grad(
       inputs[0], inputs[2], ct=ct, rng=rng)
   return output, (qk_ct, np.zeros_like(inputs[1]), v_ct)
Ejemplo n.º 20
0
 def init(self, params):
     m = np.zeros_like(params)
     v = np.zeros_like(params)
     return m, v
Ejemplo n.º 21
0
 def init(self, params):
     return np.zeros_like(params)
Ejemplo n.º 22
0
 def drop_for_hash(self, x, rng):
     rate = self._drop_for_hash_rate
     if self._mode == 'train' and rate > 0.0:
         keep = backend.random.bernoulli(rng, 1.0 - rate, x.shape)
         return np.where(keep, x / (1.0 - rate), np.zeros_like(x))
     return x
Ejemplo n.º 23
0
 def init(self, x):
     return np.zeros_like(x)
Ejemplo n.º 24
0
 def init(self, params):
     vs = [np.zeros(sz, dtype=params.dtype) for sz in params.shape]
     return (np.zeros_like(params), vs)
Ejemplo n.º 25
0
def Relu(x, **unused_kwargs):
    return np.maximum(x, np.zeros_like(x))