Example #1
0
def SumLearnedPick(positions):
    """Get a pair (vec, pos) and pick new pos."""
    succ_keys = positions[:-1, :]
    succ_values = positions[1:, :]
    subtract_1_keys = positions[1:, :]
    subtract_1_values = positions[:-1, :]
    l = int(positions.shape[0]) // 2
    add_keys = np.array([
        np.concatenate([positions[i, :], positions[j, :]]) for i in range(l)
        for j in range(l)
    ])
    add_values = np.array(
        [positions[i + j, :] for i in range(l) for j in range(l)])
    # TODO(lukaszkaiser): try this below: "for j in range(i) for i in range(2*l)"
    sub_keys = np.array([
        np.concatenate([positions[i, :], positions[j, :]]) for j in range(l)
        for i in range(l)
    ])
    sub_values = np.array(
        [positions[max(i - j, 0), :] for j in range(l) for i in range(l)])
    return tl.Serial(
        Dup2(), Dup2(), Dup2(), Dup2(),
        tl.Parallel(
            LearnedQP(),
            LearnedQP(keys=succ_keys, values=succ_values),
            LearnedQP(keys=subtract_1_keys, values=subtract_1_values),
            LearnedQP(keys=add_keys, values=add_values, binary=True),
            LearnedQP(keys=sub_keys, values=sub_values, binary=True),
        ), Softmax5Branches(n_branches=5))
Example #2
0
 def look_one_back(x):
     # Output: pairs [ bin_i bin_{i-1} ] concatenated on the time axis.
     if len(x.shape) == 2:
         x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0)
         return np.concatenate([x, x_extra], axis=1)
     else:
         assert len(x.shape) == 4
         x_extra = np.concatenate([x[:, -1:, :, :], x[:, :-1, :, :]],
                                  axis=1)
         return np.concatenate([x, x_extra], axis=2)
Example #3
0
    def reverse(self, output, weights=(), state=(), new_state=(), **kwargs):
        del weights, kwargs

        x1_split = []
        x2_split = []
        for y in output:
            y1, y2 = np.split(y, 2, -1)
            x1_split.append(y1)
            x2_split.append(y2)

        x1 = np.concatenate(x1_split, self._axis)
        x2 = np.concatenate(x2_split, self._axis)

        return (x1, x2)
Example #4
0
    def forward(self, inputs, weights):
        x, gru_state = inputs

        # Dense layer on the concatenation of x and h.
        w1, b1, w2, b2 = weights
        y = np.dot(np.concatenate([x, gru_state], axis=-1), w1) + b1

        # Update and reset gates.
        u, r = np.split(backend.sigmoid(y), 2, axis=-1)

        # Candidate.
        c = np.dot(np.concatenate([x, r * gru_state], axis=-1), w2) + b2

        new_gru_state = u * gru_state + (1 - u) * np.tanh(c)
        return new_gru_state, new_gru_state
Example #5
0
    def forward_with_state(self,
                           inputs,
                           weights=base.EMPTY_WEIGHTS,
                           state=base.EMPTY_STATE,
                           rng=None,
                           **kwargs):
        embs = []
        for ax_emb in weights:
            ax_emb = np.broadcast_to(ax_emb, (inputs.shape[0], ) +
                                     self._shape + (ax_emb.shape[-1], ))
            embs.append(ax_emb)
        emb = np.concatenate(embs, -1)

        if self._mode == 'predict':
            assert self._dropout == 0.0
            emb = np.reshape(emb, (inputs.shape[0], -1, emb.shape[-1]))
            return inputs + emb[:, state, :][:, None, :], state + 1
        elif self._dropout == 0:
            return inputs + np.reshape(emb, inputs.shape), state
        else:
            noise_shape = list(emb.shape)
            for dim in self._dropout_broadcast_dims:
                noise_shape[dim] = 1
            keep_prob = 1.0 - self._dropout
            if backend.get_name() == 'jax':
                keep_prob = jax.lax.tie_in(
                    inputs, np.full((), keep_prob, dtype=inputs.dtype))
            keep = backend.random.bernoulli(rng, keep_prob, tuple(noise_shape))
            multiplier = keep.astype(inputs.dtype) / keep_prob

            return inputs + np.reshape(emb * multiplier, inputs.shape), state
