Exemple #1
0
  def call(self, x, params, state, **kwargs):
    del kwargs
    seqlen = x.shape[1]
    d_head = x.shape[2]

    x = np.reshape(x, (-1, self._n_heads, seqlen, d_head))
    x = np.transpose(x, (0, 2, 1, 3))  # -> n_batch, seqlen, n_heads, d_head
    x = np.reshape(x, (-1, seqlen, self._n_heads * d_head))

    return np.dot(x, params), state
Exemple #2
0
 def predict(x, params=(), rng=None):
     """Predict function jited and parallelized as requested."""
     # On one device, jit and run.
     pred = mapped_predict(reshape_by_device(x, n_devices), params,
                           jax_random.split(rng, n_devices))
     # Need to reduce the [device, per-device-batch, ...] tensors back to
     # a [batch, ...] tensor. The tensors may be nested.
     if not isinstance(pred, (list, tuple)):  # Not nested.
         batch_size = pred.shape[0] * pred.shape[1]
         return np.reshape(pred, [batch_size] + list(pred.shape[2:]))
     batch_size = pred[0].shape[0] * pred[0].shape[1]
     return [np.reshape(p, [batch_size] + list(p.shape[2:])) for p in pred]
Exemple #3
0
def multigaussian_loss(preds, targets, ngauss=1):  # pylint: disable=invalid-name
    """Compute mixture of gaussians loss."""
    ndims = targets.shape[-1]
    logits = preds[:, :ngauss]
    mus = preds[:, ngauss:ngauss * (ndims + 1)]
    sigmas = preds[:, ngauss(ndims + 1):]
    sigmas = sigmas * sigmas + 1e-6  # Make positive.
    loglogits = logits - backend.logsumexp(logits, axis=-1, keepdims=True)
    mus = np.reshape(mus, [-1, ngauss, ndims])
    sigmas = np.reshape(sigmas, [-1, ngauss, ndims])
    targets = np.reshape(targets, [-1, 1, ndims])
    glogprobs = log_gaussian_diag_pdf(targets, mus, sigmas)
    return backend.logsumexp(loglogits + glogprobs, axis=-1)
 def call(self, x, params=(), state=(), **kwargs):
     del kwargs
     w, b = params
     x_shape = list(x.shape)
     if len(x_shape) > 4:
         self._check_nhwc()
         new_batch_dim = six.moves.reduce(operator.mul, x_shape[:-3])
         x = np.reshape(x, [new_batch_dim] + x_shape[-3:])
     res = backend.conv(x, w, self._strides, self._padding,
                        self._dimension_numbers, self._one) + b
     if len(x_shape) > 4:
         res = np.reshape(res, x_shape[:-3] + list(res.shape[-3:]))
     return res, state
Exemple #5
0
  def call(self, x, params, state, **kwargs):
    del kwargs
    seqlen = x.shape[1]
    res = np.dot(x, params)

    # n_batch, seqlen, n_heads*d_head -> n_batch, seqlen, n_heads, d_head
    res = np.reshape(res, (x.shape[0], seqlen, self._n_heads, self._d_head))
    # n_batch, seqlen, n_heads, d_head -> n_batch, n_heads, seqlen, d_head
    res = np.transpose(res, (0, 2, 1, 3))
    # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head
    res = np.reshape(res, (-1, seqlen, self._d_head))

    return res, state
Exemple #6
0
def reshape_by_device(train_data, num_devices):
    """Reshape the train_data into a shape [num_devices, ...]."""
    x, y = train_data
    x_shape, y_shape = list(x.shape), list(y.shape)
    assert x_shape[0] == y_shape[0]  # Same batch size.
    batch_size = x_shape[0]
    batch_size_per_device = batch_size // num_devices
    # We require that num_devices divides batch_size evenly.
    assert batch_size_per_device * num_devices == batch_size
    # New shapes.
    new_shape_prefix = [num_devices, batch_size_per_device]
    x = np.reshape(x, new_shape_prefix + x_shape[1:])
    y = np.reshape(y, new_shape_prefix + y_shape[1:])
    return x, y
