def CombineHeadsPos(x, n_heads=1, **unused_kwargs):
  """Mix x = (x0, p0, ..., xH, pH) into (x0, ...., xH), p_combined.

  The positions are averaged as vectors.

  Args:
    x: input vector, concatenated (x0, p0, ..., xH, pH).
    n_heads: number of heads.

  Returns:
    the vector with combined xs and one with combined positions.
  """
  seqlen = x.shape[1]
  d_head = x.shape[2]
  x = np.reshape(x, (-1, n_heads, seqlen, d_head))
  x = np.transpose(x, (0, 2, 1, 3))  # -> n_batch, seqlen, n_heads, d_head
  x = np.reshape(x, (-1, seqlen, n_heads * d_head))
  head_size = int(d_head) - POS_VECTOR_SIZE
  res, positions, idx = [], [], 0
  for _ in range(n_heads):
    res.append(x[:, :, idx:idx+head_size])
    idx += head_size
    positions.append(x[:, :, idx:idx+POS_VECTOR_SIZE])
    idx += POS_VECTOR_SIZE
  combined_position = sum(positions) / float(len(positions))
  return np.concatenate(res, axis=-1), combined_position
Ejemplo n.º 2
0
    def forward_with_state(self,
                           inputs,
                           weights=base.EMPTY_WEIGHTS,
                           state=base.EMPTY_STATE,
                           rng=None,
                           **kwargs):
        embs = []
        for ax_emb in weights:
            ax_emb = np.broadcast_to(ax_emb, (inputs.shape[0], ) +
                                     self._shape + (ax_emb.shape[-1], ))
            embs.append(ax_emb)
        emb = np.concatenate(embs, -1)

        if self._mode == 'predict':
            assert self._dropout == 0.0
            emb = np.reshape(emb, (inputs.shape[0], -1, emb.shape[-1]))
            return inputs + emb[:, state, :][:, None, :], state + 1
        elif self._dropout == 0:
            return inputs + np.reshape(emb, inputs.shape), state
        else:
            noise_shape = list(emb.shape)
            for dim in self._dropout_broadcast_dims:
                noise_shape[dim] = 1
            keep_prob = 1.0 - self._dropout
            if backend.get_name() == 'jax':
                keep_prob = jax.lax.tie_in(
                    inputs, np.full((), keep_prob, dtype=inputs.dtype))
            keep = backend.random.bernoulli(rng, keep_prob, tuple(noise_shape))
            multiplier = keep.astype(inputs.dtype) / keep_prob

            return inputs + np.reshape(emb * multiplier, inputs.shape), state
Ejemplo n.º 3
0
    def forward(self, x, weights):
        seqlen = x.shape[1]
        d_head = x.shape[2]

        x = np.reshape(x, (-1, self._n_heads, seqlen, d_head))
        x = np.transpose(x,
                         (0, 2, 1, 3))  # -> n_batch, seqlen, n_heads, d_head
        x = np.reshape(x, (-1, seqlen, self._n_heads * d_head))

        return np.dot(x, weights)
 def forward(self, x, weights):
     w, b = weights
     x_shape = list(x.shape)
     if len(x_shape) > 4:
         self._check_nhwc()
         new_batch_dim = six.moves.reduce(operator.mul, x_shape[:-3])
         x = np.reshape(x, [new_batch_dim] + x_shape[-3:])
     res = backend.conv(x, w, self._strides, self._padding,
                        self._dimension_numbers, self._one) + b
     if len(x_shape) > 4:
         res = np.reshape(res, x_shape[:-3] + list(res.shape[-3:]))
     return res
Ejemplo n.º 5
0
    def forward(self, x, weights):
        seqlen = x.shape[1]
        res = np.dot(x, weights)

        # n_batch, seqlen, n_heads*d_head -> n_batch, seqlen, n_heads, d_head
        res = np.reshape(res,
                         (x.shape[0], seqlen, self._n_heads, self._d_head))
        # n_batch, seqlen, n_heads, d_head -> n_batch, n_heads, seqlen, d_head
        res = np.transpose(res, (0, 2, 1, 3))
        # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head
        res = np.reshape(res, (-1, seqlen, self._d_head))

        return res