def CombineHeadsPos(x, n_heads=1, **unused_kwargs):
  """Mix x = (x0, p0, ..., xH, pH) into (x0, ...., xH), p_combined.

  The positions are averaged as vectors.

  Args:
    x: input vector, concatenated (x0, p0, ..., xH, pH).
    n_heads: number of heads.

  Returns:
    the vector with combined xs and one with combined positions.
  """
  seqlen = x.shape[1]
  d_head = x.shape[2]
  x = np.reshape(x, (-1, n_heads, seqlen, d_head))
  x = np.transpose(x, (0, 2, 1, 3))  # -> n_batch, seqlen, n_heads, d_head
  x = np.reshape(x, (-1, seqlen, n_heads * d_head))
  head_size = int(d_head) - POS_VECTOR_SIZE
  res, positions, idx = [], [], 0
  for _ in range(n_heads):
    res.append(x[:, :, idx:idx+head_size])
    idx += head_size
    positions.append(x[:, :, idx:idx+POS_VECTOR_SIZE])
    idx += POS_VECTOR_SIZE
  combined_position = sum(positions) / float(len(positions))
  return np.concatenate(res, axis=-1), combined_position
Example #7
0
    def forward(self, inputs, weights):
        x, lstm_state = inputs

        # LSTM state consists of c and h.
        c, h = np.split(lstm_state, 2, axis=-1)

        # Dense layer on the concatenation of x and h.
        w, b = weights
        y = np.dot(np.concatenate([x, h], axis=-1), w) + b

        # i = input_gate, j = new_input, f = forget_gate, o = output_gate
        i, j, f, o = np.split(y, 4, axis=-1)

        new_c = c * backend.sigmoid(f) + backend.sigmoid(i) * np.tanh(j)
        new_h = np.tanh(new_c) * backend.sigmoid(o)
        return new_h, np.concatenate([new_c, new_h], axis=-1)
Example #8
0
def NewPositionalEncoding(x, positions=None, **kwargs):
    """Implements new positional encoding."""
    del kwargs
    x_length = np.shape(x)[1]
    pos = np.array(positions)[np.newaxis, :x_length, :]
    pos += np.zeros((np.shape(x)[0], 1, 1))  # Broadcast on batch.
    res = np.concatenate([x, pos], axis=2)
    return res
 def forward(self, inp, weights):
   """Reshape input to have heads dimension and concatenate positions there."""
   x = inp[0]
   n_batches, seqlen = x.shape[0], x.shape[1]
   d_head = x.shape[-1] // self._n_heads
   res = np.reshape(x, (n_batches, seqlen, self._n_heads, d_head))
   res = np.transpose(res, (0, 2, 1, 3))  # (batch, heads, len, depth)
   if self._n_pos == 1:  # Just one position given, tile into each head.
     pos_shape = list(res.shape)[:-1] + [inp[1].shape[-1]]
     pos = inp[1][:, None, :, :] + np.zeros(pos_shape)  # Add 0 to broadcast.
   else:  # As many positions as heads, concatenate them in.
     pos = [p[:, None, :, :] for p in inp[1:]]
     pos = np.concatenate(pos, axis=1)
   res = np.concatenate([res, pos], axis=-1)
   # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head
   res = np.reshape(res, (-1, seqlen, d_head + POS_VECTOR_SIZE))
   return res
Example #10
0
    def forward(self, inputs, weights):
        del weights
        x1, x2 = inputs

        x1_split = np.split(x1, self._n_sections, self._axis)
        x2_split = np.split(x2, self._n_sections, self._axis)

        res = [np.concatenate(ys, -1) for ys in zip(x1_split, x2_split)]
        return tuple(res)