Exemple #7
0
def Flatten(x, params, n_axes_to_keep=1, **kwargs):
    del params, kwargs
    if n_axes_to_keep >= len(x.shape):
        raise ValueError(
            "n_axes_to_keep[%d] should be less than input's rank[%d]" %
            (n_axes_to_keep, len(x.shape)))
    return np.reshape(x, (x.shape[:n_axes_to_keep] + (-1, )))
Exemple #8
0
def EncoderDecoderMask(x, **unused_kwargs):
    """Makes encoder-decoder mask from decoder input and a padding mask."""
    decoder_input, padding_mask = x
    padding_mask = np.reshape(
        padding_mask, (padding_mask.shape[0], 1, 1, padding_mask.shape[-1]))
    # Final mask shape is [batch, 1 for heads, decoder-len, encoder-len].
    return padding_mask + np.zeros((1, 1, decoder_input.shape[1], 1))
Exemple #9
0
 def combine(x):
     if len(x.shape) > 1:
         batch_size = x.shape[0] * x.shape[1]
         return np.reshape(x, [batch_size] + list(x.shape[2:]))
     # TODO(lukaszkaiser): is returning averages for scalars the right choice?
     # If it is only scalar, return the average.
     return np.mean(x, axis=0)
def SplitHeads(x, params, n_heads=1, **kwargs):
    del params, kwargs
    d_model = x.shape[-1]
    assert d_model % n_heads == 0
    d_head = d_model // n_heads
    n_batch = np.shape(x)[0]
    # n_batch, seqlen, d_model --> n_batch, n_heads, seqlen, d_head
    return np.transpose(np.reshape(x, (n_batch, -1, n_heads, d_head)),
                        (0, 2, 1, 3))
Exemple #11
0
    def predict(x, params=(), rng=None):
        """Predict function jited and parallelized as requested."""
        # On one device, jit and run.
        if num_devices == 1:
            return backend.jit(model_predict)(x, params, rng=rng)

        # Multi-devices, pmap and run.
        @functools.partial(backend.pmap, axis_name="batch")
        def mapped_predict(x, params, rng):
            return model_predict(x, params, rng=rng)

        pred = mapped_predict(reshape_by_device(x, num_devices), params,
                              jax_random.split(rng, num_devices))
        # Need to reduce the [device, per-device-batch, ...] tensors back to
        # a [batch, ...] tensor. The tensors may be nested.
        if not isinstance(pred, (list, tuple)):  # Not nested.
            batch_size = pred.shape[0] * pred.shape[1]
            return np.reshape(pred, [batch_size] + list(pred.shape[2:]))
        batch_size = pred[0].shape[0] * pred[0].shape[1]
        return [np.reshape(p, [batch_size] + list(p.shape[2:])) for p in pred]
  def hash_vectors(self, vecs, rng):
    # See https://arxiv.org/pdf/1509.02897.pdf
    # We sample a different random rotation for each round of hashing to
    # decrease the probability of hash misses.
    assert self.n_buckets % 2 == 0
    random_rotations_shape = (
        vecs.shape[-1],
        self.n_hashes if self._rehash_each_round else 1,
        self.n_buckets // 2)

    rng = jax.lax.tie_in(vecs, rng)
    rng, subrng = backend.random.split(rng)
    random_rotations = jax.random.normal(
        rng, random_rotations_shape).astype('float32')
    # TODO(lukaszkaiser): the dropout mask will be used for all rounds of
    # hashing, so it's shared between them. Check if that's what we want.
    dropped_vecs = self.drop_for_hash(vecs, subrng)
    rotated_vecs = np.einsum('tf,fhb->htb', dropped_vecs, random_rotations)
    rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1)

    if self._rehash_each_round:
      buckets = np.argmax(rotated_vecs, axis=-1)
      # buckets is now (self.n_hashes, seqlen). Next we add offsets so that
      # bucket numbers from different hashing rounds don't overlap.
      offsets = jax.lax.tie_in(buckets, np.arange(self.n_hashes))
      offsets = np.reshape(offsets * self.n_buckets, (-1, 1))
      buckets = np.reshape(buckets + offsets, (-1,))
    else:
      # In this configuration, we map each item to the top self.n_hashes buckets
      rotated_vecs = np.squeeze(rotated_vecs, 0)
      bucket_range = jax.lax.tie_in(vecs, np.arange(rotated_vecs.shape[-1]))
      bucket_range = np.reshape(bucket_range, (1, -1))
      bucket_range = np.broadcast_to(bucket_range, rotated_vecs.shape)

      _, buckets = jax.lax.sort_key_val(
          rotated_vecs, bucket_range, dimension=-1)
      buckets = buckets[:, -self.n_hashes:]
      buckets = np.reshape(np.moveaxis(buckets, 0, -1), (-1,))

    return buckets
