Пример #1
0
 def make_mask(N, M, k):
     x = np.arange(N, dtype=np.int32)
     y = np.arange(M, dtype=np.int32)
     mask = jax.lax.lt((jax.lax.broadcast_in_dim(
         x, shape=(N, M), broadcast_dimensions=(0, )) + k),
                       jax.lax.broadcast(y, [N]))
     mask = jax.lax.convert_element_type(mask, np.float32)
     return mask
Пример #2
0
def one_hot(x, size, dtype=np.float32):  # pylint: disable=invalid-name
    """Make a n+1 dim one-hot array from n dim int-categorical array."""
    arange_size = np.arange(size)
    if backend.get_name() == 'jax':
        # Work around a jax broadcasting issue.
        arange_size = jax.lax.tie_in(x, arange_size)
    return np.array(x[..., np.newaxis] == arange_size, dtype)
Пример #3
0
        def make_mask(N, M, k):  # pylint: disable=invalid-name
            """Constructs a slice of the causal attention mask.

      Args:
        N: number of query positions
        M: number of key positions
        k: position of the initial query element

      Returns:
        N x M mask, where 1.0 indicates that attention is not allowed.
      """
            x = np.arange(N, dtype=np.int32)
            y = np.arange(M, dtype=np.int32)
            mask = jax.lax.lt((jax.lax.broadcast_in_dim(
                x, shape=(N, M), broadcast_dimensions=(0, )) + k),
                              jax.lax.broadcast(y, [N]))
            mask = jax.lax.convert_element_type(mask, np.float32)
            return mask