Example #11
0
def MixHeadsPos(x, h=8, **unused_kwargs):
    """Mix x = (x0, p) into x0_h1, p, x0_h2, p, ...."""
    head_size = (x.shape[2] - POS_VECTOR_SIZE) // h
    p = x[:, :, -POS_VECTOR_SIZE:]
    res, idx = [], 0
    for _ in range(h):
        res.append(x[:, :, idx:idx + head_size])
        res.append(p)
        idx += head_size
    return np.concatenate(res, axis=-1)
def QueryPositionKV(x, keys=None, values=None, binary=False, **unused_kwargs):
  """Query a table with a position vector."""
  if keys is None:
    return x
  k = np.array(keys)
  v = np.array(values)
  q = x
  if binary:
    q = np.concatenate([x, x], axis=-1)
  return tl.DotProductAttention(q, k, v, None, 0.0, None, None)
Example #13
0
 def test_fn_layer_difficult_n_out(self):
     with self.assertRaisesRegexp(ValueError, 'n_out'):
         # Determining the output of this layer is hard with dummies.
         cb.Fn(lambda x: np.concatencate([x, x], axis=4))
     # Check that this layer works when n_out is set.
     layer = cb.Fn(lambda x: np.concatenate([x, x], axis=4), n_out=1)
     input_signature = ShapeDtype((2, 1, 2, 2, 3))
     expected_shape = (2, 1, 2, 2, 6)
     output_shape = base.check_shape_agreement(layer, input_signature)
     self.assertEqual(output_shape, expected_shape)
Example #14
0
 def test_fn_layer_example(self):
     layer = cb.Fn(lambda x, y: (x + y, np.concatenate([x, y], axis=0)))
     input_signature = (ShapeDtype((2, 7)), ShapeDtype((2, 7)))
     expected_shape = ((2, 7), (4, 7))
     output_shape = base.check_shape_agreement(layer, input_signature)
     self.assertEqual(output_shape, expected_shape)
     inp = (np.array([2]), np.array([3]))
     x, xs = layer(inp)
     self.assertEqual(int(x), 5)
     self.assertEqual([int(y) for y in xs], [2, 3])
Example #15
0
def CopyHeadsPos(x, h=8, **unused_kwargs):
    """Mix x = (x, p) into x_h1, p_h1, x_h2, p_h2, ...."""
    head_size = (x.shape[2] - h * POS_VECTOR_SIZE) // h
    p = x[:, :, -h * POS_VECTOR_SIZE:]
    res, idx = [], 0
    for i in range(h):
        res.append(x[:, :, idx:idx + head_size])
        res.append(p[:, :, i * POS_VECTOR_SIZE:(i + 1) * POS_VECTOR_SIZE])
        idx += head_size
    return np.concatenate(res, axis=-1)
def PerformPositionOperations(pos, positions=None):
  """Gets pos and returns (q1, ..., q5)."""
  succ_keys = positions[:-1, :]
  succ_values = positions[1:, :]
  subtract_1_keys = positions[1:, :]
  subtract_1_values = positions[:-1, :]
  l = int(positions.shape[0]) // 2
  add_keys = np.array([np.concatenate([positions[i, :], positions[j, :]])
                       for i in range(l) for j in range(l)])
  add_values = np.array([positions[i + j, :]
                         for i in range(l) for j in range(l)])
  # TODO(lukaszkaiser): try this below: "for j in range(i) for i in range(2*l)"
  sub_keys = np.array([np.concatenate([positions[i, :], positions[j, :]])
                       for j in range(l) for i in range(l)])
  sub_values = np.array([positions[max(i - j, 0), :]
                         for j in range(l) for i in range(l)])
  query_types = [
      QueryPositionKV(),
      QueryPositionKV(keys=succ_keys, values=succ_values),
      QueryPositionKV(keys=subtract_1_keys, values=subtract_1_values),
      QueryPositionKV(keys=add_keys, values=add_values, binary=True),
      QueryPositionKV(keys=sub_keys, values=sub_values, binary=True)]
  return [qt @ pos for qt in query_types]  # pylint: disable=syntax-error
