Exemple #1
0
 def forward(self, x, weights):
     return np.take(weights, x, axis=0)
Exemple #2
0
    def forward_unbatched(self, x, *, weights, state, update_state):
        w_q, w_v, w_o = weights

        q = np.matmul(x, w_q)
        v = np.matmul(x, w_v)

        if update_state:
            _, old_rng = state
            rng = jax.random.fold_in(old_rng, 0)
            hash_rng = jax.random.fold_in(rng, 1)
            buckets = self.hash_vectors(q, hash_rng)
            state = (buckets, rng)
        else:
            buckets, rng = state

        rng = jax.random.fold_in(rng, 2)

        seqlen = x.shape[0]
        assert int(buckets.shape[0]) == self.n_hashes * seqlen

        ticker = jax.lax.tie_in(x, np.arange(self.n_hashes * seqlen))
        buckets_and_t = seqlen * buckets + (ticker % seqlen)
        buckets_and_t = jax.lax.stop_gradient(buckets_and_t)

        # Hash-based sort ("s" at the start of variable names means "sorted")
        sbuckets_and_t, sticker = jax.lax.sort_key_val(buckets_and_t,
                                                       ticker,
                                                       dimension=-1)
        _, undo_sort = jax.lax.sort_key_val(sticker, ticker, dimension=-1)
        sbuckets_and_t = jax.lax.stop_gradient(sbuckets_and_t)
        sticker = jax.lax.stop_gradient(sticker)
        undo_sort = jax.lax.stop_gradient(undo_sort)

        st = (sticker % seqlen)
        sq = np.take(q, st, axis=0)
        sv = np.take(v, st, axis=0)

        mask_fn = functools.partial(mask_self_attention,
                                    causal=self.causal,
                                    exclude_self=True)
        q_info = st
        so, slogits = attend(
            sq,
            k=None,
            v=sv,
            q_chunk_len=self.chunk_len,
            n_chunks_before=self.n_chunks_before,
            n_chunks_after=self.n_chunks_after,
            mask_fn=mask_fn,
            q_info=q_info,
            dropout=self.attention_dropout,
            rng=rng,
        )

        # np.take(so, undo_sort, axis=0); np.take(slogits, undo_sort, axis=0) would
        # also work, but these helpers include performance optimizations for TPU.
        o = permute_via_gather(so, undo_sort, sticker, axis=0)
        logits = permute_via_sort(slogits, sticker, buckets_and_t, axis=-1)

        if self.n_hashes > 1:
            o = np.reshape(o, (self.n_hashes, seqlen, o.shape[-1]))
            logits = np.reshape(logits, (self.n_hashes, seqlen, 1))
            probs = np.exp(logits - logsumexp(logits, axis=0, keepdims=True))
            o = np.sum(o * probs, axis=0)

        assert o.shape == (seqlen, w_v.shape[-1])
        out = np.matmul(o, w_o)
        return out, state
Exemple #3
0
 def permute_impl(val):
     return np.take(val, permutation, axis=axis)
Exemple #4
0
 def vjpfun(permuted_grad):
     # JAX autodiff would synthesize a scatter operation because it doesn't
     # know that the indices are a permutatation. However on TPU, gathers are
     # faster than scatters (at least in the regime the LSH attention uses).
     return (np.take(permuted_grad, inverse_permutation, axis=axis), )
Exemple #5
0
    def forward_unbatched(self, x, *, weights, state, update_state):
        w_q, w_v, w_o = weights

        q = np.matmul(x, w_q)
        v = np.matmul(x, w_v)

        if update_state:
            _, old_rng = state
            rng = jax.random.fold_in(old_rng, 0)
            hash_rng = jax.random.fold_in(rng, 1)
            buckets = self.hash_vectors(q, hash_rng)
            state = (buckets, rng)
        else:
            buckets, rng = state

        rng = jax.random.fold_in(rng, 2)

        seqlen = x.shape[0]
        assert int(buckets.shape[0]) == self.n_hashes * seqlen

        ticker = jax.lax.tie_in(x, np.arange(self.n_hashes * seqlen))
        buckets_and_t = seqlen * buckets + (ticker % seqlen)
        buckets_and_t = jax.lax.stop_gradient(buckets_and_t)

        # Hash-based sort ("s" at the start of variable names means "sorted")
        sbuckets_and_t, sticker = jax.lax.sort_key_val(buckets_and_t,
                                                       ticker,
                                                       dimension=-1)
        _, undo_sort = jax.lax.sort_key_val(sticker, ticker, dimension=-1)
        sbuckets_and_t = jax.lax.stop_gradient(sbuckets_and_t)
        sticker = jax.lax.stop_gradient(sticker)
        undo_sort = jax.lax.stop_gradient(undo_sort)

        st = (sticker % seqlen)
        sq = np.take(q, st, axis=0)
        sv = np.take(v, st, axis=0)

        mask_fn = functools.partial(mask_self_attention,
                                    causal=self.causal,
                                    exclude_self=True)
        q_info = st
        so, slogits = attend(
            sq,
            k=None,
            v=sv,
            q_chunk_len=self.chunk_len,
            n_chunks_before=self.n_chunks_before,
            n_chunks_after=self.n_chunks_after,
            mask_fn=mask_fn,
            q_info=q_info,
            dropout=self.attention_dropout,
            rng=rng,
        )

        def unsort_for_output_impl(so, slogits):
            o = np.take(so, undo_sort, axis=0)
            # Sorting is considerably faster than gather, but first we need to get the
            # XLA compiler to abandon the idea of fusing this sort with the input sort
            # (which introduces a computation cycle and leads to a crash).
            # TODO(kitaev): remove "sticker_" variable if XLA is fixed.
            sticker_ = sticker + jax.lax.convert_element_type(
                slogits[0] > 0, sticker.dtype)
            _, logits = jax.lax.sort_key_val(sticker_, slogits, dimension=-1)
            return o, logits

        def unsort_for_output_vjp(so, slogits):
            """Custom gradient for unsort_for_output."""
            so = jax.lax.stop_gradient(so)
            slogits = jax.lax.stop_gradient(slogits)
            o, logits = unsort_for_output_impl(so, slogits)

            def vjpfun(o_logits_grads):
                so_grad = np.take(o_logits_grads[0], sticker, axis=0)
                # TODO(kitaev): this exists to match the forward pass, but I'm not sure
                # if it's actually required.
                buckets_and_t_ = buckets_and_t + jax.lax.convert_element_type(
                    o_logits_grads[1][0] > 0, buckets_and_t.dtype)
                _, slogits_grad = jax.lax.sort_key_val(buckets_and_t_,
                                                       o_logits_grads[1],
                                                       dimension=-1)
                return (so_grad, slogits_grad)

            return (o, logits), vjpfun

        unsort_for_output = jax.custom_transforms(unsort_for_output_impl)
        jax.defvjp_all(unsort_for_output, unsort_for_output_vjp)
        o, logits = unsort_for_output_impl(so, slogits)

        if self.n_hashes > 1:
            o = np.reshape(o, (self.n_hashes, seqlen, o.shape[-1]))
            logits = np.reshape(logits, (self.n_hashes, seqlen, 1))
            probs = np.exp(logits - logsumexp(logits, axis=0, keepdims=True))
            o = np.sum(o * probs, axis=0)

        assert o.shape == (seqlen, w_v.shape[-1])
        out = np.matmul(o, w_o)
        return out, state