Exemple #13
0
def PreparePairedSequenceBatch(source, target_in, pad=0):
    """Build masks for this batch.

  Args:
    source: (batch, source_len) array of integer-coded symbols for inputs
    target_in: (batch, batch_len) array of integer-coded symbols for targets
    pad: int: the padding symbol used to pad the above

  Returns:
    Prepared batch of tuple of arrays: source, input-target, shifted-target,
    source mask, target mask, source-target "memory" mask, minibatch token count
  """
    target = target_in[:, :-1]
    target_y = target_in[:, 1:]
    source_mask = np.reshape(source != pad,
                             (source.shape[0], 1, 1, source.shape[-1]))
    target_mask = MakeTargetMask(target, pad)
    memory_mask = (np.reshape(
        np.arange(target.shape[-1]) < source.shape[-1], [-1, 1]))
    ntokens = np.sum(target_y != pad)
    return (source, target, target_y, source_mask, target_mask, memory_mask,
            ntokens)
Exemple #14
0
def _reshape_by_device_single(x, n_devices):
    """Reshape x into a shape [n_devices, ...]."""
    x_shape = list(x.shape)
    batch_size = x_shape[0]
    batch_size_per_device = batch_size // n_devices
    # We require that n_devices divides batch_size evenly.
    if batch_size_per_device * n_devices != batch_size:
        logging.fatal(
            "We require that n_devices[%d] divides batch_size[%d] evenly.",
            n_devices, batch_size)
    # New shape.
    new_shape_prefix = [n_devices, batch_size_per_device]
    return np.reshape(x, new_shape_prefix + x_shape[1:])
Exemple #15
0
    def predict(x, params=(), rng=None):
        """Predict function jited and parallelized as requested."""
        # On one device, jit and run.
        if num_devices == 1:
            return backend.jit(model_predict)(x, params, rng=rng)

        # Multi-devices, pmap and run.
        @functools.partial(backend.pmap, axis_name="batch")
        def mapped_predict(x, params, rng):
            return model_predict(x, params, rng=rng)

        pred = mapped_predict(reshape_by_device(x, num_devices), params,
                              jax_random.split(rng, num_devices))
        batch_size = x.shape[0]
        return np.reshape(pred, [batch_size] + list(pred.shape[2:]))
Exemple #16
0
 def test_batch_norm(self):
     input_shape = (2, 3, 4)
     input_dtype = np.float32
     eps = 1e-5
     rng = backend.random.get_prng(0)
     inp1 = np.reshape(np.arange(np.prod(input_shape), dtype=input_dtype),
                       input_shape)
     m1 = 11.5  # Mean of this random input.
     v1 = 47.9167  # Variance of this random input.
     layer = normalization.BatchNorm(axis=(0, 1, 2))
     params, state = layer.initialize(input_shape, input_dtype, rng)
     onp.testing.assert_allclose(state[0], 0)
     onp.testing.assert_allclose(state[1], 1)
     self.assertEqual(state[2], 0)
     out, state = layer(inp1, params, state)
     onp.testing.assert_allclose(state[0], m1 * 0.001)
     onp.testing.assert_allclose(state[1], 0.999 + v1 * 0.001, rtol=1e-6)
     self.assertEqual(state[2], 1)
     onp.testing.assert_allclose(out, (inp1 - m1) / np.sqrt(v1 + eps),
                                 rtol=1e-6)