Example #17
0
def DiagonalGate(x, **kwargs):
    """Split channels in 3 parts. Shifts 1st and 3rd sections to left/right."""
    del kwargs
    # x : [batch, 1, length, depth]
    x = np.pad(x, [(0, 0), (0, 0), (1, 1), (0, 0)],
               mode='constant',
               constant_values=0.0)
    depth = x.shape[-1] // 3
    assert 3 * depth == x.shape[-1], ('Depth must be divisible by 3', depth,
                                      x.shape)
    xs = [
        x[:, :, :-2, :depth], x[:, :, 1:-1, depth:2 * depth],
        x[:, :, 2:, 2 * depth:3 * depth]
    ]
    return np.concatenate(xs, axis=3)
Example #18
0
def CombineHeadsPos(x, h=8, **unused_kwargs):
    """Mix x = (x0, p0, ..., xH, pH) into x0, ...., xH, p_combined.

  The positions are added as vectors.

  Args:
    x: input vector, concatenated (x0, p0, ..., xH, pH).
    h: number of heads.

  Returns:
    the vector with combined positions.
  """
    head_size = int((x.shape[2] / h) - POS_VECTOR_SIZE)
    res, positions, idx = [], [], 0
    for _ in range(h):
        res.append(x[:, :, idx:idx + head_size])
        idx += head_size
        positions.append(x[:, :, idx:idx + POS_VECTOR_SIZE])
        idx += POS_VECTOR_SIZE
    combined_position = sum(positions)
    res.append(combined_position)
    return np.concatenate(res, axis=-1)
Example #19
0
 def look_one_back(x):
     if len(x.shape) == 2:
         x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0)
     else:
         x_extra = np.concatenate([x[-1:, :, :], x[:-1, :, :]], axis=0)
     return np.concatenate([x, x_extra], axis=1)