Ejemplo n.º 6
0
def multigaussian_loss(preds, targets, ngauss=1):  # pylint: disable=invalid-name
  """Compute mixture of gaussians loss."""
  ndims = targets.shape[-1]
  logits = preds[:, :ngauss]
  mus = preds[:, ngauss:ngauss*(ndims + 1)]
  sigmas = preds[:, ngauss(ndims + 1):]
  sigmas = sigmas * sigmas + 1e-6  # Make positive.
  loglogits = logits - backend.logsumexp(logits, axis=-1, keepdims=True)
  mus = np.reshape(mus, [-1, ngauss, ndims])
  sigmas = np.reshape(sigmas, [-1, ngauss, ndims])
  targets = np.reshape(targets, [-1, 1, ndims])
  glogprobs = log_gaussian_diag_pdf(targets, mus, sigmas)
  return backend.logsumexp(loglogits + glogprobs, axis=-1)
Ejemplo n.º 7
0
def EncoderDecoderMask(x, **unused_kwargs):
    """Makes encoder-decoder mask from decoder input and a padding mask."""
    decoder_input, padding_mask = x
    padding_mask = np.reshape(
        padding_mask, (padding_mask.shape[0], 1, 1, padding_mask.shape[-1]))
    # Final mask shape is [batch, 1 for heads, decoder-len, encoder-len].
    return padding_mask + np.zeros((1, 1, decoder_input.shape[1], 1))
Ejemplo n.º 8
0
def Unchunk(x, weights, n_sections=2, **kwargs):
    del weights, kwargs
    assert x.shape[0] % n_sections == 0
    return np.reshape(x, (
        x.shape[0] // n_sections,
        x.shape[1] * n_sections,
    ) + x.shape[2:])
Ejemplo n.º 9
0
    def Init(shape, rng):
        """Returns orthogonalized random normal values with the given `shape`."""
        # Have at least 2 elements in shape.
        cur_shape = list(shape)
        while len(cur_shape) < 2:
            cur_shape = [1] + cur_shape

        # Flatten the input shape with the last dimension remaining.
        n_rows = 1
        for dim in cur_shape[:-1]:
            n_rows *= dim
        n_cols = cur_shape[-1]
        flat_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows, n_cols)

        # Generate a random matrix
        a = random.normal(rng, flat_shape, dtype=np.float32)

        # Compute the qr factorization
        q, r = np.linalg.qr(a)

        # Make Q uniform
        d = np.diag(r)
        q *= np.sign(d)

        # Transpose and reshape back q if needed.
        if n_rows < n_cols:
            q = np.transpose(q)
        q = np.reshape(q, shape)

        # Return scaled as requested.
        return stddev * q
 def forward(self, inp, weights):
   """Reshape input to have heads dimension and concatenate positions there."""
   x = inp[0]
   n_batches, seqlen = x.shape[0], x.shape[1]
   d_head = x.shape[-1] // self._n_heads
   res = np.reshape(x, (n_batches, seqlen, self._n_heads, d_head))
   res = np.transpose(res, (0, 2, 1, 3))  # (batch, heads, len, depth)
   if self._n_pos == 1:  # Just one position given, tile into each head.
     pos_shape = list(res.shape)[:-1] + [inp[1].shape[-1]]
     pos = inp[1][:, None, :, :] + np.zeros(pos_shape)  # Add 0 to broadcast.
   else:  # As many positions as heads, concatenate them in.
     pos = [p[:, None, :, :] for p in inp[1:]]
     pos = np.concatenate(pos, axis=1)
   res = np.concatenate([res, pos], axis=-1)
   # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head
   res = np.reshape(res, (-1, seqlen, d_head + POS_VECTOR_SIZE))
   return res
 def _update_sketched(self, grads, params, m, v, opt_params):
   """Update for higher-rank parameters."""
   learning_rate = opt_params['learning_rate']
   momentum = opt_params['momentum']
   shape = params.shape
   rank = len(shape)
   reshaped_accumulators = [np.reshape(v[i], self._expanded_shape(shape, i))
                            for i in range(rank)]
   current_accumulator = self._minimum(reshaped_accumulators)
   current_accumulator += grads * grads
   accumulator_inv_sqrt = np.where(current_accumulator > 0.0,
                                   1.0 / np.sqrt(current_accumulator),
                                   np.zeros_like(current_accumulator))
   preconditioned_gradient = grads * accumulator_inv_sqrt
   m = (1.0 - momentum) * preconditioned_gradient + momentum * m
   params = params - (learning_rate * m).astype(params.dtype)
   for i in range(len(v)):
     axes = list(range(int(i))) + list(range(int(i) + 1, rank))
     dim_accumulator = np.amax(current_accumulator, axis=axes)
     v[i] = dim_accumulator
   return params, (m, v)