Exemple #17
0
  def predict(params, batch, rng=None):
    """Predict function jited and parallelized as requested."""
    # If not jit'ing, just run the function.
    if not jit_eval:
      return model_predict(params, batch, rng=rng)

    # On one device, jit and run.
    if num_devices == 1:
      return backend.jit(model_predict)(params, batch, rng=rng)

    # Multi-devices, pmap and run.
    @functools.partial(backend.pmap, axis_name="batch")
    def mapped_predict(params, batch, rng):
      return model_predict(params, batch, rng=rng)
    pred = mapped_predict(
        jax.replicate(params),
        reshape_by_device(batch, num_devices),
        jax.replicate(rng))
    batch_size = batch.shape[0]
    return np.reshape(pred, [batch_size] + list(pred.shape[2:]))
Exemple #18
0
 def _update_sketched(self, grads, params, m, v, opt_params):
   """Update for higher-rank parameters."""
   (learning_rate, momentum) = opt_params
   shape = params.shape
   rank = len(shape)
   reshaped_accumulators = [np.reshape(v[i], self._expanded_shape(shape, i))
                            for i in range(rank)]
   current_accumulator = self._minimum(reshaped_accumulators)
   current_accumulator += grads * grads
   accumulator_inv_sqrt = np.where(current_accumulator > 0.0,
                                   1.0 / np.sqrt(current_accumulator),
                                   np.zeros_like(current_accumulator))
   preconditioned_gradient = grads * accumulator_inv_sqrt
   m = (1.0 - momentum) * preconditioned_gradient + momentum * m
   params = params - (learning_rate * m).astype(params.dtype)
   for i in range(len(v)):
     axes = list(range(int(i))) + list(range(int(i) + 1, rank))
     dim_accumulator = np.amax(current_accumulator, axis=axes)
     v[i] = dim_accumulator
   return params, (m, v)
Exemple #19
0
 def _update_sketched(self, step, g, x, m, v):
     """Update for higher-rank parameters."""
     shape = x.shape
     rank = len(shape)
     reshaped_accumulators = [
         np.reshape(v[i], self._expanded_shape(shape, i))
         for i in range(rank)
     ]
     current_accumulator = self._minimum(reshaped_accumulators)
     current_accumulator += g * g
     accumulator_inv_sqrt = np.where(current_accumulator > 0.0,
                                     1.0 / np.sqrt(current_accumulator),
                                     np.zeros_like(current_accumulator))
     preconditioned_gradient = g * accumulator_inv_sqrt
     m = (1.0 -
          self._momentum) * preconditioned_gradient + self._momentum * m
     x = x - self.step_size(step) * m
     for i in range(len(v)):
         axes = list(range(int(i))) + list(range(int(i) + 1, rank))
         dim_accumulator = np.amax(current_accumulator, axis=axes)
         v[i] = dim_accumulator
     return x, (m, v)
