Пример #1
0
 def vjpfun(o_logits_grads):
   so_grad = np.take(o_logits_grads[0], sticker, axis=0)
   # TODO(kitaev): this exists to match the forward pass, but I'm not sure
   # if it's actually required.
   buckets_and_t_ = buckets_and_t + jax.lax.convert_element_type(
       o_logits_grads[1][0] > 0, buckets_and_t.dtype)
   _, slogits_grad = jax.lax.sort_key_val(
       buckets_and_t_, o_logits_grads[1], dimension=-1)
   return (so_grad, slogits_grad)
Пример #2
0
 def unsort_for_output_impl(so, slogits):
   o = np.take(so, undo_sort, axis=0)
   # Sorting is considerably faster than gather, but first we need to get the
   # XLA compiler to abandon the idea of fusing this sort with the input sort
   # (which introduces a computation cycle and leads to a crash).
   # TODO(kitaev): remove "sticker_" variable if XLA is fixed.
   sticker_ = sticker + jax.lax.convert_element_type(
       slogits[0] > 0, sticker.dtype)
   _, logits = jax.lax.sort_key_val(sticker_, slogits, dimension=-1)
   return o, logits
Пример #3
0
 def call(self, x, params, **kwargs):
     del kwargs
     return np.take(params, x, axis=0)
Пример #4
0
 def call(self, params, inputs, **kwargs):
     del kwargs
     return np.take(params, inputs, axis=0)
Пример #5
0
 def forward(self, x, params=(), state=(), **kwargs):
     del kwargs
     return np.take(params, x, axis=0), state
Пример #6
0
 def apply_fun(params, inputs, **kwargs):
     del kwargs
     dense_embedding = params
     return np.take(dense_embedding, inputs, axis=0)