Ejemplo n.º 12
0
 def test_batch_norm(self):
   input_shape = (2, 3, 4)
   input_dtype = np.float32
   input_signature = ShapeDtype(input_shape, input_dtype)
   eps = 1e-5
   inp1 = np.reshape(np.arange(np.prod(input_shape), dtype=input_dtype),
                     input_shape)
   m1 = 11.5  # Mean of this random input.
   v1 = 47.9167  # Variance of this random input.
   layer = normalization.BatchNorm(axis=(0, 1, 2))
   _, _ = layer.initialize_once(input_signature)
   state = layer.state
   onp.testing.assert_allclose(state[0], 0)
   onp.testing.assert_allclose(state[1], 1)
   self.assertEqual(state[2], 0)
   out = layer(inp1)
   state = layer.state
   onp.testing.assert_allclose(state[0], m1 * 0.001)
   onp.testing.assert_allclose(state[1], 0.999 + v1 * 0.001, rtol=1e-6)
   self.assertEqual(state[2], 1)
   onp.testing.assert_allclose(out, (inp1 - m1) / np.sqrt(v1 + eps),
                               rtol=1e-6)
Ejemplo n.º 13
0
    def single_call(self, qk, v, buckets, rng=None):
        # We use the same vector as both a query and a key.
        seqlen = qk.shape[-2]
        assert int(buckets.shape[0]) == self.n_hashes * seqlen

        ticker = jax.lax.tie_in(qk, np.arange(self.n_hashes * seqlen))
        buckets_and_t = seqlen * buckets + (ticker % seqlen)
        buckets_and_t = jax.lax.stop_gradient(buckets_and_t)

        # Hash-based sort ("s" at the start of variable names means "sorted")
        sbuckets_and_t, sticker = jax.lax.sort_key_val(buckets_and_t,
                                                       ticker,
                                                       dimension=-1)
        _, undo_sort = jax.lax.sort_key_val(sticker, ticker, dimension=-1)
        sbuckets_and_t = jax.lax.stop_gradient(sbuckets_and_t)
        sticker = jax.lax.stop_gradient(sticker)
        undo_sort = jax.lax.stop_gradient(undo_sort)

        st = (sticker % seqlen)
        sqk = np.take(qk, st, axis=0)
        sv = np.take(v, st, axis=0)

        # Split off a "bin" axis so that attention only occurs within chunks.
        bq_t = bkv_t = np.reshape(st, (self.n_hashes * self.n_bins, -1))
        bqk = np.reshape(sqk, (self.n_hashes * self.n_bins, -1, sqk.shape[-1]))
        bv = np.reshape(sv, (self.n_hashes * self.n_bins, -1, sv.shape[-1]))
        bq_buckets = bkv_buckets = np.reshape(
            sbuckets_and_t // seqlen, (self.n_hashes * self.n_bins, -1))

        # Hashing operates on unit-length vectors. Unnormalized query vectors are
        # fine because they effectively provide a learnable temperature for the
        # attention softmax, but normalizing keys is needed so that similarity for
        # the purposes of attention correctly corresponds to hash locality.
        bq = bqk
        bk = self.make_unit_length(bqk)

        # Allow each chunk to attend within itself, and also one chunk back. Chunk
        # boundaries might occur in the middle of a sequence of items from the
        # same bucket, so this increases the chances of attending to relevant items.
        # TODO(kitaev): benchmark whether XLA pad operation is noticeably faster.
        def look_one_back(x):
            if len(x.shape) == 2:
                x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0)
            else:
                x_extra = np.concatenate([x[-1:, :, :], x[:-1, :, :]], axis=0)
            return np.concatenate([x, x_extra], axis=1)

        bk = look_one_back(bk)
        bv = look_one_back(bv)
        bkv_t = look_one_back(bkv_t)
        bkv_buckets = look_one_back(bkv_buckets)

        # Dot-product attention.
        dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(bq.shape[-1])

        # Causal masking
        mask = jax.lax.convert_element_type(
            jax.lax.lt(bq_t[:, :, None], bkv_t[:, None, :]), np.float32)
        dots = dots - 1e9 * mask

        # Mask out attention to self except when no other targets are available.
        self_mask = jax.lax.convert_element_type(
            jax.lax.eq(bq_t[:, :, None], bkv_t[:, None, :]), np.float32)
        dots = dots - 1e5 * self_mask

        # Mask out attention to other hash buckets.
        if not self._attend_across_buckets:
            bucket_mask = jax.lax.convert_element_type(
                jax.lax.ne(bq_buckets[:, :, None], bkv_buckets[:, None, :]),
                np.float32)
            dots = dots - 1e7 * bucket_mask

        # Don't double-count query-key pairs across multiple rounds of hashing.
        # There are two possible strategies here. (1) The default is to count how
        # many times a query-key pair is repeated, and to lower its log-prob
        # correspondingly at each repetition. (2) When hard_k is set, the code
        # instead masks all but the first occurence of each query-key pair.
        # TODO(kitaev): is one strategy faster or more numerically stable?
        if not self._allow_duplicate_attention:
            locs1 = undo_sort // bq_t.shape[-1]
            locs2 = (locs1 + 1) % (self.n_hashes * self.n_bins)
            if not self._attend_across_buckets:
                locs1 = buckets * (self.n_hashes * self.n_bins) + locs1
                locs2 = buckets * (self.n_hashes * self.n_bins) + locs2
            locs = np.moveaxis(
                np.concatenate([
                    np.reshape(locs1, (self.n_hashes, seqlen)),
                    np.reshape(locs2, (self.n_hashes, seqlen)),
                ], 0), 0, -1)  # produces shape (seqlen, 2 * self.n_hashes)
            slocs = np.take(locs, st, axis=0)
            b_locs = np.reshape(
                slocs, (self.n_hashes * self.n_bins, -1, 2 * self.n_hashes))
            # Queries always use the primary location (based on locs1).
            b_locs1 = b_locs[:, :, None, :self.n_hashes]
            if self._hard_k > 0:
                range_n_hashes = jax.lax.tie_in(b_locs,
                                                np.arange(self.n_hashes))
                nouse_locs = (range_n_hashes[:, None] >
                              range_n_hashes[None, :])
                nouse_locs = 2 * nouse_locs - 1  # 1 = use, -1 = don't use
                nouse_locs = np.reshape(
                    np.broadcast_to(
                        nouse_locs[:, None, :],
                        (self.n_hashes, self.n_bins, self.n_hashes)),
                    (self.n_hashes * self.n_bins, 1, 1, self.n_hashes))
                b_locs1 = b_locs1 * nouse_locs
            bq_locs = np.broadcast_to(b_locs1,
                                      b_locs.shape[:2] + (2, self.n_hashes))
            bq_locs = np.reshape(bq_locs, b_locs.shape)
            bkv_locs = look_one_back(b_locs)

            dup_counts = np.sum(jax.lax.convert_element_type(
                jax.lax.eq(bq_locs[:, :, None, :], bkv_locs[:, None, :, :]),
                np.float32),
                                axis=-1)
            assert dup_counts.shape == dots.shape
            if self._hard_k > 0:
                dots = dots - 1e7 * jax.lax.stop_gradient(dup_counts)
            else:
                dots = dots - jax.lax.stop_gradient(np.log(dup_counts + 1e-9))

        # Each query only attends to the top k most relevant keys.
        if self._hard_k > 0:
            b_top_dots = np.sort(dots)[...,
                                       -self._hard_k:]  # Get the top k dots.
            b_top_dots = jax.lax.stop_gradient(b_top_dots)
            s_top_dots = np.reshape(b_top_dots, (-1, self._hard_k))
            top_dots = np.take(s_top_dots, undo_sort, axis=0)

            merged_top_dots = np.moveaxis(
                np.reshape(top_dots, (self.n_hashes, seqlen, self._hard_k)), 0,
                -1)
            merged_top_dots = np.reshape(merged_top_dots, (seqlen, -1))

            dots_thresh = np.sort(merged_top_dots)[:, -self._hard_k]
            # It's possible to compute the partition function at this point, but right
            # now this codepath isn't set up for backprop, and there might also be
            # issues computing it this way if two dot-products are exactly equal.

            sdots_thresh = dots_thresh[st]
            bdots_thresh = np.reshape(sdots_thresh,
                                      (self.n_hashes * self.n_bins, -1))
            bdots_thresh = jax.lax.stop_gradient(bdots_thresh)

            top_k_mask = jax.lax.convert_element_type(
                dots < bdots_thresh[..., None], np.float32)
            dots = dots - 1e7 * jax.lax.stop_gradient(top_k_mask)

        # Softmax.
        dots_logsumexp = backend.logsumexp(dots, axis=-1, keepdims=True)
        dots = np.exp(dots - dots_logsumexp)

        if self._dropout > 0.0:
            # Dropout is broadcast across the bin dimension
            dropout_shape = (1, dots.shape[-2], dots.shape[-1])
            keep_prob = jax.lax.tie_in(dots, 1.0 - self._dropout)
            keep = backend.random.bernoulli(rng, keep_prob, dropout_shape)
            multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(
                keep, keep_prob)
            dots = dots * multiplier

        bo = np.matmul(dots, bv)
        so = np.reshape(bo, (-1, bo.shape[-1]))
        slogits = np.reshape(dots_logsumexp, (-1, ))

        def unsort_for_output_impl(so, slogits):
            o = np.take(so, undo_sort, axis=0)
            # Sorting is considerably faster than gather, but first we need to get the
            # XLA compiler to abandon the idea of fusing this sort with the input sort
            # (which introduces a computation cycle and leads to a crash).
            # TODO(kitaev): remove "sticker_" variable if XLA is fixed.
            sticker_ = sticker + jax.lax.convert_element_type(
                slogits[0] > 0, sticker.dtype)
            _, logits = jax.lax.sort_key_val(sticker_, slogits, dimension=-1)
            return o, logits

        def unsort_for_output_vjp(so, slogits):
            """Custom gradient for unsort_for_output."""
            so = jax.lax.stop_gradient(so)
            slogits = jax.lax.stop_gradient(slogits)
            o, logits = unsort_for_output_impl(so, slogits)

            def vjpfun(o_logits_grads):
                so_grad = np.take(o_logits_grads[0], sticker, axis=0)
                # TODO(kitaev): this exists to match the forward pass, but I'm not sure
                # if it's actually required.
                buckets_and_t_ = buckets_and_t + jax.lax.convert_element_type(
                    o_logits_grads[1][0] > 0, buckets_and_t.dtype)
                _, slogits_grad = jax.lax.sort_key_val(buckets_and_t_,
                                                       o_logits_grads[1],
                                                       dimension=-1)
                return (so_grad, slogits_grad)

            return (o, logits), vjpfun

        unsort_for_output = jax.custom_transforms(unsort_for_output_impl)
        jax.defvjp_all(unsort_for_output, unsort_for_output_vjp)
        o, logits = unsort_for_output_impl(so, slogits)

        if self.n_hashes == 1:
            out = o
        else:
            o = np.reshape(o, (self.n_hashes, seqlen, o.shape[-1]))
            logits = np.reshape(logits, (self.n_hashes, seqlen, 1))
            probs = np.exp(logits -
                           backend.logsumexp(logits, axis=0, keepdims=True))
            out = np.sum(o * probs, axis=0)

        assert out.shape == v.shape
        return out