Exemple #20
0
 def test_batch_norm(self):
     input_shape = (2, 3, 4)
     input_dtype = np.float32
     eps = 1e-5
     rng = backend.random.get_prng(0)
     inp1 = np.reshape(np.arange(np.prod(input_shape), dtype=input_dtype),
                       input_shape)
     m1 = 11.5
     v1 = 47.9167
     layer = normalization.BatchNorm(axis=(0, 1, 2))
     params, state = layer.initialize(input_shape, input_dtype, rng)
     onp.testing.assert_allclose(state[0], 0)
     onp.testing.assert_allclose(state[1], 0)
     self.assertEqual(state[2], 0)
     out, state = layer(inp1, params, state)
     onp.testing.assert_allclose(state[0], m1)
     onp.testing.assert_allclose(state[1], v1, rtol=1e-6)
     self.assertEqual(state[2], 1)
     onp.testing.assert_allclose(out, (inp1 - m1) / np.sqrt(v1 + eps),
                                 rtol=1e-6)
     inp2 = inp1 * 2 + 3
     m2 = m1 * 2 + 3
     v2 = v1 * 4
     m12 = (m1 + m2) / 2
     v12 = (v1 + v2) / 2
     out, state = layer(inp2, params, state)
     onp.testing.assert_allclose(state[0], m12)
     onp.testing.assert_allclose(state[1], v12, rtol=1e-6)
     self.assertEqual(state[2], 2)
     onp.testing.assert_allclose(out, (inp2 - m2) / np.sqrt(v2 + eps),
                                 rtol=1e-6)
     layer = normalization.BatchNorm(axis=(0, 1, 2), mode="eval")
     inp3 = inp1 * 5 + 7
     out, state_unchanged = layer(inp3, params, state)
     for i in range(3):
         onp.testing.assert_allclose(state_unchanged[i], state[i])
     onp.testing.assert_allclose(out, (inp3 - m12) / np.sqrt(v12 + eps),
                                 rtol=1e-6)
Exemple #21
0
 def combine(x):
     batch_size = x.shape[0] * x.shape[1]
     return np.reshape(x, [batch_size] + list(x.shape[2:]))
Exemple #22
0
 def unchunk_vectors(x):  # pylint: disable=invalid-name
     return np.reshape(x, (x.shape[0], -1, x.shape[-1]))
Exemple #23
0
 def chunk_scalars(x):  # pylint: disable=invalid-name
     return np.reshape(x, (x.shape[0], self.n_bins, -1))