Example #20
0
    def single_call(self, qk, v, buckets, rng=None):
        # We use the same vector as both a query and a key.
        seqlen = qk.shape[-2]
        assert int(buckets.shape[0]) == self.n_hashes * seqlen

        ticker = jax.lax.tie_in(qk, np.arange(self.n_hashes * seqlen))
        buckets_and_t = seqlen * buckets + (ticker % seqlen)
        buckets_and_t = jax.lax.stop_gradient(buckets_and_t)

        # Hash-based sort ("s" at the start of variable names means "sorted")
        sbuckets_and_t, sticker = jax.lax.sort_key_val(buckets_and_t,
                                                       ticker,
                                                       dimension=-1)
        _, undo_sort = jax.lax.sort_key_val(sticker, ticker, dimension=-1)
        sbuckets_and_t = jax.lax.stop_gradient(sbuckets_and_t)
        sticker = jax.lax.stop_gradient(sticker)
        undo_sort = jax.lax.stop_gradient(undo_sort)

        st = (sticker % seqlen)
        sqk = np.take(qk, st, axis=0)
        sv = np.take(v, st, axis=0)

        # Split off a "bin" axis so that attention only occurs within chunks.
        bq_t = bkv_t = np.reshape(st, (self.n_hashes * self.n_bins, -1))
        bqk = np.reshape(sqk, (self.n_hashes * self.n_bins, -1, sqk.shape[-1]))
        bv = np.reshape(sv, (self.n_hashes * self.n_bins, -1, sv.shape[-1]))
        bq_buckets = bkv_buckets = np.reshape(
            sbuckets_and_t // seqlen, (self.n_hashes * self.n_bins, -1))

        # Hashing operates on unit-length vectors. Unnormalized query vectors are
        # fine because they effectively provide a learnable temperature for the
        # attention softmax, but normalizing keys is needed so that similarity for
        # the purposes of attention correctly corresponds to hash locality.
        bq = bqk
        bk = self.make_unit_length(bqk)

        # Allow each chunk to attend within itself, and also one chunk back. Chunk
        # boundaries might occur in the middle of a sequence of items from the
        # same bucket, so this increases the chances of attending to relevant items.
        # TODO(kitaev): benchmark whether XLA pad operation is noticeably faster.
        def look_one_back(x):
            if len(x.shape) == 2:
                x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0)
            else:
                x_extra = np.concatenate([x[-1:, :, :], x[:-1, :, :]], axis=0)
            return np.concatenate([x, x_extra], axis=1)

        bk = look_one_back(bk)
        bv = look_one_back(bv)
        bkv_t = look_one_back(bkv_t)
        bkv_buckets = look_one_back(bkv_buckets)

        # Dot-product attention.
        dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(bq.shape[-1])

        # Causal masking
        mask = jax.lax.convert_element_type(
            jax.lax.lt(bq_t[:, :, None], bkv_t[:, None, :]), np.float32)
        dots = dots - 1e9 * mask

        # Mask out attention to self except when no other targets are available.
        self_mask = jax.lax.convert_element_type(
            jax.lax.eq(bq_t[:, :, None], bkv_t[:, None, :]), np.float32)
        dots = dots - 1e5 * self_mask

        # Mask out attention to other hash buckets.
        if not self._attend_across_buckets:
            bucket_mask = jax.lax.convert_element_type(
                jax.lax.ne(bq_buckets[:, :, None], bkv_buckets[:, None, :]),
                np.float32)
            dots = dots - 1e7 * bucket_mask

        # Don't double-count query-key pairs across multiple rounds of hashing.
        # There are two possible strategies here. (1) The default is to count how
        # many times a query-key pair is repeated, and to lower its log-prob
        # correspondingly at each repetition. (2) When hard_k is set, the code
        # instead masks all but the first occurence of each query-key pair.
        # TODO(kitaev): is one strategy faster or more numerically stable?
        if not self._allow_duplicate_attention:
            locs1 = undo_sort // bq_t.shape[-1]
            locs2 = (locs1 + 1) % (self.n_hashes * self.n_bins)
            if not self._attend_across_buckets:
                locs1 = buckets * (self.n_hashes * self.n_bins) + locs1
                locs2 = buckets * (self.n_hashes * self.n_bins) + locs2
            locs = np.moveaxis(
                np.concatenate([
                    np.reshape(locs1, (self.n_hashes, seqlen)),
                    np.reshape(locs2, (self.n_hashes, seqlen)),
                ], 0), 0, -1)  # produces shape (seqlen, 2 * self.n_hashes)
            slocs = np.take(locs, st, axis=0)
            b_locs = np.reshape(
                slocs, (self.n_hashes * self.n_bins, -1, 2 * self.n_hashes))
            # Queries always use the primary location (based on locs1).
            b_locs1 = b_locs[:, :, None, :self.n_hashes]
            if self._hard_k > 0:
                range_n_hashes = jax.lax.tie_in(b_locs,
                                                np.arange(self.n_hashes))
                nouse_locs = (range_n_hashes[:, None] >
                              range_n_hashes[None, :])
                nouse_locs = 2 * nouse_locs - 1  # 1 = use, -1 = don't use
                nouse_locs = np.reshape(
                    np.broadcast_to(
                        nouse_locs[:, None, :],
                        (self.n_hashes, self.n_bins, self.n_hashes)),
                    (self.n_hashes * self.n_bins, 1, 1, self.n_hashes))
                b_locs1 = b_locs1 * nouse_locs
            bq_locs = np.broadcast_to(b_locs1,
                                      b_locs.shape[:2] + (2, self.n_hashes))
            bq_locs = np.reshape(bq_locs, b_locs.shape)
            bkv_locs = look_one_back(b_locs)

            dup_counts = np.sum(jax.lax.convert_element_type(
                jax.lax.eq(bq_locs[:, :, None, :], bkv_locs[:, None, :, :]),
                np.float32),
                                axis=-1)
            assert dup_counts.shape == dots.shape
            if self._hard_k > 0:
                dots = dots - 1e7 * jax.lax.stop_gradient(dup_counts)
            else:
                dots = dots - jax.lax.stop_gradient(np.log(dup_counts + 1e-9))

        # Each query only attends to the top k most relevant keys.
        if self._hard_k > 0:
            b_top_dots = np.sort(dots)[...,
                                       -self._hard_k:]  # Get the top k dots.
            b_top_dots = jax.lax.stop_gradient(b_top_dots)
            s_top_dots = np.reshape(b_top_dots, (-1, self._hard_k))
            top_dots = np.take(s_top_dots, undo_sort, axis=0)

            merged_top_dots = np.moveaxis(
                np.reshape(top_dots, (self.n_hashes, seqlen, self._hard_k)), 0,
                -1)
            merged_top_dots = np.reshape(merged_top_dots, (seqlen, -1))

            dots_thresh = np.sort(merged_top_dots)[:, -self._hard_k]
            # It's possible to compute the partition function at this point, but right
            # now this codepath isn't set up for backprop, and there might also be
            # issues computing it this way if two dot-products are exactly equal.

            sdots_thresh = dots_thresh[st]
            bdots_thresh = np.reshape(sdots_thresh,
                                      (self.n_hashes * self.n_bins, -1))
            bdots_thresh = jax.lax.stop_gradient(bdots_thresh)

            top_k_mask = jax.lax.convert_element_type(
                dots < bdots_thresh[..., None], np.float32)
            dots = dots - 1e7 * jax.lax.stop_gradient(top_k_mask)

        # Softmax.
        dots_logsumexp = backend.logsumexp(dots, axis=-1, keepdims=True)
        dots = np.exp(dots - dots_logsumexp)

        if self._dropout > 0.0:
            # Dropout is broadcast across the bin dimension
            dropout_shape = (1, dots.shape[-2], dots.shape[-1])
            keep_prob = jax.lax.tie_in(dots, 1.0 - self._dropout)
            keep = backend.random.bernoulli(rng, keep_prob, dropout_shape)
            multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(
                keep, keep_prob)
            dots = dots * multiplier

        bo = np.matmul(dots, bv)
        so = np.reshape(bo, (-1, bo.shape[-1]))
        slogits = np.reshape(dots_logsumexp, (-1, ))

        def unsort_for_output_impl(so, slogits):
            o = np.take(so, undo_sort, axis=0)
            # Sorting is considerably faster than gather, but first we need to get the
            # XLA compiler to abandon the idea of fusing this sort with the input sort
            # (which introduces a computation cycle and leads to a crash).
            # TODO(kitaev): remove "sticker_" variable if XLA is fixed.
            sticker_ = sticker + jax.lax.convert_element_type(
                slogits[0] > 0, sticker.dtype)
            _, logits = jax.lax.sort_key_val(sticker_, slogits, dimension=-1)
            return o, logits

        def unsort_for_output_vjp(so, slogits):
            """Custom gradient for unsort_for_output."""
            so = jax.lax.stop_gradient(so)
            slogits = jax.lax.stop_gradient(slogits)
            o, logits = unsort_for_output_impl(so, slogits)

            def vjpfun(o_logits_grads):
                so_grad = np.take(o_logits_grads[0], sticker, axis=0)
                # TODO(kitaev): this exists to match the forward pass, but I'm not sure
                # if it's actually required.
                buckets_and_t_ = buckets_and_t + jax.lax.convert_element_type(
                    o_logits_grads[1][0] > 0, buckets_and_t.dtype)
                _, slogits_grad = jax.lax.sort_key_val(buckets_and_t_,
                                                       o_logits_grads[1],
                                                       dimension=-1)
                return (so_grad, slogits_grad)

            return (o, logits), vjpfun

        unsort_for_output = jax.custom_transforms(unsort_for_output_impl)
        jax.defvjp_all(unsort_for_output, unsort_for_output_vjp)
        o, logits = unsort_for_output_impl(so, slogits)

        if self.n_hashes == 1:
            out = o
        else:
            o = np.reshape(o, (self.n_hashes, seqlen, o.shape[-1]))
            logits = np.reshape(logits, (self.n_hashes, seqlen, 1))
            probs = np.exp(logits -
                           backend.logsumexp(logits, axis=0, keepdims=True))
            out = np.sum(o * probs, axis=0)

        assert out.shape == v.shape
        return out