Ejemplo n.º 14
0
    def hash_vectors(self, vecs, rng):
        # See https://arxiv.org/pdf/1509.02897.pdf
        # We sample a different random rotation for each round of hashing to
        # decrease the probability of hash misses.
        assert self.n_buckets % 2 == 0

        # If we factorize the hash, find a factor dividing n_buckets nicely.
        rot_size, factor_list = self.n_buckets, [self.n_buckets]
        if self._factorize_hash:
            # If we are given a list of factors, verify it and use later.
            if isinstance(self._factorize_hash, list):
                rot_size, product = 0, 1
                factor_list = self._factorize_hash
                for factor in factor_list:
                    assert factor % 2 == 0
                    product *= factor
                    rot_size += factor
                assert product == self.n_buckets
            else:  # Find one factor if just set to True.
                # We want to represent self.n_buckets = factor * rest so that
                # (1) both factor and rest are even, and (2) factor + rest is minimal.
                # To compute this we start from factor = sqrt(n_buckets) and go down
                # with it until we find one that satisfies the constraints above.
                factor = int(math.sqrt(self.n_buckets))
                while factor > 0 and not (self.n_buckets % factor == 0
                                          and factor % 2 == 0 and
                                          (self.n_buckets // factor) % 2 == 0):
                    factor -= 1
                if factor > 2:  # Factor of 2 does not warrant the effort.
                    rot_size = factor + (self.n_buckets // factor)
                    factor_list = [factor, self.n_buckets // factor]

        rotations_shape = (vecs.shape[-1],
                           self.n_hashes if self._rehash_each_round else 1,
                           rot_size // 2)

        rng = jax.lax.tie_in(vecs, rng)
        rng, subrng = backend.random.split(rng)
        random_rotations = self._sample_rotation(rotations_shape, vecs, rng)

        # TODO(lukaszkaiser): the dropout mask will be used for all rounds of
        # hashing, so it's shared between them. Check if that's what we want.
        dropped_vecs = self.drop_for_hash(vecs, subrng)
        rotated_vecs = np.einsum('tf,fhb->htb', dropped_vecs, random_rotations)

        if self._rehash_each_round:
            if self._factorize_hash and len(factor_list) > 1:
                # We factorized self.n_buckets as the product of factor_list.
                # Get the buckets for them and combine.
                buckets, cur_sum, cur_product = None, 0, 1
                for factor in factor_list:
                    rv = rotated_vecs[..., cur_sum:cur_sum + (factor // 2)]
                    cur_sum += factor // 2
                    rv = np.concatenate([rv, -rv], axis=-1)
                    if buckets is None:
                        buckets = np.argmax(rv, axis=-1)
                    else:
                        buckets += cur_product * np.argmax(rv, axis=-1)
                    cur_product *= factor
            else:
                rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs],
                                              axis=-1)
                buckets = np.argmax(rotated_vecs, axis=-1)
            # buckets is now (self.n_hashes, seqlen). Next we add offsets so that
            # bucket numbers from different hashing rounds don't overlap.
            offsets = jax.lax.tie_in(buckets, np.arange(self.n_hashes))
            offsets = np.reshape(offsets * self.n_buckets, (-1, 1))
            buckets = np.reshape(buckets + offsets, (-1, ))
        else:
            assert not self._factorize_hash
            rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs],
                                          axis=-1)
            # In this configuration, we map each item to the top self.n_hashes buckets
            rotated_vecs = np.squeeze(rotated_vecs, 0)
            bucket_range = jax.lax.tie_in(vecs,
                                          np.arange(rotated_vecs.shape[-1]))
            bucket_range = np.reshape(bucket_range, (1, -1))
            bucket_range = np.broadcast_to(bucket_range, rotated_vecs.shape)

            _, buckets = jax.lax.sort_key_val(rotated_vecs,
                                              bucket_range,
                                              dimension=-1)
            buckets = buckets[:, -self.n_hashes:]
            buckets = np.reshape(np.moveaxis(buckets, 0, -1), (-1, ))

        return buckets
Ejemplo n.º 15
0
    def _sample_rotation(self, shape, vecs, rng):
        """Samples a rotation matrix, either randomly or based on `vecs`."""

        if not self._data_rotation:
            return jax.random.normal(rng, shape).astype('float32')

        assert len(shape) == 3
        unused_n_dim, n_hashes, r_div_2 = shape

        assert len(vecs.shape) == 2
        n_vecs = vecs.shape[0]

        rng1, rng2 = backend.random.split(rng, num=2)

        # We need to sample 2 * n_hashes * r_div_2 vectors from `vecs` at random.
        num_needed = 2 * n_hashes * r_div_2
        if n_vecs < num_needed:
            # shape = (n_hashes, r_div_2)
            random_idxs_1 = jax.random.randint(rng1, (n_hashes, r_div_2), 0,
                                               n_vecs)
            random_idxs_2 = jax.random.randint(rng2, (n_hashes, r_div_2), 0,
                                               n_vecs)
        else:
            # Sample without replacement.
            shuffled_indices = jax.random.shuffle(rng1, np.arange(n_vecs))
            random_idxs = np.reshape(shuffled_indices[:num_needed],
                                     (2, n_hashes, r_div_2))
            random_idxs_1 = random_idxs[0]
            random_idxs_2 = random_idxs[1]

        if self._data_rotation_farthest:
            # shape = (n_hashes * r_div_2, )
            random_idxs_1 = np.reshape(random_idxs_1, (-1, ))
            random_vecs_1 = vecs[random_idxs_1]

            # Sample candidates for vec2s.
            rng, subrng = backend.random.split(rng)
            # shape = (self._data_rotation_farthest_num, n_hashes * r_div_2)
            candidate_idxs_2 = jax.random.randint(
                subrng, (self._data_rotation_farthest_num, n_hashes * r_div_2),
                0, n_vecs)
            candidate_vecs_2 = vecs[candidate_idxs_2]
            # shape = candidate_idxs_2.shape
            distances = -np.abs(
                np.einsum('hd,chd->ch', random_vecs_1, candidate_vecs_2))
            # shape = (n_hashes * r_div_2,)
            farthest_idxs = np.argmax(distances, axis=0)
            # candidate_vecs_2.shape
            random_vecs_2 = candidate_vecs_2[farthest_idxs,
                                             np.arange(n_hashes * r_div_2)]

            # reshape to (n_hashes, r_div_2, n_dim)
            random_vecs_1 = np.reshape(random_vecs_1, (n_hashes, r_div_2, -1))
            random_vecs_2 = np.reshape(random_vecs_2, (n_hashes, r_div_2, -1))
        else:
            # shape = (n_hashes, r_div_2, n_dim)
            random_vecs_1 = vecs[random_idxs_1]
            random_vecs_2 = vecs[random_idxs_2]

        # shape = (n_dim, n_hashes, r_div_2)
        return np.transpose(random_vecs_2 - random_vecs_1, axes=[2, 0, 1])
Ejemplo n.º 16
0
    def _forward_train_eval(self, inputs, rng):
        (inputs, original_len, n_bins) = self._pad_inputs(inputs)
        q, k, v = inputs
        seqlen = q.shape[-2]
        # q/k/v are n_batch*n_heads, seqlen, d_head
        # Time indices for causal masking.
        t = jax.lax.tie_in(q, np.arange(seqlen))

        # Split off a "bin" axis for chunks of consecutive items.
        bq_t = np.reshape(t, (n_bins, -1))
        bq = np.reshape(q, (q.shape[0], n_bins, -1, q.shape[-1]))
        if self._share_qk:
            bk = self.make_unit_length(bq)
        else:
            bk = np.reshape(k, (k.shape[0], n_bins, -1, k.shape[-1]))
        bv = np.reshape(v, (v.shape[0], n_bins, -1, v.shape[-1]))

        # Allow each chunk to attend within itself, and also one chunk back.
        def look_one_back(x):
            # Output: pairs [ bin_i bin_{i-1} ] concatenated on the time axis.
            if len(x.shape) == 2:
                x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0)
                return np.concatenate([x, x_extra], axis=1)
            else:
                assert len(x.shape) == 4
                x_extra = np.concatenate([x[:, -1:, :, :], x[:, :-1, :, :]],
                                         axis=1)
                return np.concatenate([x, x_extra], axis=2)

        bkv_t = look_one_back(bq_t)
        bk = look_one_back(bk)
        bv = look_one_back(bv)

        # Dot-product attention.
        dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(bq.shape[-1])

        # Causal masking based on the time indices.
        mask = jax.lax.convert_element_type(
            jax.lax.lt(bq_t[None, :, :, None], bkv_t[None, :, None, :]),
            np.float32)
        dots = dots - 1e9 * mask

        # Mask out attention to self except when no other targets are available.
        if self._share_qk:
            self_mask = jax.lax.broadcasted_eye(dots.dtype, dots.shape, (2, 3))
            self_mask = jax.lax.tie_in(dots, self_mask)
            dots = dots - 1e5 * self_mask

        # Softmax.
        dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))

        if self.dropout > 0.0:
            # Dropout is broadcast across the batch+head dimension
            dropout_shape = (1, dots.shape[-3], dots.shape[-2], dots.shape[-1])
            keep_prob = jax.lax.tie_in(dots, 1.0 - self.dropout)
            keep = backend.random.bernoulli(rng, keep_prob, dropout_shape)
            multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(
                keep, keep_prob)
            dots = dots * multiplier

        bo = np.matmul(dots, bv)

        output = np.reshape(bo, (bo.shape[0], -1, bo.shape[-1]))
        assert output.shape == v.shape
        return output[..., :original_len, :]
