Beispiel #1
0
 def unpermute_impl(vecs):
     assert len(vecs.shape) == 3
     return np.take_along_axis(vecs, undo_sort[:, :, None], axis=-2)
Beispiel #2
0
    def call_and_grad(self, inputs, ct, rng=None, **kwargs):
        del kwargs
        # We use the same vector as both a query and a key. For now we haven't
        # adjusted any of the surrounding code, so we still get a separate "key"
        # input that we ignore.
        qk, ignored_k, v = inputs
        seqlen = qk.shape[-2]
        # qk/v are n_batch*n_heads, seqlen, d_head

        # bins are n_batch*n_heads, seqlen
        # They specify which hash bucket the query/key/value vectors fall in.
        bins = self.hash_vectors(qk, rng=rng)

        # joint_t is n_batch*n_heads, seqlen
        joint_t = jax.lax.tie_in(qk, np.arange(seqlen))
        joint_t = np.reshape(joint_t, (1, seqlen))
        joint_t = np.broadcast_to(joint_t, qk.shape[:-1])

        assert int((self.n_bins + 1) * seqlen) < 2**31, (
            'Potential 32-bit integer overflow; please double-check the code.')
        joint_bins_and_t = seqlen * bins + joint_t

        def chunk_scalars(x):  # pylint: disable=invalid-name
            return np.reshape(x, (x.shape[0], self.n_bins, -1))

        def chunk_vectors(x):  # pylint: disable=invalid-name
            return np.reshape(x, (x.shape[0], self.n_bins, -1, x.shape[-1]))

        def unchunk_vectors(x):  # pylint: disable=invalid-name
            return np.reshape(x, (x.shape[0], -1, x.shape[-1]))

        # Sort everything by bin number, with a secondary sort by time
        # (variables starting with "s" are sorted)
        _, sjoint_t = jax.lax.sort_key_val(joint_bins_and_t,
                                           joint_t,
                                           dimension=-1)

        sqk = np.take_along_axis(qk, sjoint_t[:, :, None], axis=-2)
        sv = np.take_along_axis(v, sjoint_t[:, :, None], axis=-2)

        if ct is not None:
            so_ct = np.take_along_axis(ct, sjoint_t[:, :, None], axis=-2)

        @jax.jit
        def binned_attn(sqk, sv):  # pylint: disable=invalid-name
            """Performs attention on sorted queries/keys/values."""
            # Split off a "bin" axis so that attention only occurs whithin chunks.
            bq_t = bkv_t = chunk_scalars(sjoint_t)
            bqk = chunk_vectors(sqk)
            bv = chunk_vectors(sv)

            # Hashing operates on unit-length vectors. Unnormalized query vectors are
            # fine because they effectively provide a learnable temperature for the
            # attention softmax, but normalizing keys is needed so that similarity for
            # the purposes of attention correctly corresponds to hash locality.
            bq = bqk
            bk = self.make_unit_length(bqk)

            # Allow each chunk to attend within itself, and also one chunk back. Chunk
            # boundaries might occur in the middle of a sequence of items from the
            # same bin, so this increases the chances of attending to relevant items.
            # TODO(kitaev): benchmark whether XLA pad operation is noticeably faster.
            bk_extra = np.concatenate([bk[:, -1:, :, :], bk[:, :-1, :, :]],
                                      axis=1)
            bk = np.concatenate([bk, bk_extra], axis=2)
            bv_extra = np.concatenate([bv[:, -1:, :, :], bv[:, :-1, :, :]],
                                      axis=1)
            bv = np.concatenate([bv, bv_extra], axis=2)
            bkv_t_extra = np.concatenate([bkv_t[:, -1:, :], bkv_t[:, :-1, :]],
                                         axis=1)
            bkv_t = np.concatenate([bkv_t, bkv_t_extra], axis=2)

            # Dot-product attention.
            dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(
                bq.shape[-1])

            # Causal masking
            mask = jax.lax.convert_element_type(
                jax.lax.lt(bq_t[:, :, :, None], bkv_t[:, :, None, :]),
                np.float32)
            dots = dots - 1e9 * mask

            # Mask out attention to self except when no other targets are available.
            self_mask = jax.lax.broadcasted_eye(dots.dtype, dots.shape, (2, 3))
            self_mask = jax.lax.tie_in(dots, self_mask)
            dots = dots - 32 * self_mask

            # Softmax.
            dots = np.exp(dots -
                          backend.logsumexp(dots, axis=-1, keepdims=True))
            bo = np.matmul(dots, bv)

            so = unchunk_vectors(bo)
            return so

        @jax.jit
        def binned_attn_vjp(sqk, sv, so_ct):  # pylint: disable=invalid-name
            so, vjpfun = jax.vjp(binned_attn, sqk, sv)
            sqkv_ct = vjpfun(so_ct)
            return so, sqkv_ct

        if ct is None:
            so = binned_attn(sqk, sv)
            _, undo_sort = jax.lax.sort_key_val(sjoint_t,
                                                joint_t,
                                                dimension=-1)
            out = np.take_along_axis(so, undo_sort[:, :, None], axis=-2)
            return out, None
        else:
            # Jax can construct a backward pass automatically, but it's about 2x
            # slower than writing our own. The main reason is that the backward pass
            # of gather is in general a scatter operation, but we know we're dealing
            # with permutations so we use gather for the backward pass too.
            so, (sqk_ct, sv_ct) = binned_attn_vjp(sqk, sv, so_ct)

            _, undo_sort = jax.lax.sort_key_val(sjoint_t,
                                                joint_t,
                                                dimension=-1)
            out = np.take_along_axis(so, undo_sort[:, :, None], axis=-2)

            qk_ct = np.take_along_axis(sqk_ct, undo_sort[:, :, None], axis=-2)
            v_ct = np.take_along_axis(sv_ct, undo_sort[:, :, None], axis=-2)

            return out, (qk_ct, np.zeros_like(ignored_k), v_ct)