Exemple #24
0
    def call(self, inputs, params=(), state=(), rng=None, **kwargs):
        del params, kwargs
        # We use the same vector as both a query and a key. For now we haven't
        # adjusted any of the surrounding code, so we still get a separate "key"
        # input that we ignore.
        qk, _, v = inputs
        seqlen = qk.shape[-2]

        # qk/v are n_hashes*n_batch*n_heads, seqlen, d_head
        # TODO(kitaev): is it faster to fuse this tiling into gather/scatter ops?
        qk = np.tile(qk, (self.n_hashes, 1, 1))
        v = np.tile(v, (self.n_hashes, 1, 1))

        # bins are n_hashes*n_batch*n_heads, seqlen
        # They specify which hash bucket the query/key/value vectors fall in.
        bins = self.hash_vectors(qk, rng=rng)

        # joint_t is n_hashes*n_batch*n_heads, seqlen
        joint_t = jax.lax.tie_in(qk, np.arange(seqlen))
        joint_t = np.reshape(joint_t, (1, seqlen))
        joint_t = np.broadcast_to(joint_t, qk.shape[:-1])

        assert int(
            (self.n_buckets_per_bin * self.n_bins + 1) * seqlen
        ) < 2**31, (
            'Potential 32-bit integer overflow; please double-check the code.')
        joint_bins_and_t = seqlen * bins + joint_t

        def chunk_scalars(x):  # pylint: disable=invalid-name
            return np.reshape(x, (x.shape[0], self.n_bins, -1))

        def chunk_vectors(x):  # pylint: disable=invalid-name
            return np.reshape(x, (x.shape[0], self.n_bins, -1, x.shape[-1]))

        def unchunk_vectors(x):  # pylint: disable=invalid-name
            return np.reshape(x, (x.shape[0], -1, x.shape[-1]))

        # Sort everything by bin number, with a secondary sort by time
        # (variables starting with "s" are sorted)
        _, sjoint_t = jax.lax.sort_key_val(joint_bins_and_t,
                                           joint_t,
                                           dimension=-1)
        _, undo_sort = jax.lax.sort_key_val(sjoint_t, joint_t, dimension=-1)
        # TODO(kitaev): why does jax flag integer indices as differentiable?
        # If we don't call stop_gradient here, custom gradients below won't work
        # because the primitive functions close over "differentiable" variables.
        sjoint_t = jax.lax.stop_gradient(sjoint_t)
        undo_sort = jax.lax.stop_gradient(undo_sort)

        # The backward pass of gather is in general a scatter operation, but we know
        # we're dealing with permutations so we use gather for the backward pass
        # too. This custom gradient should be about 2x faster than having jax infer
        # one that uses scatter ops instead.
        def permute_impl(vecs):
            assert len(vecs.shape) == 3
            return np.take_along_axis(vecs, sjoint_t[:, :, None], axis=-2)

        def unpermute_impl(vecs):
            assert len(vecs.shape) == 3
            return np.take_along_axis(vecs, undo_sort[:, :, None], axis=-2)

        @jax.custom_transforms
        def permute(vecs):
            return permute_impl(vecs)

        def permute_vjp(vecs):
            out_vecs = permute_impl(vecs)

            def vjpfun(grad):
                return (unpermute_impl(grad), )

            return out_vecs, vjpfun

        @jax.custom_transforms
        def unpermute(vecs):
            return unpermute_impl(vecs)

        def unpermute_vjp(vecs):
            out_vecs = unpermute_impl(vecs)

            def vjpfun(grad):
                return (permute_impl(grad), )

            return out_vecs, vjpfun

        jax.defvjp_all(permute, permute_vjp)
        jax.defvjp_all(unpermute, unpermute_vjp)

        sqk = permute(qk)
        sv = permute(v)

        # Split off a "bin" axis so that attention only occurs within chunks.
        bq_t = bkv_t = chunk_scalars(sjoint_t)
        bqk = chunk_vectors(sqk)
        bv = chunk_vectors(sv)

        # Hashing operates on unit-length vectors. Unnormalized query vectors are
        # fine because they effectively provide a learnable temperature for the
        # attention softmax, but normalizing keys is needed so that similarity for
        # the purposes of attention correctly corresponds to hash locality.
        bq = bqk
        bk = self.make_unit_length(bqk)

        # Allow each chunk to attend within itself, and also one chunk back. Chunk
        # boundaries might occur in the middle of a sequence of items from the
        # same bin, so this increases the chances of attending to relevant items.
        # TODO(kitaev): benchmark whether XLA pad operation is noticeably faster.
        bk_extra = np.concatenate([bk[:, -1:, :, :], bk[:, :-1, :, :]], axis=1)
        bk = np.concatenate([bk, bk_extra], axis=2)
        bv_extra = np.concatenate([bv[:, -1:, :, :], bv[:, :-1, :, :]], axis=1)
        bv = np.concatenate([bv, bv_extra], axis=2)
        bkv_t_extra = np.concatenate([bkv_t[:, -1:, :], bkv_t[:, :-1, :]],
                                     axis=1)
        bkv_t = np.concatenate([bkv_t, bkv_t_extra], axis=2)

        # Dot-product attention.
        dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(bq.shape[-1])

        # Causal masking
        mask = jax.lax.convert_element_type(
            jax.lax.lt(bq_t[:, :, :, None], bkv_t[:, :, None, :]), np.float32)
        dots = dots - 1e9 * mask

        # Mask out attention to self except when no other targets are available.
        self_mask = jax.lax.broadcasted_eye(dots.dtype, dots.shape, (2, 3))
        self_mask = jax.lax.tie_in(dots, self_mask)
        dots = dots - 32 * self_mask

        # Softmax.
        dots_logsumexp = backend.logsumexp(dots, axis=-1, keepdims=True)
        dots = np.exp(dots - dots_logsumexp)

        if self._hard_k > 0:
            top_k = np.sort(dots)[...,
                                  -self._hard_k]  # Get the top-kth weight.
            top_k = jax.lax.stop_gradient(top_k)
            dots -= top_k[..., np.newaxis]  # Subtract (be 0 for lower ones).
            dots = np.maximum(dots, 0)
            dots_sum = np.sum(dots, axis=-1,
                              keepdims=True)  # Sum to re-normalize.
            dots_logsumexp += np.log(dots_sum)  # Add it to the weight.
            dots /= dots_sum  # Re-normalize.

        bo = np.matmul(dots, bv)
        so = unchunk_vectors(bo)
        slogits = unchunk_vectors(dots_logsumexp)

        o = unpermute(so)
        logits = unpermute(slogits)

        o = np.reshape(o, (self.n_hashes, -1, seqlen, o.shape[-1]))
        logits = np.reshape(logits, (self.n_hashes, -1, seqlen, 1))
        probs = np.exp(logits -
                       backend.logsumexp(logits, axis=0, keepdims=True))
        out = np.sum(o * probs, axis=0)
        assert out.shape == inputs[2].shape

        return out, state
