Exemplo n.º 1
0
def run_policy(
    policy_and_value_net_apply,
    observations,
    lengths,
    weights,
    state,
    rng,
    action_space,
):
    """Runs the policy network."""
    # TODO(pkozakowski): Pass the actual actions here, to enable autoregressive
    # action sampling.
    (B, T_plus_1) = observations.shape[:2]  # pylint: disable=invalid-name
    dummy_actions = onp.zeros((B, T_plus_1 - 1) + action_space.shape,
                              dtype=action_space.dtype)
    policy_input = (observations, dummy_actions)
    (rng, subrng) = trax_random.split(rng)
    (log_probs, value_preds) = policy_and_value_net_apply(policy_input,
                                                          weights=weights,
                                                          state=state,
                                                          rng=subrng)
    # We need the log_probs of those actions that correspond to the last actual
    # time-step.
    index = lengths - 1  # Since we want to index using lengths.
    log_probs = log_probs[np.arange(B), index]
    value_preds = value_preds[np.arange(B), index]
    return (log_probs, value_preds, state, rng)
Exemplo n.º 2
0
def run_policy(
    policy_and_value_net_apply,
    observations,
    lengths,
    weights,
    state,
    rng,
    action_space,
):
    """Runs the policy network and returns lps, vps for the last timestep."""
    log_probs, value_preds, state, rng = run_policy_all_timesteps(
        policy_and_value_net_apply,
        observations,
        weights,
        state,
        rng,
        action_space,
    )

    # We need the log_probs of those actions that correspond to the last actual
    # time-step.
    (B, unused_T_plus_1) = observations.shape[:2]  # pylint: disable=invalid-name
    index = lengths - 1  # Since we want to index using lengths.
    log_probs = log_probs[np.arange(B), index]
    value_preds = value_preds[np.arange(B), index]
    return (log_probs, value_preds, state, rng)
Exemplo n.º 3
0
def actor_loss(actions, advantage_weights, log_probab_actions_new, state=None):
    """Actor loss."""

    # log_probab_actions_new's shape is (AB, 1, #C, #A), AB is actor batch.
    lp = jnp.squeeze(log_probab_actions_new, axis=1)
    AB, NC = actions.shape  # pylint: disable=invalid-name
    log_probs = lp[jnp.arange(AB)[:, None], jnp.arange(NC)[None, :], actions]

    # TODO(afrozm): Clarify this.
    #   log_probs are shaped (AB, #C), however advantage_weights are (AB,)
    return -1.0 * jnp.mean(log_probs * advantage_weights[:, None]), state
Exemplo n.º 4
0
def one_hot(x, n_categories, dtype=np.float32):  # pylint: disable=invalid-name
    """Makes a one-hot array (n+1 dims) from an int-categorical array (n dims)."""
    indices_less_than_n = np.arange(n_categories)
    if math.backend_name() == 'jax':
        # Work around a jax broadcasting issue.
        indices_less_than_n = jax.lax.tie_in(x, indices_less_than_n)
    return np.array(x[..., np.newaxis] == indices_less_than_n, dtype)
Exemplo n.º 5
0
    def forward_with_state(self,
                           inputs,
                           weights=layer_base.EMPTY_WEIGHTS,
                           state=layer_base.EMPTY_STATE,
                           rng=None,
                           **kwargs):
        depth = inputs.shape[-1]

        if self._mode == 'predict':
            emb = self._get_embeddings(t=state)
            emb = emb[:, np.newaxis, :]
            state = state + 1
        else:
            input_len = inputs.shape[-2]
            emb = self._get_embeddings(t=np.arange(input_len, dtype=np.int32))
            # Leave batch axis as 1 for broadcasting:
            emb = emb[np.newaxis, :, :]
            emb = np.broadcast_to(emb, inputs.shape[:-1] + (3, ))

        # Replace the last num_features channels of input.
        inputs = np.concatenate([inputs[..., :-self.num_features], emb], -1)
        if inputs.shape[-1] > depth:
            logging.warning('dropping feature(s): %d down to %d',
                            inputs.shape[-1], depth)
            inputs = inputs[..., -depth:]

        assert inputs.shape[-1] == depth, inputs.shape
        return inputs, state