Example #21
0
    def hash_vectors(self, vecs, rng):
        # See https://arxiv.org/pdf/1509.02897.pdf
        # We sample a different random rotation for each round of hashing to
        # decrease the probability of hash misses.
        assert self.n_buckets % 2 == 0

        # If we factorize the hash, find a factor dividing n_buckets nicely.
        rot_size, factor_list = self.n_buckets, [self.n_buckets]
        if self._factorize_hash:
            # If we are given a list of factors, verify it and use later.
            if isinstance(self._factorize_hash, list):
                rot_size, product = 0, 1
                factor_list = self._factorize_hash
                for factor in factor_list:
                    assert factor % 2 == 0
                    product *= factor
                    rot_size += factor
                assert product == self.n_buckets
            else:  # Find one factor if just set to True.
                # We want to represent self.n_buckets = factor * rest so that
                # (1) both factor and rest are even, and (2) factor + rest is minimal.
                # To compute this we start from factor = sqrt(n_buckets) and go down
                # with it until we find one that satisfies the constraints above.
                factor = int(math.sqrt(self.n_buckets))
                while factor > 0 and not (self.n_buckets % factor == 0
                                          and factor % 2 == 0 and
                                          (self.n_buckets // factor) % 2 == 0):
                    factor -= 1
                if factor > 2:  # Factor of 2 does not warrant the effort.
                    rot_size = factor + (self.n_buckets // factor)
                    factor_list = [factor, self.n_buckets // factor]

        rotations_shape = (vecs.shape[-1],
                           self.n_hashes if self._rehash_each_round else 1,
                           rot_size // 2)

        rng = jax.lax.tie_in(vecs, rng)
        rng, subrng = backend.random.split(rng)
        random_rotations = self._sample_rotation(rotations_shape, vecs, rng)

        # TODO(lukaszkaiser): the dropout mask will be used for all rounds of
        # hashing, so it's shared between them. Check if that's what we want.
        dropped_vecs = self.drop_for_hash(vecs, subrng)
        rotated_vecs = np.einsum('tf,fhb->htb', dropped_vecs, random_rotations)

        if self._rehash_each_round:
            if self._factorize_hash and len(factor_list) > 1:
                # We factorized self.n_buckets as the product of factor_list.
                # Get the buckets for them and combine.
                buckets, cur_sum, cur_product = None, 0, 1
                for factor in factor_list:
                    rv = rotated_vecs[..., cur_sum:cur_sum + (factor // 2)]
                    cur_sum += factor // 2
                    rv = np.concatenate([rv, -rv], axis=-1)
                    if buckets is None:
                        buckets = np.argmax(rv, axis=-1)
                    else:
                        buckets += cur_product * np.argmax(rv, axis=-1)
                    cur_product *= factor
            else:
                rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs],
                                              axis=-1)
                buckets = np.argmax(rotated_vecs, axis=-1)
            # buckets is now (self.n_hashes, seqlen). Next we add offsets so that
            # bucket numbers from different hashing rounds don't overlap.
            offsets = jax.lax.tie_in(buckets, np.arange(self.n_hashes))
            offsets = np.reshape(offsets * self.n_buckets, (-1, 1))
            buckets = np.reshape(buckets + offsets, (-1, ))
        else:
            assert not self._factorize_hash
            rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs],
                                          axis=-1)
            # In this configuration, we map each item to the top self.n_hashes buckets
            rotated_vecs = np.squeeze(rotated_vecs, 0)
            bucket_range = jax.lax.tie_in(vecs,
                                          np.arange(rotated_vecs.shape[-1]))
            bucket_range = np.reshape(bucket_range, (1, -1))
            bucket_range = np.broadcast_to(bucket_range, rotated_vecs.shape)

            _, buckets = jax.lax.sort_key_val(rotated_vecs,
                                              bucket_range,
                                              dimension=-1)
            buckets = buckets[:, -self.n_hashes:]
            buckets = np.reshape(np.moveaxis(buckets, 0, -1), (-1, ))

        return buckets