Exemple #25
0
    def call_and_grad(self, inputs, ct, rng=None, **kwargs):
        del kwargs
        # We use the same vector as both a query and a key. For now we haven't
        # adjusted any of the surrounding code, so we still get a separate "key"
        # input that we ignore.
        qk, ignored_k, v = inputs
        seqlen = qk.shape[-2]
        # qk/v are n_batch*n_heads, seqlen, d_head

        # bins are n_batch*n_heads, seqlen
        # They specify which hash bucket the query/key/value vectors fall in.
        bins = self.hash_vectors(qk, rng=rng)

        # joint_t is n_batch*n_heads, seqlen
        joint_t = jax.lax.tie_in(qk, np.arange(seqlen))
        joint_t = np.reshape(joint_t, (1, seqlen))
        joint_t = np.broadcast_to(joint_t, qk.shape[:-1])

        assert int((self.n_bins + 1) * seqlen) < 2**31, (
            'Potential 32-bit integer overflow; please double-check the code.')
        joint_bins_and_t = seqlen * bins + joint_t

        def chunk_scalars(x):  # pylint: disable=invalid-name
            return np.reshape(x, (x.shape[0], self.n_bins, -1))

        def chunk_vectors(x):  # pylint: disable=invalid-name
            return np.reshape(x, (x.shape[0], self.n_bins, -1, x.shape[-1]))

        def unchunk_vectors(x):  # pylint: disable=invalid-name
            return np.reshape(x, (x.shape[0], -1, x.shape[-1]))

        # Sort everything by bin number, with a secondary sort by time
        # (variables starting with "s" are sorted)
        _, sjoint_t = jax.lax.sort_key_val(joint_bins_and_t,
                                           joint_t,
                                           dimension=-1)

        sqk = np.take_along_axis(qk, sjoint_t[:, :, None], axis=-2)
        sv = np.take_along_axis(v, sjoint_t[:, :, None], axis=-2)

        if ct is not None:
            so_ct = np.take_along_axis(ct, sjoint_t[:, :, None], axis=-2)

        @jax.jit
        def binned_attn(sqk, sv):  # pylint: disable=invalid-name
            """Performs attention on sorted queries/keys/values."""
            # Split off a "bin" axis so that attention only occurs whithin chunks.
            bq_t = bkv_t = chunk_scalars(sjoint_t)
            bqk = chunk_vectors(sqk)
            bv = chunk_vectors(sv)

            # Hashing operates on unit-length vectors. Unnormalized query vectors are
            # fine because they effectively provide a learnable temperature for the
            # attention softmax, but normalizing keys is needed so that similarity for
            # the purposes of attention correctly corresponds to hash locality.
            bq = bqk
            bk = self.make_unit_length(bqk)

            # Allow each chunk to attend within itself, and also one chunk back. Chunk
            # boundaries might occur in the middle of a sequence of items from the
            # same bin, so this increases the chances of attending to relevant items.
            # TODO(kitaev): benchmark whether XLA pad operation is noticeably faster.
            bk_extra = np.concatenate([bk[:, -1:, :, :], bk[:, :-1, :, :]],
                                      axis=1)
            bk = np.concatenate([bk, bk_extra], axis=2)
            bv_extra = np.concatenate([bv[:, -1:, :, :], bv[:, :-1, :, :]],
                                      axis=1)
            bv = np.concatenate([bv, bv_extra], axis=2)
            bkv_t_extra = np.concatenate([bkv_t[:, -1:, :], bkv_t[:, :-1, :]],
                                         axis=1)
            bkv_t = np.concatenate([bkv_t, bkv_t_extra], axis=2)

            # Dot-product attention.
            dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(
                bq.shape[-1])

            # Causal masking
            mask = jax.lax.convert_element_type(
                jax.lax.lt(bq_t[:, :, :, None], bkv_t[:, :, None, :]),
                np.float32)
            dots = dots - 1e9 * mask

            # Mask out attention to self except when no other targets are available.
            self_mask = jax.lax.broadcasted_eye(dots.dtype, dots.shape, (2, 3))
            self_mask = jax.lax.tie_in(dots, self_mask)
            dots = dots - 32 * self_mask

            # Softmax.
            dots = np.exp(dots -
                          backend.logsumexp(dots, axis=-1, keepdims=True))
            bo = np.matmul(dots, bv)

            so = unchunk_vectors(bo)
            return so

        @jax.jit
        def binned_attn_vjp(sqk, sv, so_ct):  # pylint: disable=invalid-name
            so, vjpfun = jax.vjp(binned_attn, sqk, sv)
            sqkv_ct = vjpfun(so_ct)
            return so, sqkv_ct

        if ct is None:
            so = binned_attn(sqk, sv)
            _, undo_sort = jax.lax.sort_key_val(sjoint_t,
                                                joint_t,
                                                dimension=-1)
            out = np.take_along_axis(so, undo_sort[:, :, None], axis=-2)
            return out, None
        else:
            # Jax can construct a backward pass automatically, but it's about 2x
            # slower than writing our own. The main reason is that the backward pass
            # of gather is in general a scatter operation, but we know we're dealing
            # with permutations so we use gather for the backward pass too.
            so, (sqk_ct, sv_ct) = binned_attn_vjp(sqk, sv, so_ct)

            _, undo_sort = jax.lax.sort_key_val(sjoint_t,
                                                joint_t,
                                                dimension=-1)
            out = np.take_along_axis(so, undo_sort[:, :, None], axis=-2)

            qk_ct = np.take_along_axis(sqk_ct, undo_sort[:, :, None], axis=-2)
            v_ct = np.take_along_axis(sv_ct, undo_sort[:, :, None], axis=-2)

            return out, (qk_ct, np.zeros_like(ignored_k), v_ct)
Exemple #26
0
def PaddingMask(x, params, pad=0, **kwargs):
    del params, kwargs
    return np.reshape(x != pad, (x.shape[0], 1, 1, x.shape[-1]))
Exemple #27
0
 def JoinHeads(x):  # pylint: disable=invalid-name
     return np.reshape(np.transpose(x, (0, 2, 1, 3)),
                       (nbatch, -1, n_heads * d_head))
Exemple #28
0
 def SplitHeads(x):
     return np.transpose(np.reshape(x, (nbatch, -1, n_heads, d_head)),
                         (0, 2, 1, 3))
 def unchunk_rank4(x):
     return np.reshape(x, (x.shape[0], x.shape[1], -1, x.shape[-1]))
 def chunk_rank4(x):
     return np.reshape(
         x, (x.shape[0], x.shape[1], self.n_bins, -1, x.shape[-1]))