Пример #4
0
    def bin_vectors_by_time(self, vecs):
        seqlen = vecs.shape[-2]
        assert seqlen % self.n_bins == 0
        bin_size = int(seqlen // self.n_bins)

        bins = np.arange(seqlen, dtype=np.int32) // bin_size
        bins = jax.lax.tie_in(vecs, bins)
        bins = bins[None, :]
        bins = np.broadcast_to(bins, vecs.shape[:-1])
        return bins
Пример #5
0
    def make_self_mask(N, M, k):  # pylint: disable=invalid-name
      """Masks out elements attending to self.

      Args:
        N: number of query positions
        M: number of key positions
        k: position of the initial query element

      Returns:
        N x M mask, where 1.0 indicates that attention is not allowed.
      """
      x = jax.lax.tie_in(k, np.arange(N, dtype=np.int32))
      y = jax.lax.tie_in(k, np.arange(M, dtype=np.int32))
      mask = jax.lax.eq(
          (jax.lax.broadcast_in_dim(
              x, shape=(N, M), broadcast_dimensions=(0,)) + k),
          jax.lax.broadcast(y, [N]))
      mask = jax.lax.convert_element_type(mask, np.float32)
      return mask
Пример #6
0
def label_smoothed_loss(logpred, target, size, padding_idx=0, smoothing=0.0):
    """Returns a label-smoothing loss-criterion function."""
    confidence = 1.0 - smoothing
    zerosmoothed = smoothing / (size - 2)
    delta = confidence - zerosmoothed
    assert logpred.shape[1] == size
    truedist = (np.full_like(logpred, zerosmoothed) +
                delta * slax.one_hot(target, size))
    truedist *= (1 - (np.arange(size) == padding_idx))
    truedist *= (1 - (target == padding_idx))[:, np.newaxis]
    return kl_div(logpred, truedist, eps=1e-6)
Пример #7
0
  def hash_vectors(self, vecs, rng):
    # See https://arxiv.org/pdf/1509.02897.pdf
    # We sample a different random rotation for each round of hashing to
    # decrease the probability of hash misses.
    assert self.n_buckets % 2 == 0
    random_rotations_shape = (
        vecs.shape[-1],
        self.n_hashes if self._rehash_each_round else 1,
        self.n_buckets // 2)

    rng = jax.lax.tie_in(vecs, rng)
    rng, subrng = backend.random.split(rng)
    random_rotations = jax.random.normal(
        rng, random_rotations_shape).astype('float32')
    # TODO(lukaszkaiser): the dropout mask will be used for all rounds of
    # hashing, so it's shared between them. Check if that's what we want.
    dropped_vecs = self.drop_for_hash(vecs, subrng)
    rotated_vecs = np.einsum('tf,fhb->htb', dropped_vecs, random_rotations)
    rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1)

    if self._rehash_each_round:
      buckets = np.argmax(rotated_vecs, axis=-1)
      # buckets is now (self.n_hashes, seqlen). Next we add offsets so that
      # bucket numbers from different hashing rounds don't overlap.
      offsets = jax.lax.tie_in(buckets, np.arange(self.n_hashes))
      offsets = np.reshape(offsets * self.n_buckets, (-1, 1))
      buckets = np.reshape(buckets + offsets, (-1,))
    else:
      # In this configuration, we map each item to the top self.n_hashes buckets
      rotated_vecs = np.squeeze(rotated_vecs, 0)
      bucket_range = jax.lax.tie_in(vecs, np.arange(rotated_vecs.shape[-1]))
      bucket_range = np.reshape(bucket_range, (1, -1))
      bucket_range = np.broadcast_to(bucket_range, rotated_vecs.shape)

      _, buckets = jax.lax.sort_key_val(
          rotated_vecs, bucket_range, dimension=-1)
      buckets = buckets[:, -self.n_hashes:]
      buckets = np.reshape(np.moveaxis(buckets, 0, -1), (-1,))

    return buckets
Пример #8
0
 def test_batch_norm(self):
     input_shape = (2, 3, 4)
     input_dtype = np.float32
     eps = 1e-5
     rng = backend.random.get_prng(0)
     inp1 = np.reshape(np.arange(np.prod(input_shape), dtype=input_dtype),
                       input_shape)
     m1 = 11.5  # Mean of this random input.
     v1 = 47.9167  # Variance of this random input.
     layer = normalization.BatchNorm(axis=(0, 1, 2))
     params, state = layer.initialize(input_shape, input_dtype, rng)
     onp.testing.assert_allclose(state[0], 0)
     onp.testing.assert_allclose(state[1], 1)
     self.assertEqual(state[2], 0)
     out, state = layer(inp1, params, state)
     onp.testing.assert_allclose(state[0], m1 * 0.001)
     onp.testing.assert_allclose(state[1], 0.999 + v1 * 0.001, rtol=1e-6)
     self.assertEqual(state[2], 1)
     onp.testing.assert_allclose(out, (inp1 - m1) / np.sqrt(v1 + eps),
                                 rtol=1e-6)
Пример #9
0
def PreparePairedSequenceBatch(source, target_in, pad=0):
    """Build masks for this batch.

  Args:
    source: (batch, source_len) array of integer-coded symbols for inputs
    target_in: (batch, batch_len) array of integer-coded symbols for targets
    pad: int: the padding symbol used to pad the above

  Returns:
    Prepared batch of tuple of arrays: source, input-target, shifted-target,
    source mask, target mask, source-target "memory" mask, minibatch token count
  """
    target = target_in[:, :-1]
    target_y = target_in[:, 1:]
    source_mask = np.reshape(source != pad,
                             (source.shape[0], 1, 1, source.shape[-1]))
    target_mask = MakeTargetMask(target, pad)
    memory_mask = (np.reshape(
        np.arange(target.shape[-1]) < source.shape[-1], [-1, 1]))
    ntokens = np.sum(target_y != pad)
    return (source, target, target_y, source_mask, target_mask, memory_mask,
            ntokens)
Пример #10
0
 def test_batch_norm(self):
     input_shape = (2, 3, 4)
     input_dtype = np.float32
     eps = 1e-5
     rng = backend.random.get_prng(0)
     inp1 = np.reshape(np.arange(np.prod(input_shape), dtype=input_dtype),
                       input_shape)
     m1 = 11.5
     v1 = 47.9167
     layer = normalization.BatchNorm(axis=(0, 1, 2))
     params, state = layer.initialize(input_shape, input_dtype, rng)
     onp.testing.assert_allclose(state[0], 0)
     onp.testing.assert_allclose(state[1], 0)
     self.assertEqual(state[2], 0)
     out, state = layer(inp1, params, state)
     onp.testing.assert_allclose(state[0], m1)
     onp.testing.assert_allclose(state[1], v1, rtol=1e-6)
     self.assertEqual(state[2], 1)
     onp.testing.assert_allclose(out, (inp1 - m1) / np.sqrt(v1 + eps),
                                 rtol=1e-6)
     inp2 = inp1 * 2 + 3
     m2 = m1 * 2 + 3
     v2 = v1 * 4
     m12 = (m1 + m2) / 2
     v12 = (v1 + v2) / 2
     out, state = layer(inp2, params, state)
     onp.testing.assert_allclose(state[0], m12)
     onp.testing.assert_allclose(state[1], v12, rtol=1e-6)
     self.assertEqual(state[2], 2)
     onp.testing.assert_allclose(out, (inp2 - m2) / np.sqrt(v2 + eps),
                                 rtol=1e-6)
     layer = normalization.BatchNorm(axis=(0, 1, 2), mode="eval")
     inp3 = inp1 * 5 + 7
     out, state_unchanged = layer(inp3, params, state)
     for i in range(3):
         onp.testing.assert_allclose(state_unchanged[i], state[i])
     onp.testing.assert_allclose(out, (inp3 - m12) / np.sqrt(v12 + eps),
                                 rtol=1e-6)
Пример #11
0
    def call(self, inputs, params=(), state=(), rng=None, **kwargs):
        del params, kwargs
        # We use the same vector as both a query and a key. For now we haven't
        # adjusted any of the surrounding code, so we still get a separate "key"
        # input that we ignore.
        qk, _, v = inputs
        seqlen = qk.shape[-2]

        # qk/v are n_hashes*n_batch*n_heads, seqlen, d_head
        # TODO(kitaev): is it faster to fuse this tiling into gather/scatter ops?
        qk = np.tile(qk, (self.n_hashes, 1, 1))
        v = np.tile(v, (self.n_hashes, 1, 1))

        # bins are n_hashes*n_batch*n_heads, seqlen
        # They specify which hash bucket the query/key/value vectors fall in.
        bins = self.hash_vectors(qk, rng=rng)

        # joint_t is n_hashes*n_batch*n_heads, seqlen
        joint_t = jax.lax.tie_in(qk, np.arange(seqlen))
        joint_t = np.reshape(joint_t, (1, seqlen))
        joint_t = np.broadcast_to(joint_t, qk.shape[:-1])

        assert int(
            (self.n_buckets_per_bin * self.n_bins + 1) * seqlen
        ) < 2**31, (
            'Potential 32-bit integer overflow; please double-check the code.')
        joint_bins_and_t = seqlen * bins + joint_t

        def chunk_scalars(x):  # pylint: disable=invalid-name
            return np.reshape(x, (x.shape[0], self.n_bins, -1))

        def chunk_vectors(x):  # pylint: disable=invalid-name
            return np.reshape(x, (x.shape[0], self.n_bins, -1, x.shape[-1]))

        def unchunk_vectors(x):  # pylint: disable=invalid-name
            return np.reshape(x, (x.shape[0], -1, x.shape[-1]))

        # Sort everything by bin number, with a secondary sort by time
        # (variables starting with "s" are sorted)
        _, sjoint_t = jax.lax.sort_key_val(joint_bins_and_t,
                                           joint_t,
                                           dimension=-1)
        _, undo_sort = jax.lax.sort_key_val(sjoint_t, joint_t, dimension=-1)
        # TODO(kitaev): why does jax flag integer indices as differentiable?
        # If we don't call stop_gradient here, custom gradients below won't work
        # because the primitive functions close over "differentiable" variables.
        sjoint_t = jax.lax.stop_gradient(sjoint_t)
        undo_sort = jax.lax.stop_gradient(undo_sort)

        # The backward pass of gather is in general a scatter operation, but we know
        # we're dealing with permutations so we use gather for the backward pass
        # too. This custom gradient should be about 2x faster than having jax infer
        # one that uses scatter ops instead.
        def permute_impl(vecs):
            assert len(vecs.shape) == 3
            return np.take_along_axis(vecs, sjoint_t[:, :, None], axis=-2)

        def unpermute_impl(vecs):
            assert len(vecs.shape) == 3
            return np.take_along_axis(vecs, undo_sort[:, :, None], axis=-2)

        @jax.custom_transforms
        def permute(vecs):
            return permute_impl(vecs)

        def permute_vjp(vecs):
            out_vecs = permute_impl(vecs)

            def vjpfun(grad):
                return (unpermute_impl(grad), )

            return out_vecs, vjpfun

        @jax.custom_transforms
        def unpermute(vecs):
            return unpermute_impl(vecs)

        def unpermute_vjp(vecs):
            out_vecs = unpermute_impl(vecs)

            def vjpfun(grad):
                return (permute_impl(grad), )

            return out_vecs, vjpfun

        jax.defvjp_all(permute, permute_vjp)
        jax.defvjp_all(unpermute, unpermute_vjp)

        sqk = permute(qk)
        sv = permute(v)

        # Split off a "bin" axis so that attention only occurs within chunks.
        bq_t = bkv_t = chunk_scalars(sjoint_t)
        bqk = chunk_vectors(sqk)
        bv = chunk_vectors(sv)

        # Hashing operates on unit-length vectors. Unnormalized query vectors are
        # fine because they effectively provide a learnable temperature for the
        # attention softmax, but normalizing keys is needed so that similarity for
        # the purposes of attention correctly corresponds to hash locality.
        bq = bqk
        bk = self.make_unit_length(bqk)

        # Allow each chunk to attend within itself, and also one chunk back. Chunk
        # boundaries might occur in the middle of a sequence of items from the
        # same bin, so this increases the chances of attending to relevant items.
        # TODO(kitaev): benchmark whether XLA pad operation is noticeably faster.
        bk_extra = np.concatenate([bk[:, -1:, :, :], bk[:, :-1, :, :]], axis=1)
        bk = np.concatenate([bk, bk_extra], axis=2)
        bv_extra = np.concatenate([bv[:, -1:, :, :], bv[:, :-1, :, :]], axis=1)
        bv = np.concatenate([bv, bv_extra], axis=2)
        bkv_t_extra = np.concatenate([bkv_t[:, -1:, :], bkv_t[:, :-1, :]],
                                     axis=1)
        bkv_t = np.concatenate([bkv_t, bkv_t_extra], axis=2)

        # Dot-product attention.
        dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(bq.shape[-1])

        # Causal masking
        mask = jax.lax.convert_element_type(
            jax.lax.lt(bq_t[:, :, :, None], bkv_t[:, :, None, :]), np.float32)
        dots = dots - 1e9 * mask

        # Mask out attention to self except when no other targets are available.
        self_mask = jax.lax.broadcasted_eye(dots.dtype, dots.shape, (2, 3))
        self_mask = jax.lax.tie_in(dots, self_mask)
        dots = dots - 32 * self_mask

        # Softmax.
        dots_logsumexp = backend.logsumexp(dots, axis=-1, keepdims=True)
        dots = np.exp(dots - dots_logsumexp)

        if self._hard_k > 0:
            top_k = np.sort(dots)[...,
                                  -self._hard_k]  # Get the top-kth weight.
            top_k = jax.lax.stop_gradient(top_k)
            dots -= top_k[..., np.newaxis]  # Subtract (be 0 for lower ones).
            dots = np.maximum(dots, 0)
            dots_sum = np.sum(dots, axis=-1,
                              keepdims=True)  # Sum to re-normalize.
            dots_logsumexp += np.log(dots_sum)  # Add it to the weight.
            dots /= dots_sum  # Re-normalize.

        bo = np.matmul(dots, bv)
        so = unchunk_vectors(bo)
        slogits = unchunk_vectors(dots_logsumexp)

        o = unpermute(so)
        logits = unpermute(slogits)

        o = np.reshape(o, (self.n_hashes, -1, seqlen, o.shape[-1]))
        logits = np.reshape(logits, (self.n_hashes, -1, seqlen, 1))
        probs = np.exp(logits -
                       backend.logsumexp(logits, axis=0, keepdims=True))
        out = np.sum(o * probs, axis=0)
        assert out.shape == inputs[2].shape

        return out, state
Пример #12
0
    def call_and_grad(self, inputs, ct, rng=None, **kwargs):
        del kwargs
        # We use the same vector as both a query and a key. For now we haven't
        # adjusted any of the surrounding code, so we still get a separate "key"
        # input that we ignore.
        qk, ignored_k, v = inputs
        seqlen = qk.shape[-2]
        # qk/v are n_batch*n_heads, seqlen, d_head

        # bins are n_batch*n_heads, seqlen
        # They specify which hash bucket the query/key/value vectors fall in.
        bins = self.hash_vectors(qk, rng=rng)

        # joint_t is n_batch*n_heads, seqlen
        joint_t = jax.lax.tie_in(qk, np.arange(seqlen))
        joint_t = np.reshape(joint_t, (1, seqlen))
        joint_t = np.broadcast_to(joint_t, qk.shape[:-1])

        assert int((self.n_bins + 1) * seqlen) < 2**31, (
            'Potential 32-bit integer overflow; please double-check the code.')
        joint_bins_and_t = seqlen * bins + joint_t

        def chunk_scalars(x):  # pylint: disable=invalid-name
            return np.reshape(x, (x.shape[0], self.n_bins, -1))

        def chunk_vectors(x):  # pylint: disable=invalid-name
            return np.reshape(x, (x.shape[0], self.n_bins, -1, x.shape[-1]))

        def unchunk_vectors(x):  # pylint: disable=invalid-name
            return np.reshape(x, (x.shape[0], -1, x.shape[-1]))

        # Sort everything by bin number, with a secondary sort by time
        # (variables starting with "s" are sorted)
        _, sjoint_t = jax.lax.sort_key_val(joint_bins_and_t,
                                           joint_t,
                                           dimension=-1)

        sqk = np.take_along_axis(qk, sjoint_t[:, :, None], axis=-2)
        sv = np.take_along_axis(v, sjoint_t[:, :, None], axis=-2)

        if ct is not None:
            so_ct = np.take_along_axis(ct, sjoint_t[:, :, None], axis=-2)

        @jax.jit
        def binned_attn(sqk, sv):  # pylint: disable=invalid-name
            """Performs attention on sorted queries/keys/values."""
            # Split off a "bin" axis so that attention only occurs whithin chunks.
            bq_t = bkv_t = chunk_scalars(sjoint_t)
            bqk = chunk_vectors(sqk)
            bv = chunk_vectors(sv)

            # Hashing operates on unit-length vectors. Unnormalized query vectors are
            # fine because they effectively provide a learnable temperature for the
            # attention softmax, but normalizing keys is needed so that similarity for
            # the purposes of attention correctly corresponds to hash locality.
            bq = bqk
            bk = self.make_unit_length(bqk)

            # Allow each chunk to attend within itself, and also one chunk back. Chunk
            # boundaries might occur in the middle of a sequence of items from the
            # same bin, so this increases the chances of attending to relevant items.
            # TODO(kitaev): benchmark whether XLA pad operation is noticeably faster.
            bk_extra = np.concatenate([bk[:, -1:, :, :], bk[:, :-1, :, :]],
                                      axis=1)
            bk = np.concatenate([bk, bk_extra], axis=2)
            bv_extra = np.concatenate([bv[:, -1:, :, :], bv[:, :-1, :, :]],
                                      axis=1)
            bv = np.concatenate([bv, bv_extra], axis=2)
            bkv_t_extra = np.concatenate([bkv_t[:, -1:, :], bkv_t[:, :-1, :]],
                                         axis=1)
            bkv_t = np.concatenate([bkv_t, bkv_t_extra], axis=2)

            # Dot-product attention.
            dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(
                bq.shape[-1])

            # Causal masking
            mask = jax.lax.convert_element_type(
                jax.lax.lt(bq_t[:, :, :, None], bkv_t[:, :, None, :]),
                np.float32)
            dots = dots - 1e9 * mask

            # Mask out attention to self except when no other targets are available.
            self_mask = jax.lax.broadcasted_eye(dots.dtype, dots.shape, (2, 3))
            self_mask = jax.lax.tie_in(dots, self_mask)
            dots = dots - 32 * self_mask

            # Softmax.
            dots = np.exp(dots -
                          backend.logsumexp(dots, axis=-1, keepdims=True))
            bo = np.matmul(dots, bv)

            so = unchunk_vectors(bo)
            return so

        @jax.jit
        def binned_attn_vjp(sqk, sv, so_ct):  # pylint: disable=invalid-name
            so, vjpfun = jax.vjp(binned_attn, sqk, sv)
            sqkv_ct = vjpfun(so_ct)
            return so, sqkv_ct

        if ct is None:
            so = binned_attn(sqk, sv)
            _, undo_sort = jax.lax.sort_key_val(sjoint_t,
                                                joint_t,
                                                dimension=-1)
            out = np.take_along_axis(so, undo_sort[:, :, None], axis=-2)
            return out, None
        else:
            # Jax can construct a backward pass automatically, but it's about 2x
            # slower than writing our own. The main reason is that the backward pass
            # of gather is in general a scatter operation, but we know we're dealing
            # with permutations so we use gather for the backward pass too.
            so, (sqk_ct, sv_ct) = binned_attn_vjp(sqk, sv, so_ct)

            _, undo_sort = jax.lax.sort_key_val(sjoint_t,
                                                joint_t,
                                                dimension=-1)
            out = np.take_along_axis(so, undo_sort[:, :, None], axis=-2)

            qk_ct = np.take_along_axis(sqk_ct, undo_sort[:, :, None], axis=-2)
            v_ct = np.take_along_axis(sv_ct, undo_sort[:, :, None], axis=-2)

            return out, (qk_ct, np.zeros_like(ignored_k), v_ct)
Пример #13
0
    def forward_and_vjp(self, inputs, ct, params=(), **kwargs):
        del params, kwargs
        q, k, v = inputs
        # q/k/v are n_batch, n_heads, seqlen, d_head

        assert k.shape[2] % self.n_bins == 0
        bin_size = int(k.shape[2] // self.n_bins)

        # q_bins/kv_bins are n_batch, n_heads, seqlen
        # They specify which hash bucket the query/key/value vectors fall in. For
        # now, instead of hashing we just put consecutive items in the same bucket.
        q_bins = np.arange(q.shape[2], dtype=np.int32) // bin_size
        q_bins = jax.lax.tie_in(q, q_bins)
        q_bins = q_bins[None, None, :]
        q_bins = np.broadcast_to(q_bins, q.shape[:-1])
        q_bins = -q_bins
        kv_bins = q_bins * 2

        # q_t/kv_t are n_batch, n_heads, seqlen
        q_t = jax.lax.tie_in(q, np.arange(q.shape[2]))
        q_t = np.reshape(q_t, (1, 1, q_t.shape[0]))
        q_t = np.broadcast_to(q_t, q.shape[:-1])
        kv_t = q_t

        def chunk_rank3(x):
            return np.reshape(x, (x.shape[0], x.shape[1], self.n_bins, -1))

        def chunk_rank4(x):
            return np.reshape(
                x, (x.shape[0], x.shape[1], self.n_bins, -1, x.shape[-1]))

        def unchunk_rank4(x):
            return np.reshape(x, (x.shape[0], x.shape[1], -1, x.shape[-1]))

    # Sort everything by bin number (variables starting with "s" are sorted)

        _, sq_t = jax.lax.sort_key_val(q_bins, q_t, dimension=2)

        sq = np.take_along_axis(q, sq_t[:, :, :, None], axis=2)
        if ct is not None:
            so_ct = np.take_along_axis(ct, sq_t[:, :, :, None], axis=2)

        _, skv_t = jax.lax.sort_key_val(kv_bins, kv_t, dimension=2)
        sk = np.take_along_axis(k, skv_t[:, :, :, None], axis=2)
        sv = np.take_along_axis(v, skv_t[:, :, :, None], axis=2)

        @jax.jit
        def binned_attn(sq, sk, sv):
            """Performs attention on sorted queries/keys/values."""
            # Split off a "bin" axis so that attention only occurs whithin chunks.
            bq_t = chunk_rank3(sq_t)
            bkv_t = chunk_rank3(skv_t)
            bq = chunk_rank4(sq)
            bk = chunk_rank4(sk)
            bv = chunk_rank4(sv)

            dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(
                bq.shape[-1])

            # Causal masking
            mask = jax.lax.convert_element_type(
                jax.lax.lt(bq_t[:, :, :, :, None], bkv_t[:, :, :, None, :]),
                np.float32)
            dots = dots - 1e9 * mask

            # Softmax.
            dots = np.exp(dots - dots.max(axis=-1, keepdims=True))
            dots = dots / dots.sum(axis=-1, keepdims=True)
            bo = np.matmul(dots, bv)

            so = unchunk_rank4(bo)
            return so

        @jax.jit
        def binned_attn_vjp(sq, sk, sv, so_ct):
            so, vjpfun = jax.vjp(binned_attn, sq, sk, sv)
            sqkv_ct = vjpfun(so_ct)
            return so, sqkv_ct

        if ct is None:
            so = binned_attn(sq, sk, sv)
            _, undo_q_sort = jax.lax.sort_key_val(sq_t, q_t, dimension=2)
            out = np.take_along_axis(so, undo_q_sort[:, :, :, None], axis=2)
            return out, None
        else:
            # Jax can construct a backward pass automatically, but it's about 2x
            # slower than writing our own. The main reason is that the backward pass
            # of gather is in general a scatter operation, but we know we're dealing
            # with permutations so we use gather for the backward pass too.
            so, (sq_ct, sk_ct, sv_ct) = binned_attn_vjp(sq, sk, sv, so_ct)

            _, undo_q_sort = jax.lax.sort_key_val(sq_t, q_t, dimension=2)
            out = np.take_along_axis(so, undo_q_sort[:, :, :, None], axis=2)
            q_ct = np.take_along_axis(sq_ct,
                                      undo_q_sort[:, :, :, None],
                                      axis=2)

            _, undo_kv_sort = jax.lax.sort_key_val(skv_t, kv_t, dimension=2)
            k_ct = np.take_along_axis(sk_ct,
                                      undo_kv_sort[:, :, :, None],
                                      axis=2)
            v_ct = np.take_along_axis(sv_ct,
                                      undo_kv_sort[:, :, :, None],
                                      axis=2)

            return out, (q_ct, k_ct, v_ct)
Пример #14
0
def one_hot(x, size, dtype=np.float32):
    """Make a n+1 dim one-hot array from n dim int-categorical array."""
    return np.array(x[..., np.newaxis] == np.arange(size), dtype)
Пример #15
0
  def _forward_train_eval(self, inputs, rng):
    (inputs, original_len, n_bins) = self._pad_inputs(inputs)
    q, k, v = inputs
    seqlen = q.shape[-2]
    # q/k/v are n_batch*n_heads, seqlen, d_head
    # Time indices for causal masking.
    t = jax.lax.tie_in(q, np.arange(seqlen))

    # Split off a "bin" axis for chunks of consecutive items.
    bq_t = np.reshape(t, (n_bins, -1))
    bq = np.reshape(q, (q.shape[0], n_bins, -1, q.shape[-1]))
    if self._share_qk:
      bk = self.make_unit_length(bq)
    else:
      bk = np.reshape(k, (k.shape[0], n_bins, -1, k.shape[-1]))
    bv = np.reshape(v, (v.shape[0], n_bins, -1, v.shape[-1]))

    # Allow each chunk to attend within itself, and also one chunk back.
    def look_one_back(x):
      # Output: pairs [ bin_i bin_{i-1} ] concatenated on the time axis.
      if len(x.shape) == 2:
        x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0)
        return np.concatenate([x, x_extra], axis=1)
      else:
        assert len(x.shape) == 4
        x_extra = np.concatenate([x[:, -1:, :, :], x[:, :-1, :, :]], axis=1)
        return np.concatenate([x, x_extra], axis=2)

    bkv_t = look_one_back(bq_t)
    bk = look_one_back(bk)
    bv = look_one_back(bv)

    # Dot-product attention.
    dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(bq.shape[-1])

    # Causal masking based on the time indices.
    mask = jax.lax.convert_element_type(
        jax.lax.lt(bq_t[None, :, :, None], bkv_t[None, :, None, :]),
        np.float32)
    dots = dots - 1e9 * mask

    # Mask out attention to self except when no other targets are available.
    if self._share_qk:
      self_mask = jax.lax.broadcasted_eye(dots.dtype, dots.shape, (2, 3))
      self_mask = jax.lax.tie_in(dots, self_mask)
      dots = dots - 1e5 * self_mask

    if self.dropout > 0.0:
      # Dropout is broadcast across the batch+head dimension
      dropout_shape = (1, dots.shape[-3], dots.shape[-2], dots.shape[-1])
      keep_prob = jax.lax.tie_in(dots, 1.0 - self.dropout)
      keep = backend.random.bernoulli(rng, keep_prob, dropout_shape)
      multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(keep, keep_prob)
      dots = dots * multiplier

    # Softmax.
    dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))
    bo = np.matmul(dots, bv)

    output = np.reshape(bo, (bo.shape[0], -1, bo.shape[-1]))
    assert output.shape == v.shape
    return output[..., :original_len, :]