Ejemplo n.º 17
0
 def SplitHeads(x):
     return np.transpose(np.reshape(x, (nbatch, -1, n_heads, d_head)),
                         (0, 2, 1, 3))
Ejemplo n.º 18
0
def Flatten(x, n_axes_to_keep=1, **unused_kwargs):
  if n_axes_to_keep >= len(x.shape):
    raise ValueError("n_axes_to_keep[%d] should be less than input's rank[%d]" %
                     (n_axes_to_keep, len(x.shape)))
  return np.reshape(x, (x.shape[:n_axes_to_keep] + (-1,)))
Ejemplo n.º 19
0
 def JoinHeads(x):  # pylint: disable=invalid-name
     return np.reshape(np.transpose(x, (0, 2, 1, 3)),
                       (nbatch, -1, n_heads * d_head))
Ejemplo n.º 20
0
def ReformerShortenLM(vocab_size,
                      shorten_factor=1,
                      d_embedding=256,
                      d_model=512,
                      d_ff=2048,
                      d_attention_key=64,
                      d_attention_value=64,
                      n_layers=6,
                      n_heads=8,
                      dropout=0.1,
                      max_len=2048,
                      n_attention_chunks=1,
                      attention_type=tl.DotProductCausalAttention,
                      share_qk=False,
                      axial_pos_shape=(),
                      d_axial_pos_embs=None,
                      ff_activation=tl.FastGelu,
                      ff_use_sru=0,
                      mode='train'):
    """Reversible transformer language model with shortening.

  When shorten_factor is F and processing an input of shape [batch, length],
  we embed the (shifted-right) input and then group each F elements (on length)
  into a single vector -- so that in the end we process a tensor of shape
    [batch, length // F, d_model]
  almost until the end -- at the end it's un-shortend and a SRU is applied.
  This reduces the length processed inside the main model body, effectively
  making the model faster but possibly slightly less accurate.

  Args:
    vocab_size: int: vocab size
    shorten_factor: by how much to shorten, see above
    d_embedding: the depth of the embedding layer and final logits
    d_model: int:  depth of *each half* of the two-part features
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    n_attention_chunks: int: number of chunks for attention
    attention_type: class: attention class to use, such as DotProductAttention.
    share_qk: bool, whether to share queries and keys.
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, values must sum to d_embedding.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    assert mode != 'predict'  # TODO(lukaszkaiser,kitaev): fast inference

    if not axial_pos_shape:
        positional_encoding = tl.PositionalEncoding(max_len=max_len,
                                                    dropout=dropout,
                                                    mode=mode)
    else:
        assert d_axial_pos_embs is not None
        positional_encoding = tl.AxialPositionalEncoding(
            shape=axial_pos_shape,
            d_embs=d_axial_pos_embs,
            dropout_broadcast_dims=tuple(range(1,
                                               len(axial_pos_shape) + 1)),
            dropout=dropout,
            mode=mode)

    positional_embedder = [
        tl.Embedding(d_embedding, vocab_size),
        BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
        positional_encoding,
    ]

    decoder_blocks = []

    if isinstance(attention_type, (tuple, list)):
        assert n_layers % len(attention_type) == 0
    else:
        attention_type = [attention_type]
    for layer_idx in range(n_layers):
        layer_attention_type = attention_type[layer_idx % len(attention_type)]
        decoder_block = DecoderBlock(
            d_model,
            d_ff,
            d_attention_key,
            d_attention_value,
            n_heads,
            n_attention_chunks,
            attention_type=layer_attention_type,
            dropout=dropout,
            share_qk=(share_qk or issubclass(layer_attention_type,
                                             tl.LSHCausalAttention)),
            ff_activation=ff_activation,
            ff_use_sru=ff_use_sru,
            mode=mode)
        decoder_blocks.append(decoder_block)

    # pylint: disable=g-long-lambda
    return tl.Serial(
        tl.ShiftRight(),
        positional_embedder,
        tl.Dup(),  # Stack has (x, x), the first will be shortened
        # Before shortening, we need to pad by shorten factor so as not to leak
        # information into the future. To understand why, imagine shorten factor
        # of 2 and sequence of length 4, so ABCD. If we shift just by 1, then we
        # would have 0ABC, which gets grouped to [0A][BC] on input, which is
        # predicting ABCD as targets. The problem is that [0A] has access to A
        # and [BC] has access to C -- it will learn to copy it, peek into
        # the future. Shifting twice to [00][AB] solves the problem as the first
        # "big" symbol becomes all-0 and the rest is shifted enough.
        tl.ShiftRight(n_shifts=shorten_factor - 1),
        tl.Fn(
            lambda x: np.reshape(  # Shorten -- move to depth.
                x, (x.shape[0], x.shape[1] // shorten_factor, -1)),
            n_out=1),
        tl.Dense(d_model),
        tl.Dup(),  # Stack has (short_x, short_x, x)
        tl.ReversibleSerial(decoder_blocks),
        tl.Select([0], n_in=2),
        tl.LayerNorm(),
        BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
        tl.Dense(shorten_factor * d_embedding),
        tl.Fn(
            lambda x: np.reshape(  # Prolong back.
                x, (x.shape[0], x.shape[1] * shorten_factor, -1)),
            n_out=1),
        tl.Concatenate(),  # Concatenate with just the embeddings.
        tl.CausalConv(d_embedding),
        tl.Relu(),
        tl.SRU(d_embedding),  # One RNN layer for conditional dependence.
        tl.Dense(vocab_size),
        tl.LogSoftmax())
Ejemplo n.º 21
0
def PaddingMask(x, weights, pad=0, **kwargs):
    del weights, kwargs
    return np.reshape(x != pad, (x.shape[0], 1, 1, x.shape[-1]))