Пример #7
0
  def single_call(self, qk, v, buckets, hash_rng=None):
    # We use the same vector as both a query and a key.
    seqlen = qk.shape[-2]
    assert int(buckets.shape[0]) == self.n_hashes * seqlen

    ticker = jax.lax.tie_in(qk, np.arange(self.n_hashes * seqlen))
    buckets_and_t = seqlen * buckets + (ticker % seqlen)
    buckets_and_t = jax.lax.stop_gradient(buckets_and_t)

    # Hash-based sort ("s" at the start of variable names means "sorted")
    sbuckets_and_t, sticker = jax.lax.sort_key_val(
        buckets_and_t, ticker, dimension=-1)
    _, undo_sort = jax.lax.sort_key_val(sticker, ticker, dimension=-1)
    sbuckets_and_t = jax.lax.stop_gradient(sbuckets_and_t)
    sticker = jax.lax.stop_gradient(sticker)
    undo_sort = jax.lax.stop_gradient(undo_sort)

    st = (sticker % seqlen)
    sqk = np.take(qk, st, axis=0)
    sv = np.take(v, st, axis=0)

    # Split off a "bin" axis so that attention only occurs within chunks.
    bq_t = bkv_t = np.reshape(st, (self.n_hashes * self.n_bins, -1))
    bqk = np.reshape(sqk, (self.n_hashes * self.n_bins, -1, sqk.shape[-1]))
    bv = np.reshape(sv, (self.n_hashes * self.n_bins, -1, sv.shape[-1]))
    bq_buckets = bkv_buckets = np.reshape(
        sbuckets_and_t // seqlen, (self.n_hashes * self.n_bins, -1))

    # Hashing operates on unit-length vectors. Unnormalized query vectors are
    # fine because they effectively provide a learnable temperature for the
    # attention softmax, but normalizing keys is needed so that similarity for
    # the purposes of attention correctly corresponds to hash locality.
    bq = bqk
    bk = self.make_unit_length(bqk)

    # Allow each chunk to attend within itself, and also one chunk back. Chunk
    # boundaries might occur in the middle of a sequence of items from the
    # same bucket, so this increases the chances of attending to relevant items.
    # TODO(kitaev): benchmark whether XLA pad operation is noticeably faster.
    def look_one_back(x):
      if len(x.shape) == 2:
        x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0)
      else:
        x_extra = np.concatenate([x[-1:, :, :], x[:-1, :, :]], axis=0)
      return np.concatenate([x, x_extra], axis=1)

    bk = look_one_back(bk)
    bv = look_one_back(bv)
    bkv_t = look_one_back(bkv_t)
    bkv_buckets = look_one_back(bkv_buckets)

    # Dot-product attention.
    dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(bq.shape[-1])

    # Causal masking
    mask = jax.lax.convert_element_type(
        jax.lax.lt(bq_t[:, :, None], bkv_t[:, None, :]),
        np.float32)
    dots = dots - 1e9 * mask

    # Mask out attention to self except when no other targets are available.
    self_mask = jax.lax.convert_element_type(
        jax.lax.eq(bq_t[:, :, None], bkv_t[:, None, :]),
        np.float32)
    dots = dots - 1e5 * self_mask

    # Mask out attention to other hash buckets.
    if not self._attend_across_buckets:
      bucket_mask = jax.lax.convert_element_type(
          jax.lax.ne(bq_buckets[:, :, None], bkv_buckets[:, None, :]),
          np.float32)
      dots = dots - 1e7 * bucket_mask

    # Don't double-count query-key pairs across multiple rounds of hashing.
    # There are two possible strategies here. (1) The default is to count how
    # many times a query-key pair is repeated, and to lower its log-prob
    # correspondingly at each repetition. (2) When hard_k is set, the code
    # instead masks all but the first occurence of each query-key pair.
    # TODO(kitaev): is one strategy faster or more numerically stable?
    if not self._allow_duplicate_attention:
      locs1 = undo_sort // bq_t.shape[-1]
      locs2 = (locs1 + 1) % (self.n_hashes * self.n_bins)
      if not self._attend_across_buckets:
        locs1 = buckets * (self.n_hashes * self.n_bins) + locs1
        locs2 = buckets * (self.n_hashes * self.n_bins) + locs2
      locs = np.moveaxis(np.concatenate([
          np.reshape(locs1, (self.n_hashes, seqlen)),
          np.reshape(locs2, (self.n_hashes, seqlen)),
      ], 0), 0, -1)  # produces shape (seqlen, 2 * self.n_hashes)
      slocs = np.take(locs, st, axis=0)
      b_locs = np.reshape(
          slocs, (self.n_hashes * self.n_bins, -1, 2 * self.n_hashes))
      # Queries always use the primary location (based on locs1).
      b_locs1 = b_locs[:, :, None, :self.n_hashes]
      if self._hard_k > 0:
        range_n_hashes = jax.lax.tie_in(b_locs, np.arange(self.n_hashes))
        nouse_locs = (range_n_hashes[:, None] > range_n_hashes[None, :])
        nouse_locs = 2 * nouse_locs - 1  # 1 = use, -1 = don't use
        nouse_locs = np.reshape(
            np.broadcast_to(nouse_locs[:, None, :],
                            (self.n_hashes, self.n_bins, self.n_hashes)),
            (self.n_hashes * self.n_bins, 1, 1, self.n_hashes))
        b_locs1 = b_locs1 * nouse_locs
      bq_locs = np.broadcast_to(
          b_locs1,
          b_locs.shape[:2] + (2, self.n_hashes))
      bq_locs = np.reshape(bq_locs, b_locs.shape)
      bkv_locs = look_one_back(b_locs)

      dup_counts = np.sum(
          jax.lax.convert_element_type(
              jax.lax.eq(bq_locs[:, :, None, :], bkv_locs[:, None, :, :]),
              np.float32),
          axis=-1)
      assert dup_counts.shape == dots.shape
      if self._hard_k > 0:
        dots = dots - 1e7 * jax.lax.stop_gradient(dup_counts)
      else:
        dots = dots - jax.lax.stop_gradient(np.log(dup_counts + 1e-9))

    # Each query only attends to the top k most relevant keys.
    if self._hard_k > 0:
      b_top_dots = np.sort(dots)[..., -self._hard_k:]  # Get the top k dots.
      b_top_dots = jax.lax.stop_gradient(b_top_dots)
      s_top_dots = np.reshape(b_top_dots, (-1, self._hard_k))
      top_dots = np.take(s_top_dots, undo_sort, axis=0)

      merged_top_dots = np.moveaxis(
          np.reshape(top_dots, (self.n_hashes, seqlen, self._hard_k)), 0, -1)
      merged_top_dots = np.reshape(merged_top_dots, (seqlen, -1))

      dots_thresh = np.sort(merged_top_dots)[:, -self._hard_k]
      # It's possible to compute the partition function at this point, but right
      # now this codepath isn't set up for backprop, and there might also be
      # issues computing it this way if two dot-products are exactly equal.

      sdots_thresh = dots_thresh[st]
      bdots_thresh = np.reshape(sdots_thresh, (self.n_hashes * self.n_bins, -1))
      bdots_thresh = jax.lax.stop_gradient(bdots_thresh)

      top_k_mask = jax.lax.convert_element_type(
          dots < bdots_thresh[..., None], np.float32)
      dots = dots - 1e7 * jax.lax.stop_gradient(top_k_mask)

    # Softmax.
    dots_logsumexp = backend.logsumexp(dots, axis=-1, keepdims=True)
    dots = np.exp(dots - dots_logsumexp)

    bo = np.matmul(dots, bv)
    so = np.reshape(bo, (-1, bo.shape[-1]))
    slogits = np.reshape(dots_logsumexp, (-1,))

    def unsort_for_output_impl(so, slogits):
      o = np.take(so, undo_sort, axis=0)
      # Sorting is considerably faster than gather, but first we need to get the
      # XLA compiler to abandon the idea of fusing this sort with the input sort
      # (which introduces a computation cycle and leads to a crash).
      # TODO(kitaev): remove "sticker_" variable if XLA is fixed.
      sticker_ = sticker + jax.lax.convert_element_type(
          slogits[0] > 0, sticker.dtype)
      _, logits = jax.lax.sort_key_val(sticker_, slogits, dimension=-1)
      return o, logits

    def unsort_for_output_vjp(so, slogits):
      """Custom gradient for unsort_for_output."""
      so = jax.lax.stop_gradient(so)
      slogits = jax.lax.stop_gradient(slogits)
      o, logits = unsort_for_output_impl(so, slogits)
      def vjpfun(o_logits_grads):
        so_grad = np.take(o_logits_grads[0], sticker, axis=0)
        # TODO(kitaev): this exists to match the forward pass, but I'm not sure
        # if it's actually required.
        buckets_and_t_ = buckets_and_t + jax.lax.convert_element_type(
            o_logits_grads[1][0] > 0, buckets_and_t.dtype)
        _, slogits_grad = jax.lax.sort_key_val(
            buckets_and_t_, o_logits_grads[1], dimension=-1)
        return (so_grad, slogits_grad)
      return (o, logits), vjpfun

    unsort_for_output = jax.custom_transforms(unsort_for_output_impl)
    jax.defvjp_all(unsort_for_output, unsort_for_output_vjp)
    o, logits = unsort_for_output_impl(so, slogits)

    if self.n_hashes == 1:
      out = o
    else:
      o = np.reshape(o, (self.n_hashes, seqlen, o.shape[-1]))
      logits = np.reshape(logits, (self.n_hashes, seqlen, 1))
      probs = np.exp(logits - backend.logsumexp(logits, axis=0, keepdims=True))
      out = np.sum(o * probs, axis=0)

    assert out.shape == v.shape
    return out