Exemplo n.º 6
0
 def forward_with_state(self, x, weights, state, rng):
     batch_size, length = x.shape[0], x.shape[1]
     max_pos = min(self._bases)**self._n_digits
     rng1, rng2, rng3 = math.random.split(rng, 3)
     assert length < max_pos, 'length (%d) >= max_pos (%d)' % (length,
                                                               max_pos)
     positions = jnp.arange(0, length)[None, :]
     if self._mode == 'train':
         # In 1% of training cases still start from 0 to be exactly as in eval.
         start_from_nonzero = jax.random.randint(
             rng1, (batch_size, ), 0, self._start_from_zero_one_in)
         start_from_nonzero = jnp.minimum(1, start_from_nonzero)
         random_start = jax.random.randint(rng2, (batch_size, ), 0,
                                           max_pos - length)
         random_start *= start_from_nonzero
         positions += random_start[:, None]
     res = []
     for bn, base in enumerate(self._bases):
         pos_embeddings = []
         cur_positions = positions
         for i in range(self._n_digits):
             cur_indices = jnp.mod(cur_positions, base)
             cur_positions = cur_positions // base
             s = weights[bn][i]
             pos_embeddings.append(
                 cur_indices.astype(jnp.float32)[:, :, None] * s)
         embeddings = jnp.concatenate(pos_embeddings, axis=-1)
         if self._mode == 'train':
             base_dropout = jax.random.randint(rng3, (batch_size, ), 0,
                                               self._base_dropout_one_in)
             base_dropout = jnp.minimum(1, base_dropout).astype(jnp.float32)
             embeddings *= base_dropout[:, None, None]
         res.append(embeddings)
     res = sum(res) + jnp.zeros_like(x)
     return jnp.concatenate([x, res], axis=-1), state
Exemplo n.º 7
0
def one_hot(x, size, dtype=np.float32):  # pylint: disable=invalid-name
    """Make a n+1 dim one-hot array from n dim int-categorical array."""
    arange_size = np.arange(size)
    if math.backend_name() == 'jax':
        # Work around a jax broadcasting issue.
        arange_size = jax.lax.tie_in(x, arange_size)
    return np.array(x[..., np.newaxis] == arange_size, dtype)
Exemplo n.º 8
0
    def forward_unbatched(self, x, *, weights, state, update_state):
        del update_state
        if self.share_qk:
            w_q, w_v, w_o = weights
        else:
            w_q, w_k, w_v, w_o = weights

        q = np.matmul(x, w_q)
        k = None
        if not self.share_qk:
            k = np.matmul(x, w_k)
        v = np.matmul(x, w_v)

        mask_fn = functools.partial(mask_self_attention,
                                    causal=self.causal,
                                    exclude_self=self.share_qk)
        q_info = kv_info = jax.lax.tie_in(x, np.arange(q.shape[-2]))
        o, _ = attend(
            q,
            k,
            v,
            q_chunk_len=self.chunk_len,
            kv_chunk_len=self.chunk_len,
            n_chunks_before=self.n_chunks_before,
            n_chunks_after=self.n_chunks_after,
            mask_fn=mask_fn,
            q_info=q_info,
            kv_info=kv_info,
            dropout=self.attention_dropout,
            rng=None,  # TODO(kitaev): support RNG
        )

        out = np.matmul(o, w_o)
        return out, state
Exemplo n.º 9
0
 def log_prob(self, inputs, point):
     # Flatten the prefix dimensions for easy indexing.
     flat_point = np.reshape(point, -1)
     flat_inputs = np.reshape(inputs, (point.size, -1))
     flat_log_probs = flat_inputs[np.arange(point.size),
                                  flat_point.astype(int)]
     return np.reshape(flat_log_probs, point.shape)
Exemplo n.º 10
0
    def forward(self, inputs, weights):
        state = self.state
        depth = inputs.shape[-1]

        if self._mode == 'predict':
            emb = self._get_embeddings(t=state)
            emb = emb[:, jnp.newaxis, :]
            state = state + 1
        else:
            input_len = inputs.shape[-2]
            emb = self._get_embeddings(
                t=jnp.arange(input_len, dtype=jnp.int32))
            # Leave batch axis as 1 for broadcasting:
            emb = emb[jnp.newaxis, :, :]
            emb = jnp.broadcast_to(emb, inputs.shape[:-1] + (3, ))

        # Replace the last num_features channels of input.
        inputs = jnp.concatenate([inputs[..., :-self.num_features], emb], -1)
        if inputs.shape[-1] > depth:
            logging.warning('dropping feature(s): %d down to %d',
                            inputs.shape[-1], depth)
            inputs = inputs[..., -depth:]

        assert inputs.shape[-1] == depth, inputs.shape
        self.state = state
        return inputs