Пример #16
0
  def single_call(self, qk, v, buckets, hash_rng=None):
    # We use the same vector as both a query and a key.
    seqlen = qk.shape[-2]
    assert int(buckets.shape[0]) == self.n_hashes * seqlen

    ticker = jax.lax.tie_in(qk, np.arange(self.n_hashes * seqlen))
    buckets_and_t = seqlen * buckets + (ticker % seqlen)
    buckets_and_t = jax.lax.stop_gradient(buckets_and_t)

    # Hash-based sort ("s" at the start of variable names means "sorted")
    sbuckets_and_t, sticker = jax.lax.sort_key_val(
        buckets_and_t, ticker, dimension=-1)
    _, undo_sort = jax.lax.sort_key_val(sticker, ticker, dimension=-1)
    sbuckets_and_t = jax.lax.stop_gradient(sbuckets_and_t)
    sticker = jax.lax.stop_gradient(sticker)
    undo_sort = jax.lax.stop_gradient(undo_sort)

    st = (sticker % seqlen)
    sqk = np.take(qk, st, axis=0)
    sv = np.take(v, st, axis=0)

    # Split off a "bin" axis so that attention only occurs within chunks.
    bq_t = bkv_t = np.reshape(st, (self.n_hashes * self.n_bins, -1))
    bqk = np.reshape(sqk, (self.n_hashes * self.n_bins, -1, sqk.shape[-1]))
    bv = np.reshape(sv, (self.n_hashes * self.n_bins, -1, sv.shape[-1]))
    bq_buckets = bkv_buckets = np.reshape(
        sbuckets_and_t // seqlen, (self.n_hashes * self.n_bins, -1))

    # Hashing operates on unit-length vectors. Unnormalized query vectors are
    # fine because they effectively provide a learnable temperature for the
    # attention softmax, but normalizing keys is needed so that similarity for
    # the purposes of attention correctly corresponds to hash locality.
    bq = bqk
    bk = self.make_unit_length(bqk)

    # Allow each chunk to attend within itself, and also one chunk back. Chunk
    # boundaries might occur in the middle of a sequence of items from the
    # same bucket, so this increases the chances of attending to relevant items.
    # TODO(kitaev): benchmark whether XLA pad operation is noticeably faster.
    def look_one_back(x):
      if len(x.shape) == 2:
        x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0)
      else:
        x_extra = np.concatenate([x[-1:, :, :], x[:-1, :, :]], axis=0)
      return np.concatenate([x, x_extra], axis=1)

    bk = look_one_back(bk)
    bv = look_one_back(bv)
    bkv_t = look_one_back(bkv_t)
    bkv_buckets = look_one_back(bkv_buckets)

    # Dot-product attention.
    dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(bq.shape[-1])

    # Causal masking
    mask = jax.lax.convert_element_type(
        jax.lax.lt(bq_t[:, :, None], bkv_t[:, None, :]),
        np.float32)
    dots = dots - 1e9 * mask

    # Mask out attention to self except when no other targets are available.
    self_mask = jax.lax.convert_element_type(
        jax.lax.eq(bq_t[:, :, None], bkv_t[:, None, :]),
        np.float32)
    dots = dots - 1e5 * self_mask

    # Mask out attention to other hash buckets.
    if not self._attend_across_buckets:
      bucket_mask = jax.lax.convert_element_type(
          jax.lax.ne(bq_buckets[:, :, None], bkv_buckets[:, None, :]),
          np.float32)
      dots = dots - 1e7 * bucket_mask

    # Don't double-count query-key pairs across multiple rounds of hashing.
    # There are two possible strategies here. (1) The default is to count how
    # many times a query-key pair is repeated, and to lower its log-prob
    # correspondingly at each repetition. (2) When hard_k is set, the code
    # instead masks all but the first occurence of each query-key pair.
    # TODO(kitaev): is one strategy faster or more numerically stable?
    if not self._allow_duplicate_attention:
      locs1 = undo_sort // bq_t.shape[-1]
      locs2 = (locs1 + 1) % (self.n_hashes * self.n_bins)
      if not self._attend_across_buckets:
        locs1 = buckets * (self.n_hashes * self.n_bins) + locs1
        locs2 = buckets * (self.n_hashes * self.n_bins) + locs2
      locs = np.moveaxis(np.concatenate([
          np.reshape(locs1, (self.n_hashes, seqlen)),
          np.reshape(locs2, (self.n_hashes, seqlen)),
      ], 0), 0, -1)  # produces shape (seqlen, 2 * self.n_hashes)
      slocs = np.take(locs, st, axis=0)
      b_locs = np.reshape(
          slocs, (self.n_hashes * self.n_bins, -1, 2 * self.n_hashes))
      # Queries always use the primary location (based on locs1).
      b_locs1 = b_locs[:, :, None, :self.n_hashes]
      if self._hard_k > 0:
        range_n_hashes = jax.lax.tie_in(b_locs, np.arange(self.n_hashes))
        nouse_locs = (range_n_hashes[:, None] > range_n_hashes[None, :])
        nouse_locs = 2 * nouse_locs - 1  # 1 = use, -1 = don't use
        nouse_locs = np.reshape(
            np.broadcast_to(nouse_locs[:, None, :],
                            (self.n_hashes, self.n_bins, self.n_hashes)),
            (self.n_hashes * self.n_bins, 1, 1, self.n_hashes))
        b_locs1 = b_locs1 * nouse_locs
      bq_locs = np.broadcast_to(
          b_locs1,
          b_locs.shape[:2] + (2, self.n_hashes))
      bq_locs = np.reshape(bq_locs, b_locs.shape)
      bkv_locs = look_one_back(b_locs)

      dup_counts = np.sum(
          jax.lax.convert_element_type(
              jax.lax.eq(bq_locs[:, :, None, :], bkv_locs[:, None, :, :]),
              np.float32),
          axis=-1)
      assert dup_counts.shape == dots.shape
      if self._hard_k > 0:
        dots = dots - 1e7 * jax.lax.stop_gradient(dup_counts)
      else:
        dots = dots - jax.lax.stop_gradient(np.log(dup_counts + 1e-9))

    # Each query only attends to the top k most relevant keys.
    if self._hard_k > 0:
      b_top_dots = np.sort(dots)[..., -self._hard_k:]  # Get the top k dots.
      b_top_dots = jax.lax.stop_gradient(b_top_dots)
      s_top_dots = np.reshape(b_top_dots, (-1, self._hard_k))
      top_dots = np.take(s_top_dots, undo_sort, axis=0)

      merged_top_dots = np.moveaxis(
          np.reshape(top_dots, (self.n_hashes, seqlen, self._hard_k)), 0, -1)
      merged_top_dots = np.reshape(merged_top_dots, (seqlen, -1))

      dots_thresh = np.sort(merged_top_dots)[:, -self._hard_k]
      # It's possible to compute the partition function at this point, but right
      # now this codepath isn't set up for backprop, and there might also be
      # issues computing it this way if two dot-products are exactly equal.

      sdots_thresh = dots_thresh[st]
      bdots_thresh = np.reshape(sdots_thresh, (self.n_hashes * self.n_bins, -1))
      bdots_thresh = jax.lax.stop_gradient(bdots_thresh)

      top_k_mask = jax.lax.convert_element_type(
          dots < bdots_thresh[..., None], np.float32)
      dots = dots - 1e7 * jax.lax.stop_gradient(top_k_mask)

    # Softmax.
    dots_logsumexp = backend.logsumexp(dots, axis=-1, keepdims=True)
    dots = np.exp(dots - dots_logsumexp)

    bo = np.matmul(dots, bv)
    so = np.reshape(bo, (-1, bo.shape[-1]))
    slogits = np.reshape(dots_logsumexp, (-1,))

    def unsort_for_output_impl(so, slogits):
      o = np.take(so, undo_sort, axis=0)
      # Sorting is considerably faster than gather, but first we need to get the
      # XLA compiler to abandon the idea of fusing this sort with the input sort
      # (which introduces a computation cycle and leads to a crash).
      # TODO(kitaev): remove "sticker_" variable if XLA is fixed.
      sticker_ = sticker + jax.lax.convert_element_type(
          slogits[0] > 0, sticker.dtype)
      _, logits = jax.lax.sort_key_val(sticker_, slogits, dimension=-1)
      return o, logits

    def unsort_for_output_vjp(so, slogits):
      """Custom gradient for unsort_for_output."""
      so = jax.lax.stop_gradient(so)
      slogits = jax.lax.stop_gradient(slogits)
      o, logits = unsort_for_output_impl(so, slogits)
      def vjpfun(o_logits_grads):
        so_grad = np.take(o_logits_grads[0], sticker, axis=0)
        # TODO(kitaev): this exists to match the forward pass, but I'm not sure
        # if it's actually required.
        buckets_and_t_ = buckets_and_t + jax.lax.convert_element_type(
            o_logits_grads[1][0] > 0, buckets_and_t.dtype)
        _, slogits_grad = jax.lax.sort_key_val(
            buckets_and_t_, o_logits_grads[1], dimension=-1)
        return (so_grad, slogits_grad)
      return (o, logits), vjpfun

    unsort_for_output = jax.custom_transforms(unsort_for_output_impl)
    jax.defvjp_all(unsort_for_output, unsort_for_output_vjp)
    o, logits = unsort_for_output_impl(so, slogits)

    if self.n_hashes == 1:
      out = o
    else:
      o = np.reshape(o, (self.n_hashes, seqlen, o.shape[-1]))
      logits = np.reshape(logits, (self.n_hashes, seqlen, 1))
      probs = np.exp(logits - backend.logsumexp(logits, axis=0, keepdims=True))
      out = np.sum(o * probs, axis=0)

    assert out.shape == v.shape
    return out