Beispiel #3
0
 def permute_impl(vecs):
     assert len(vecs.shape) == 3
     return np.take_along_axis(vecs, sjoint_t[:, :, None], axis=-2)
    def forward_and_vjp(self, inputs, ct, params=(), **kwargs):
        del params, kwargs
        q, k, v = inputs
        # q/k/v are n_batch, n_heads, seqlen, d_head

        assert k.shape[2] % self.n_bins == 0
        bin_size = int(k.shape[2] // self.n_bins)

        # q_bins/kv_bins are n_batch, n_heads, seqlen
        # They specify which hash bucket the query/key/value vectors fall in. For
        # now, instead of hashing we just put consecutive items in the same bucket.
        q_bins = np.arange(q.shape[2], dtype=np.int32) // bin_size
        q_bins = jax.lax.tie_in(q, q_bins)
        q_bins = q_bins[None, None, :]
        q_bins = np.broadcast_to(q_bins, q.shape[:-1])
        q_bins = -q_bins
        kv_bins = q_bins * 2

        # q_t/kv_t are n_batch, n_heads, seqlen
        q_t = jax.lax.tie_in(q, np.arange(q.shape[2]))
        q_t = np.reshape(q_t, (1, 1, q_t.shape[0]))
        q_t = np.broadcast_to(q_t, q.shape[:-1])
        kv_t = q_t

        def chunk_rank3(x):
            return np.reshape(x, (x.shape[0], x.shape[1], self.n_bins, -1))

        def chunk_rank4(x):
            return np.reshape(
                x, (x.shape[0], x.shape[1], self.n_bins, -1, x.shape[-1]))

        def unchunk_rank4(x):
            return np.reshape(x, (x.shape[0], x.shape[1], -1, x.shape[-1]))

    # Sort everything by bin number (variables starting with "s" are sorted)

        _, sq_t = jax.lax.sort_key_val(q_bins, q_t, dimension=2)

        sq = np.take_along_axis(q, sq_t[:, :, :, None], axis=2)
        if ct is not None:
            so_ct = np.take_along_axis(ct, sq_t[:, :, :, None], axis=2)

        _, skv_t = jax.lax.sort_key_val(kv_bins, kv_t, dimension=2)
        sk = np.take_along_axis(k, skv_t[:, :, :, None], axis=2)
        sv = np.take_along_axis(v, skv_t[:, :, :, None], axis=2)

        @jax.jit
        def binned_attn(sq, sk, sv):
            """Performs attention on sorted queries/keys/values."""
            # Split off a "bin" axis so that attention only occurs whithin chunks.
            bq_t = chunk_rank3(sq_t)
            bkv_t = chunk_rank3(skv_t)
            bq = chunk_rank4(sq)
            bk = chunk_rank4(sk)
            bv = chunk_rank4(sv)

            dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(
                bq.shape[-1])

            # Causal masking
            mask = jax.lax.convert_element_type(
                jax.lax.lt(bq_t[:, :, :, :, None], bkv_t[:, :, :, None, :]),
                np.float32)
            dots = dots - 1e9 * mask

            # Softmax.
            dots = np.exp(dots - dots.max(axis=-1, keepdims=True))
            dots = dots / dots.sum(axis=-1, keepdims=True)
            bo = np.matmul(dots, bv)

            so = unchunk_rank4(bo)
            return so

        @jax.jit
        def binned_attn_vjp(sq, sk, sv, so_ct):
            so, vjpfun = jax.vjp(binned_attn, sq, sk, sv)
            sqkv_ct = vjpfun(so_ct)
            return so, sqkv_ct

        if ct is None:
            so = binned_attn(sq, sk, sv)
            _, undo_q_sort = jax.lax.sort_key_val(sq_t, q_t, dimension=2)
            out = np.take_along_axis(so, undo_q_sort[:, :, :, None], axis=2)
            return out, None
        else:
            # Jax can construct a backward pass automatically, but it's about 2x
            # slower than writing our own. The main reason is that the backward pass
            # of gather is in general a scatter operation, but we know we're dealing
            # with permutations so we use gather for the backward pass too.
            so, (sq_ct, sk_ct, sv_ct) = binned_attn_vjp(sq, sk, sv, so_ct)

            _, undo_q_sort = jax.lax.sort_key_val(sq_t, q_t, dimension=2)
            out = np.take_along_axis(so, undo_q_sort[:, :, :, None], axis=2)
            q_ct = np.take_along_axis(sq_ct,
                                      undo_q_sort[:, :, :, None],
                                      axis=2)

            _, undo_kv_sort = jax.lax.sort_key_val(skv_t, kv_t, dimension=2)
            k_ct = np.take_along_axis(sk_ct,
                                      undo_kv_sort[:, :, :, None],
                                      axis=2)
            v_ct = np.take_along_axis(sv_ct,
                                      undo_kv_sort[:, :, :, None],
                                      axis=2)

            return out, (q_ct, k_ct, v_ct)