Exemplo n.º 11
0
def actor_loss(actions,
               advantage_weights,
               log_probab_actions_new,
               state=None):
  """Actor loss."""
  lp = np.squeeze(log_probab_actions_new)
  b = len(lp)
  log_probs = np.squeeze(lp[np.arange(b)[np.newaxis, :], actions])

  return -1.0 * np.mean(log_probs * advantage_weights), state
Exemplo n.º 12
0
  def F(vec_e, vec_d, mask_e, mask_d):
    # pylint: disable=invalid-name
    L1 = mask_e.shape[1]
    L2 = mask_d.shape[1]
    # pylint: enable=invalid-name

    # [-(L1+L2), -L2) but with padding 0-ed out - (B, L1).
    mask_e_key = jnp.arange(-(L1 + L2), -L2) * mask_e
    # [-L2,0) but with padding 0-ed out - (B, L2).
    mask_d_key = jnp.arange(-L2, 0) * mask_d

    # Shape (B, L1+L2, H)
    enc_dec_concat = jnp.concatenate([vec_e, vec_d], axis=1)
    # Shape (B, L1+L2)
    mask_concat = jnp.concatenate([mask_e_key, mask_d_key], axis=1)
    # Make `mask_concat` the same shape as `enc_dec_concat`
    mask_concat = (
        mask_concat[..., jnp.newaxis] +
        jnp.zeros_like(enc_dec_concat, dtype=jnp.int32))
    # Sort on `mask_concat` so padding with key=0 goes to the right end, axis=1.
    _, enc_dec_pad = math.sort_key_val(mask_concat, enc_dec_concat, 1)

    return enc_dec_pad
Exemplo n.º 13
0
    def hash_vectors(self, vecs, rng):
        # See https://arxiv.org/pdf/1509.02897.pdf
        # We sample a different random rotation for each round of hashing to
        # decrease the probability of hash misses.
        if isinstance(self.n_buckets, int):
            assert self.n_buckets % 2 == 0
            rot_size = self.n_buckets
            n_buckets = self.n_buckets
        else:
            # Factorize the hash if self.n_buckets is a list or tuple
            rot_size, n_buckets = 0, 1
            for factor in self.n_buckets:
                assert factor % 2 == 0
                rot_size += factor
                n_buckets *= factor

        rotations_shape = (vecs.shape[-1], self.n_hashes, rot_size // 2)

        rng = jax.lax.stop_gradient(jax.lax.tie_in(vecs, rng))
        random_rotations = jax.random.normal(rng,
                                             rotations_shape).astype('float32')
        rotated_vecs = np.einsum('tf,fhb->htb', vecs, random_rotations)

        if isinstance(self.n_buckets, int) or len(self.n_buckets) == 1:
            rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs],
                                          axis=-1)
            buckets = np.argmax(rotated_vecs, axis=-1)
        else:
            # Get the buckets for them and combine.
            buckets, cur_sum, cur_product = None, 0, 1
            for factor in self.n_buckets:
                rv = rotated_vecs[..., cur_sum:cur_sum + (factor // 2)]
                cur_sum += factor // 2
                rv = np.concatenate([rv, -rv], axis=-1)
                if buckets is None:
                    buckets = np.argmax(rv, axis=-1)
                else:
                    buckets += cur_product * np.argmax(rv, axis=-1)
                cur_product *= factor

        # buckets is now (self.n_hashes, seqlen). Next we add offsets so that
        # bucket numbers from different hashing rounds don't overlap.
        offsets = jax.lax.tie_in(buckets, np.arange(self.n_hashes))
        offsets = np.reshape(offsets * n_buckets, (-1, 1))
        buckets = np.reshape(buckets + offsets, (-1, ))

        return buckets
Exemplo n.º 14
0
def _fast_inference_update_state(inputs, state):
    """Updates state of a causal attention layer for fast inference."""
    assert math.backend_name() == 'jax', (
        'JAX backend is required to use the predict mode.')
    for x in inputs:
        assert x.shape[1] == 1, (
            'In predict mode the input sequence must be of length 1.')
    # Fast inference: run with only 1 query in each step, storing the sequence
    # of keys and values calculated so far in state.
    (_, new_k, new_v) = inputs
    (ks, vs, mask, seq_indices) = state
    batch_indices = np.arange(ks.shape[0])
    ks = jax.ops.index_update(ks, jax.ops.index[batch_indices, seq_indices, :],
                              new_k[:, 0, :])
    vs = jax.ops.index_update(vs, jax.ops.index[batch_indices, seq_indices, :],
                              new_v[:, 0, :])
    mask = jax.ops.index_update(mask, jax.ops.index[batch_indices, :,
                                                    seq_indices], 1)
    return (ks, vs, mask, seq_indices + 1)
Exemplo n.º 15
0
    def forward_unbatched(self, x, mask=None, *, weights, state, update_state):
        del update_state
        if self.share_qk:
            w_q, w_v, w_o = weights
        else:
            w_q, w_k, w_v, w_o = weights

        q = np.matmul(x, w_q)
        k = None
        if not self.share_qk:
            k = np.matmul(x, w_k)
        v = np.matmul(x, w_v)

        mask_fn = functools.partial(mask_self_attention,
                                    causal=self.causal,
                                    exclude_self=self.share_qk,
                                    masked=self.masked)
        q_info = kv_info = jax.lax.tie_in(x, np.arange(q.shape[-2]))

        assert (mask is not None) == self.masked
        if self.masked:
            # mask is a boolean array (True means "is valid token")
            ones_like_mask = jax.lax.tie_in(x,
                                            np.ones_like(mask, dtype=np.int32))
            kv_info = kv_info * np.where(mask, ones_like_mask, -ones_like_mask)

        o, _ = attend(
            q,
            k,
            v,
            q_chunk_len=self.chunk_len,
            kv_chunk_len=self.chunk_len,
            n_chunks_before=self.n_chunks_before,
            n_chunks_after=self.n_chunks_after,
            mask_fn=mask_fn,
            q_info=q_info,
            kv_info=kv_info,
            dropout=self.attention_dropout,
            rng=None,  # TODO(kitaev): support RNG
        )

        out = np.matmul(o, w_o)
        return out, state
Exemplo n.º 16
0
 def forward_with_state(self,
                        x,
                        weights=layer_base.EMPTY_WEIGHTS,
                        state=layer_base.EMPTY_STATE,
                        rng=None,
                        **kwargs):
     length = np.shape(x)[1]
     max_pos = self._base**self._n_digits
     assert length < max_pos, 'length (%d) >= max_pos (%d)' % (length,
                                                               max_pos)
     positions = np.arange(0, length)
     if self._mode == 'train':
         positions += jax.random.randint(rng, (), 0, max_pos - length)
     pos_embeddings = []
     cur_positions = positions
     for i in range(self._n_digits):
         cur_indices = np.mod(cur_positions, self._base)
         cur_positions //= self._base
         pos_embeddings.append(np.take(weights[i], cur_indices, axis=0))
     embeddings = np.concatenate(pos_embeddings, axis=-1)
     return (x + embeddings[None, :, :], state)
Exemplo n.º 17
0
def _fast_inference_update_state(inputs, state):
  """Updates state of a causal attention layer for fast inference."""
  if math.backend_name() != 'jax':
    raise ValueError(f'JAX backend is required in predict mode, but found '
                     f'backend ({math.backend_nameO()}).')
  for x in inputs:
    if x.shape[1] != 1:
      raise ValueError(f'In predict mode, input sequence must have length 1, '
                       f'instead has length {x.shape[1]}.')
  # Fast inference: run with only 1 query in each step, storing the sequence
  # of keys and values calculated so far in state.
  (_, new_k, new_v) = inputs
  (ks, vs, mask, seq_indices) = state
  batch_indices = jnp.arange(ks.shape[0])
  ks = jax.ops.index_update(
      ks, jax.ops.index[batch_indices, seq_indices, :], new_k[:, 0, :])
  vs = jax.ops.index_update(
      vs, jax.ops.index[batch_indices, seq_indices, :], new_v[:, 0, :])
  mask = jax.ops.index_update(
      mask, jax.ops.index[batch_indices, :, seq_indices], 1)
  return (ks, vs, mask, seq_indices + 1)
Exemplo n.º 18
0
 def test_batch_norm(self):
     input_shape = (2, 3, 4)
     input_dtype = np.float32
     input_signature = ShapeDtype(input_shape, input_dtype)
     eps = 1e-5
     inp1 = np.reshape(np.arange(np.prod(input_shape), dtype=input_dtype),
                       input_shape)
     m1 = 11.5  # Mean of this random input.
     v1 = 47.9167  # Variance of this random input.
     layer = normalization.BatchNorm(axis=(0, 1, 2))
     _, _ = layer.init(input_signature)
     state = layer.state
     onp.testing.assert_allclose(state[0], 0)
     onp.testing.assert_allclose(state[1], 1)
     self.assertEqual(state[2], 0)
     out = layer(inp1)
     state = layer.state
     onp.testing.assert_allclose(state[0], m1 * 0.001)
     onp.testing.assert_allclose(state[1], 0.999 + v1 * 0.001, rtol=1e-6)
     self.assertEqual(state[2], 1)
     onp.testing.assert_allclose(out, (inp1 - m1) / np.sqrt(v1 + eps),
                                 rtol=1e-6)
Exemplo n.º 19
0
def threefry_2x32_prange(key, lo: int = 0, hi: int = 2):
    """Splits a key into a stream of random keys.

  This uses the little-endian counter mode.

  Args:
    key: uint32[2] the key to split
    lo: the range to start extracting from
    hi: the range to stop extracting from

  Returns:
    keys: uint32[hi - lo, 2] the split keys
  """
    if not (key.shape == (2, ) and key.dtype == np.uint32):
        raise ValueError('key must be uint32[2]')
    if not hi < 2**32:
        # You shouldn't really be using more than half the key size anyways.
        raise NotImplementedError('only 32-bit sizes are supported')
    # Create a 64-bit counter:
    i_lo = np.arange(lo, hi, dtype=np.uint32)
    i_hi = np.zeros_like(i_lo)
    i = np.stack([i_lo, i_hi], axis=-1)
    return threefry_2x32_prf(key, i)
Exemplo n.º 20
0
def top_k(x, k):
  """Select the top k slices from the last dimension."""
  bcast_idxs = jnp.broadcast_to(np.arange(x.shape[-1]), x.shape)
  sorted_vals, sorted_idxs = lax.sort_key_val(x, bcast_idxs)
  # TODO(levskaya): use lax.slice here instead to benefit from XLA optimization
  return sorted_vals[..., -k:], sorted_idxs[..., -k:]
Exemplo n.º 21
0
    def forward_unbatched(self, x, *, weights, state, update_state):
        w_q, w_v, w_o = weights

        q = np.matmul(x, w_q)
        v = np.matmul(x, w_v)

        if update_state:
            _, old_rng = state
            rng = jax.random.fold_in(old_rng, 0)
            hash_rng = jax.random.fold_in(rng, 1)
            buckets = self.hash_vectors(q, hash_rng)
            state = (buckets, rng)
        else:
            buckets, rng = state

        rng = jax.random.fold_in(rng, 2)

        seqlen = x.shape[0]
        assert int(buckets.shape[0]) == self.n_hashes * seqlen

        ticker = jax.lax.tie_in(x, np.arange(self.n_hashes * seqlen))
        buckets_and_t = seqlen * buckets + (ticker % seqlen)
        buckets_and_t = jax.lax.stop_gradient(buckets_and_t)

        # Hash-based sort ("s" at the start of variable names means "sorted")
        sbuckets_and_t, sticker = jax.lax.sort_key_val(buckets_and_t,
                                                       ticker,
                                                       dimension=-1)
        _, undo_sort = jax.lax.sort_key_val(sticker, ticker, dimension=-1)
        sbuckets_and_t = jax.lax.stop_gradient(sbuckets_and_t)
        sticker = jax.lax.stop_gradient(sticker)
        undo_sort = jax.lax.stop_gradient(undo_sort)

        st = (sticker % seqlen)
        sq = np.take(q, st, axis=0)
        sv = np.take(v, st, axis=0)

        mask_fn = functools.partial(mask_self_attention,
                                    causal=self.causal,
                                    exclude_self=True)
        q_info = st
        so, slogits = attend(
            sq,
            k=None,
            v=sv,
            q_chunk_len=self.chunk_len,
            n_chunks_before=self.n_chunks_before,
            n_chunks_after=self.n_chunks_after,
            mask_fn=mask_fn,
            q_info=q_info,
            dropout=self.attention_dropout,
            rng=rng,
        )

        def unsort_for_output_impl(so, slogits):
            o = np.take(so, undo_sort, axis=0)
            # Sorting is considerably faster than gather, but first we need to get the
            # XLA compiler to abandon the idea of fusing this sort with the input sort
            # (which introduces a computation cycle and leads to a crash).
            # TODO(kitaev): remove "sticker_" variable if XLA is fixed.
            sticker_ = sticker + jax.lax.convert_element_type(
                slogits[0] > 0, sticker.dtype)
            _, logits = jax.lax.sort_key_val(sticker_, slogits, dimension=-1)
            return o, logits

        def unsort_for_output_vjp(so, slogits):
            """Custom gradient for unsort_for_output."""
            so = jax.lax.stop_gradient(so)
            slogits = jax.lax.stop_gradient(slogits)
            o, logits = unsort_for_output_impl(so, slogits)

            def vjpfun(o_logits_grads):
                so_grad = np.take(o_logits_grads[0], sticker, axis=0)
                # TODO(kitaev): this exists to match the forward pass, but I'm not sure
                # if it's actually required.
                buckets_and_t_ = buckets_and_t + jax.lax.convert_element_type(
                    o_logits_grads[1][0] > 0, buckets_and_t.dtype)
                _, slogits_grad = jax.lax.sort_key_val(buckets_and_t_,
                                                       o_logits_grads[1],
                                                       dimension=-1)
                return (so_grad, slogits_grad)

            return (o, logits), vjpfun

        unsort_for_output = jax.custom_transforms(unsort_for_output_impl)
        jax.defvjp_all(unsort_for_output, unsort_for_output_vjp)
        o, logits = unsort_for_output_impl(so, slogits)

        if self.n_hashes > 1:
            o = np.reshape(o, (self.n_hashes, seqlen, o.shape[-1]))
            logits = np.reshape(logits, (self.n_hashes, seqlen, 1))
            probs = np.exp(logits - logsumexp(logits, axis=0, keepdims=True))
            o = np.sum(o * probs, axis=0)

        assert o.shape == (seqlen, w_v.shape[-1])
        out = np.matmul(o, w_o)
        return out, state
Exemplo n.º 22
0
def attend(
    q,
    k=None,
    v=None,
    q_chunk_len=None,
    kv_chunk_len=None,
    n_chunks_before=0,
    n_chunks_after=0,
    mask_fn=None,
    q_info=None,
    kv_info=None,
    dropout=0.0,
    rng=None,
):
    """Dot-product attention, with optional chunking and/or masking.

  Args:
    q: Query vectors, shape [q_len, d_qk]
    k: Key vectors, shape [kv_len, d_qk]; or None
    v: Value vectors, shape [kv_len, d_v]
    q_chunk_len: Set to non-zero to enable chunking for query vectors
    kv_chunk_len: Set to non-zero to enable chunking for key/value vectors
    n_chunks_before: Number of adjacent previous chunks to attend to
    n_chunks_after: Number of adjacent subsequent chunks to attend to
    mask_fn: TODO(kitaev) doc
    q_info: Query-associated metadata for masking
    kv_info: Key-associated metadata for masking
    dropout: Dropout rate
    rng: RNG for dropout

  Returns:
    A tuple (output, dots_logsumexp). The output has shape [q_len, d_v], and
    dots_logsumexp has shape [q_len]. The logsumexp of the attention
    probabilities is useful for combining multiple rounds of attention (as in
    LSH attention).
  """
    assert v is not None
    share_qk = (k is None)

    if q_info is None:
        q_info = np.arange(q.shape[-2])

    if kv_info is None and not share_qk:
        kv_info = np.arange(v.shape[-2])

    # Split q/k/v into chunks along the time axis, if desired.
    if q_chunk_len is not None:
        q = np.reshape(q, (-1, q_chunk_len, q.shape[-1]))
        q_info = np.reshape(q_info, (-1, q_chunk_len))

    if share_qk:
        assert kv_chunk_len is None or kv_chunk_len == q_chunk_len
        k = q
        kv_chunk_len = q_chunk_len
        kv_info = q_info
    elif kv_chunk_len is not None:
        k = np.reshape(k, (-1, kv_chunk_len, k.shape[-1]))
        kv_info = np.reshape(kv_info, (-1, kv_chunk_len))

    if kv_chunk_len is not None:
        v = np.reshape(v, (-1, kv_chunk_len, v.shape[-1]))

    if share_qk:
        k = length_normalized(k)
    k = k / np.sqrt(k.shape[-1])

    # Optionally include adjacent chunks.
    if q_chunk_len is not None or kv_chunk_len is not None:
        assert q_chunk_len is not None and kv_chunk_len is not None
    else:
        assert n_chunks_before == 0 and n_chunks_after == 0

    k = look_adjacent(k, n_chunks_before, n_chunks_after)
    v = look_adjacent(v, n_chunks_before, n_chunks_after)
    kv_info = look_adjacent(kv_info, n_chunks_before, n_chunks_after)

    # Dot-product attention.
    dots = np.matmul(q, np.swapaxes(k, -1, -2))

    # Masking
    if mask_fn is not None:
        dots = mask_fn(dots, q_info[..., :, None], kv_info[..., None, :])

    # Softmax.
    dots_logsumexp = logsumexp(dots, axis=-1, keepdims=True)
    dots = np.exp(dots - dots_logsumexp)

    if dropout > 0.0:
        assert rng is not None
        # Dropout is broadcast across the bin dimension
        dropout_shape = (dots.shape[-2], dots.shape[-1])
        # TODO(kitaev): verify that tie-in is safe to remove (in light of jax fix)
        keep_prob = jax.lax.tie_in(dots, 1.0 - dropout)
        keep = jax.random.bernoulli(rng, keep_prob, dropout_shape)
        multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(keep, keep_prob)
        dots = dots * multiplier

    # The softmax normalizer (dots_logsumexp) is used by multi-round LSH attn.
    out = np.matmul(dots, v)
    out = np.reshape(out, (-1, out.shape[-1]))
    dots_logsumexp = np.reshape(dots_logsumexp, (-1, ))
    return out, dots_logsumexp
Exemplo n.º 23
0
    def forward_unbatched(self, x, *, weights, state, update_state):
        w_q, w_v, w_o = weights

        q = np.matmul(x, w_q)
        v = np.matmul(x, w_v)

        if update_state:
            _, old_rng = state
            rng = jax.random.fold_in(old_rng, 0)
            hash_rng = jax.random.fold_in(rng, 1)
            buckets = self.hash_vectors(q, hash_rng)
            state = (buckets, rng)
        else:
            buckets, rng = state

        rng = jax.random.fold_in(rng, 2)

        seqlen = x.shape[0]
        assert int(buckets.shape[0]) == self.n_hashes * seqlen

        ticker = jax.lax.tie_in(x, np.arange(self.n_hashes * seqlen))
        buckets_and_t = seqlen * buckets + (ticker % seqlen)
        buckets_and_t = jax.lax.stop_gradient(buckets_and_t)

        # Hash-based sort ("s" at the start of variable names means "sorted")
        sbuckets_and_t, sticker = jax.lax.sort_key_val(buckets_and_t,
                                                       ticker,
                                                       dimension=-1)
        _, undo_sort = jax.lax.sort_key_val(sticker, ticker, dimension=-1)
        sbuckets_and_t = jax.lax.stop_gradient(sbuckets_and_t)
        sticker = jax.lax.stop_gradient(sticker)
        undo_sort = jax.lax.stop_gradient(undo_sort)

        st = (sticker % seqlen)
        sq = np.take(q, st, axis=0)
        sv = np.take(v, st, axis=0)

        mask_fn = functools.partial(mask_self_attention,
                                    causal=self.causal,
                                    exclude_self=True)
        q_info = st
        so, slogits = attend(
            sq,
            k=None,
            v=sv,
            q_chunk_len=self.chunk_len,
            n_chunks_before=self.n_chunks_before,
            n_chunks_after=self.n_chunks_after,
            mask_fn=mask_fn,
            q_info=q_info,
            dropout=self.attention_dropout,
            rng=rng,
        )

        # np.take(so, undo_sort, axis=0); np.take(slogits, undo_sort, axis=0) would
        # also work, but these helpers include performance optimizations for TPU.
        o = permute_via_gather(so, undo_sort, sticker, axis=0)
        logits = permute_via_sort(slogits, sticker, buckets_and_t, axis=-1)

        if self.n_hashes > 1:
            o = np.reshape(o, (self.n_hashes, seqlen, o.shape[-1]))
            logits = np.reshape(logits, (self.n_hashes, seqlen, 1))
            probs = np.exp(logits - logsumexp(logits, axis=0, keepdims=True))
            o = np.sum(o * probs, axis=0)

        assert o.shape == (seqlen, w_v.shape[-1])
        out = np.matmul(o, w_o)
        return out, state