def SumLearnedPick(positions): """Get a pair (vec, pos) and pick new pos.""" succ_keys = positions[:-1, :] succ_values = positions[1:, :] subtract_1_keys = positions[1:, :] subtract_1_values = positions[:-1, :] l = int(positions.shape[0]) // 2 add_keys = np.array([ np.concatenate([positions[i, :], positions[j, :]]) for i in range(l) for j in range(l) ]) add_values = np.array( [positions[i + j, :] for i in range(l) for j in range(l)]) # TODO(lukaszkaiser): try this below: "for j in range(i) for i in range(2*l)" sub_keys = np.array([ np.concatenate([positions[i, :], positions[j, :]]) for j in range(l) for i in range(l) ]) sub_values = np.array( [positions[max(i - j, 0), :] for j in range(l) for i in range(l)]) return tl.Serial( Dup2(), Dup2(), Dup2(), Dup2(), tl.Parallel( LearnedQP(), LearnedQP(keys=succ_keys, values=succ_values), LearnedQP(keys=subtract_1_keys, values=subtract_1_values), LearnedQP(keys=add_keys, values=add_values, binary=True), LearnedQP(keys=sub_keys, values=sub_values, binary=True), ), Softmax5Branches(n_branches=5))
def look_one_back(x): # Output: pairs [ bin_i bin_{i-1} ] concatenated on the time axis. if len(x.shape) == 2: x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0) return np.concatenate([x, x_extra], axis=1) else: assert len(x.shape) == 4 x_extra = np.concatenate([x[:, -1:, :, :], x[:, :-1, :, :]], axis=1) return np.concatenate([x, x_extra], axis=2)
def reverse(self, output, weights=(), state=(), new_state=(), **kwargs): del weights, kwargs x1_split = [] x2_split = [] for y in output: y1, y2 = np.split(y, 2, -1) x1_split.append(y1) x2_split.append(y2) x1 = np.concatenate(x1_split, self._axis) x2 = np.concatenate(x2_split, self._axis) return (x1, x2)
def forward(self, inputs, weights): x, gru_state = inputs # Dense layer on the concatenation of x and h. w1, b1, w2, b2 = weights y = np.dot(np.concatenate([x, gru_state], axis=-1), w1) + b1 # Update and reset gates. u, r = np.split(backend.sigmoid(y), 2, axis=-1) # Candidate. c = np.dot(np.concatenate([x, r * gru_state], axis=-1), w2) + b2 new_gru_state = u * gru_state + (1 - u) * np.tanh(c) return new_gru_state, new_gru_state
def forward_with_state(self, inputs, weights=base.EMPTY_WEIGHTS, state=base.EMPTY_STATE, rng=None, **kwargs): embs = [] for ax_emb in weights: ax_emb = np.broadcast_to(ax_emb, (inputs.shape[0], ) + self._shape + (ax_emb.shape[-1], )) embs.append(ax_emb) emb = np.concatenate(embs, -1) if self._mode == 'predict': assert self._dropout == 0.0 emb = np.reshape(emb, (inputs.shape[0], -1, emb.shape[-1])) return inputs + emb[:, state, :][:, None, :], state + 1 elif self._dropout == 0: return inputs + np.reshape(emb, inputs.shape), state else: noise_shape = list(emb.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout if backend.get_name() == 'jax': keep_prob = jax.lax.tie_in( inputs, np.full((), keep_prob, dtype=inputs.dtype)) keep = backend.random.bernoulli(rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(inputs.dtype) / keep_prob return inputs + np.reshape(emb * multiplier, inputs.shape), state
def CombineHeadsPos(x, n_heads=1, **unused_kwargs): """Mix x = (x0, p0, ..., xH, pH) into (x0, ...., xH), p_combined. The positions are averaged as vectors. Args: x: input vector, concatenated (x0, p0, ..., xH, pH). n_heads: number of heads. Returns: the vector with combined xs and one with combined positions. """ seqlen = x.shape[1] d_head = x.shape[2] x = np.reshape(x, (-1, n_heads, seqlen, d_head)) x = np.transpose(x, (0, 2, 1, 3)) # -> n_batch, seqlen, n_heads, d_head x = np.reshape(x, (-1, seqlen, n_heads * d_head)) head_size = int(d_head) - POS_VECTOR_SIZE res, positions, idx = [], [], 0 for _ in range(n_heads): res.append(x[:, :, idx:idx+head_size]) idx += head_size positions.append(x[:, :, idx:idx+POS_VECTOR_SIZE]) idx += POS_VECTOR_SIZE combined_position = sum(positions) / float(len(positions)) return np.concatenate(res, axis=-1), combined_position
def forward(self, inputs, weights): x, lstm_state = inputs # LSTM state consists of c and h. c, h = np.split(lstm_state, 2, axis=-1) # Dense layer on the concatenation of x and h. w, b = weights y = np.dot(np.concatenate([x, h], axis=-1), w) + b # i = input_gate, j = new_input, f = forget_gate, o = output_gate i, j, f, o = np.split(y, 4, axis=-1) new_c = c * backend.sigmoid(f) + backend.sigmoid(i) * np.tanh(j) new_h = np.tanh(new_c) * backend.sigmoid(o) return new_h, np.concatenate([new_c, new_h], axis=-1)
def NewPositionalEncoding(x, positions=None, **kwargs): """Implements new positional encoding.""" del kwargs x_length = np.shape(x)[1] pos = np.array(positions)[np.newaxis, :x_length, :] pos += np.zeros((np.shape(x)[0], 1, 1)) # Broadcast on batch. res = np.concatenate([x, pos], axis=2) return res
def forward(self, inp, weights): """Reshape input to have heads dimension and concatenate positions there.""" x = inp[0] n_batches, seqlen = x.shape[0], x.shape[1] d_head = x.shape[-1] // self._n_heads res = np.reshape(x, (n_batches, seqlen, self._n_heads, d_head)) res = np.transpose(res, (0, 2, 1, 3)) # (batch, heads, len, depth) if self._n_pos == 1: # Just one position given, tile into each head. pos_shape = list(res.shape)[:-1] + [inp[1].shape[-1]] pos = inp[1][:, None, :, :] + np.zeros(pos_shape) # Add 0 to broadcast. else: # As many positions as heads, concatenate them in. pos = [p[:, None, :, :] for p in inp[1:]] pos = np.concatenate(pos, axis=1) res = np.concatenate([res, pos], axis=-1) # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head res = np.reshape(res, (-1, seqlen, d_head + POS_VECTOR_SIZE)) return res
def forward(self, inputs, weights): del weights x1, x2 = inputs x1_split = np.split(x1, self._n_sections, self._axis) x2_split = np.split(x2, self._n_sections, self._axis) res = [np.concatenate(ys, -1) for ys in zip(x1_split, x2_split)] return tuple(res)
def MixHeadsPos(x, h=8, **unused_kwargs): """Mix x = (x0, p) into x0_h1, p, x0_h2, p, ....""" head_size = (x.shape[2] - POS_VECTOR_SIZE) // h p = x[:, :, -POS_VECTOR_SIZE:] res, idx = [], 0 for _ in range(h): res.append(x[:, :, idx:idx + head_size]) res.append(p) idx += head_size return np.concatenate(res, axis=-1)
def QueryPositionKV(x, keys=None, values=None, binary=False, **unused_kwargs): """Query a table with a position vector.""" if keys is None: return x k = np.array(keys) v = np.array(values) q = x if binary: q = np.concatenate([x, x], axis=-1) return tl.DotProductAttention(q, k, v, None, 0.0, None, None)
def test_fn_layer_difficult_n_out(self): with self.assertRaisesRegexp(ValueError, 'n_out'): # Determining the output of this layer is hard with dummies. cb.Fn(lambda x: np.concatencate([x, x], axis=4)) # Check that this layer works when n_out is set. layer = cb.Fn(lambda x: np.concatenate([x, x], axis=4), n_out=1) input_signature = ShapeDtype((2, 1, 2, 2, 3)) expected_shape = (2, 1, 2, 2, 6) output_shape = base.check_shape_agreement(layer, input_signature) self.assertEqual(output_shape, expected_shape)
def test_fn_layer_example(self): layer = cb.Fn(lambda x, y: (x + y, np.concatenate([x, y], axis=0))) input_signature = (ShapeDtype((2, 7)), ShapeDtype((2, 7))) expected_shape = ((2, 7), (4, 7)) output_shape = base.check_shape_agreement(layer, input_signature) self.assertEqual(output_shape, expected_shape) inp = (np.array([2]), np.array([3])) x, xs = layer(inp) self.assertEqual(int(x), 5) self.assertEqual([int(y) for y in xs], [2, 3])
def CopyHeadsPos(x, h=8, **unused_kwargs): """Mix x = (x, p) into x_h1, p_h1, x_h2, p_h2, ....""" head_size = (x.shape[2] - h * POS_VECTOR_SIZE) // h p = x[:, :, -h * POS_VECTOR_SIZE:] res, idx = [], 0 for i in range(h): res.append(x[:, :, idx:idx + head_size]) res.append(p[:, :, i * POS_VECTOR_SIZE:(i + 1) * POS_VECTOR_SIZE]) idx += head_size return np.concatenate(res, axis=-1)
def PerformPositionOperations(pos, positions=None): """Gets pos and returns (q1, ..., q5).""" succ_keys = positions[:-1, :] succ_values = positions[1:, :] subtract_1_keys = positions[1:, :] subtract_1_values = positions[:-1, :] l = int(positions.shape[0]) // 2 add_keys = np.array([np.concatenate([positions[i, :], positions[j, :]]) for i in range(l) for j in range(l)]) add_values = np.array([positions[i + j, :] for i in range(l) for j in range(l)]) # TODO(lukaszkaiser): try this below: "for j in range(i) for i in range(2*l)" sub_keys = np.array([np.concatenate([positions[i, :], positions[j, :]]) for j in range(l) for i in range(l)]) sub_values = np.array([positions[max(i - j, 0), :] for j in range(l) for i in range(l)]) query_types = [ QueryPositionKV(), QueryPositionKV(keys=succ_keys, values=succ_values), QueryPositionKV(keys=subtract_1_keys, values=subtract_1_values), QueryPositionKV(keys=add_keys, values=add_values, binary=True), QueryPositionKV(keys=sub_keys, values=sub_values, binary=True)] return [qt @ pos for qt in query_types] # pylint: disable=syntax-error
def DiagonalGate(x, **kwargs): """Split channels in 3 parts. Shifts 1st and 3rd sections to left/right.""" del kwargs # x : [batch, 1, length, depth] x = np.pad(x, [(0, 0), (0, 0), (1, 1), (0, 0)], mode='constant', constant_values=0.0) depth = x.shape[-1] // 3 assert 3 * depth == x.shape[-1], ('Depth must be divisible by 3', depth, x.shape) xs = [ x[:, :, :-2, :depth], x[:, :, 1:-1, depth:2 * depth], x[:, :, 2:, 2 * depth:3 * depth] ] return np.concatenate(xs, axis=3)
def CombineHeadsPos(x, h=8, **unused_kwargs): """Mix x = (x0, p0, ..., xH, pH) into x0, ...., xH, p_combined. The positions are added as vectors. Args: x: input vector, concatenated (x0, p0, ..., xH, pH). h: number of heads. Returns: the vector with combined positions. """ head_size = int((x.shape[2] / h) - POS_VECTOR_SIZE) res, positions, idx = [], [], 0 for _ in range(h): res.append(x[:, :, idx:idx + head_size]) idx += head_size positions.append(x[:, :, idx:idx + POS_VECTOR_SIZE]) idx += POS_VECTOR_SIZE combined_position = sum(positions) res.append(combined_position) return np.concatenate(res, axis=-1)
def look_one_back(x): if len(x.shape) == 2: x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0) else: x_extra = np.concatenate([x[-1:, :, :], x[:-1, :, :]], axis=0) return np.concatenate([x, x_extra], axis=1)
def single_call(self, qk, v, buckets, rng=None): # We use the same vector as both a query and a key. seqlen = qk.shape[-2] assert int(buckets.shape[0]) == self.n_hashes * seqlen ticker = jax.lax.tie_in(qk, np.arange(self.n_hashes * seqlen)) buckets_and_t = seqlen * buckets + (ticker % seqlen) buckets_and_t = jax.lax.stop_gradient(buckets_and_t) # Hash-based sort ("s" at the start of variable names means "sorted") sbuckets_and_t, sticker = jax.lax.sort_key_val(buckets_and_t, ticker, dimension=-1) _, undo_sort = jax.lax.sort_key_val(sticker, ticker, dimension=-1) sbuckets_and_t = jax.lax.stop_gradient(sbuckets_and_t) sticker = jax.lax.stop_gradient(sticker) undo_sort = jax.lax.stop_gradient(undo_sort) st = (sticker % seqlen) sqk = np.take(qk, st, axis=0) sv = np.take(v, st, axis=0) # Split off a "bin" axis so that attention only occurs within chunks. bq_t = bkv_t = np.reshape(st, (self.n_hashes * self.n_bins, -1)) bqk = np.reshape(sqk, (self.n_hashes * self.n_bins, -1, sqk.shape[-1])) bv = np.reshape(sv, (self.n_hashes * self.n_bins, -1, sv.shape[-1])) bq_buckets = bkv_buckets = np.reshape( sbuckets_and_t // seqlen, (self.n_hashes * self.n_bins, -1)) # Hashing operates on unit-length vectors. Unnormalized query vectors are # fine because they effectively provide a learnable temperature for the # attention softmax, but normalizing keys is needed so that similarity for # the purposes of attention correctly corresponds to hash locality. bq = bqk bk = self.make_unit_length(bqk) # Allow each chunk to attend within itself, and also one chunk back. Chunk # boundaries might occur in the middle of a sequence of items from the # same bucket, so this increases the chances of attending to relevant items. # TODO(kitaev): benchmark whether XLA pad operation is noticeably faster. def look_one_back(x): if len(x.shape) == 2: x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0) else: x_extra = np.concatenate([x[-1:, :, :], x[:-1, :, :]], axis=0) return np.concatenate([x, x_extra], axis=1) bk = look_one_back(bk) bv = look_one_back(bv) bkv_t = look_one_back(bkv_t) bkv_buckets = look_one_back(bkv_buckets) # Dot-product attention. dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(bq.shape[-1]) # Causal masking mask = jax.lax.convert_element_type( jax.lax.lt(bq_t[:, :, None], bkv_t[:, None, :]), np.float32) dots = dots - 1e9 * mask # Mask out attention to self except when no other targets are available. self_mask = jax.lax.convert_element_type( jax.lax.eq(bq_t[:, :, None], bkv_t[:, None, :]), np.float32) dots = dots - 1e5 * self_mask # Mask out attention to other hash buckets. if not self._attend_across_buckets: bucket_mask = jax.lax.convert_element_type( jax.lax.ne(bq_buckets[:, :, None], bkv_buckets[:, None, :]), np.float32) dots = dots - 1e7 * bucket_mask # Don't double-count query-key pairs across multiple rounds of hashing. # There are two possible strategies here. (1) The default is to count how # many times a query-key pair is repeated, and to lower its log-prob # correspondingly at each repetition. (2) When hard_k is set, the code # instead masks all but the first occurence of each query-key pair. # TODO(kitaev): is one strategy faster or more numerically stable? if not self._allow_duplicate_attention: locs1 = undo_sort // bq_t.shape[-1] locs2 = (locs1 + 1) % (self.n_hashes * self.n_bins) if not self._attend_across_buckets: locs1 = buckets * (self.n_hashes * self.n_bins) + locs1 locs2 = buckets * (self.n_hashes * self.n_bins) + locs2 locs = np.moveaxis( np.concatenate([ np.reshape(locs1, (self.n_hashes, seqlen)), np.reshape(locs2, (self.n_hashes, seqlen)), ], 0), 0, -1) # produces shape (seqlen, 2 * self.n_hashes) slocs = np.take(locs, st, axis=0) b_locs = np.reshape( slocs, (self.n_hashes * self.n_bins, -1, 2 * self.n_hashes)) # Queries always use the primary location (based on locs1). b_locs1 = b_locs[:, :, None, :self.n_hashes] if self._hard_k > 0: range_n_hashes = jax.lax.tie_in(b_locs, np.arange(self.n_hashes)) nouse_locs = (range_n_hashes[:, None] > range_n_hashes[None, :]) nouse_locs = 2 * nouse_locs - 1 # 1 = use, -1 = don't use nouse_locs = np.reshape( np.broadcast_to( nouse_locs[:, None, :], (self.n_hashes, self.n_bins, self.n_hashes)), (self.n_hashes * self.n_bins, 1, 1, self.n_hashes)) b_locs1 = b_locs1 * nouse_locs bq_locs = np.broadcast_to(b_locs1, b_locs.shape[:2] + (2, self.n_hashes)) bq_locs = np.reshape(bq_locs, b_locs.shape) bkv_locs = look_one_back(b_locs) dup_counts = np.sum(jax.lax.convert_element_type( jax.lax.eq(bq_locs[:, :, None, :], bkv_locs[:, None, :, :]), np.float32), axis=-1) assert dup_counts.shape == dots.shape if self._hard_k > 0: dots = dots - 1e7 * jax.lax.stop_gradient(dup_counts) else: dots = dots - jax.lax.stop_gradient(np.log(dup_counts + 1e-9)) # Each query only attends to the top k most relevant keys. if self._hard_k > 0: b_top_dots = np.sort(dots)[..., -self._hard_k:] # Get the top k dots. b_top_dots = jax.lax.stop_gradient(b_top_dots) s_top_dots = np.reshape(b_top_dots, (-1, self._hard_k)) top_dots = np.take(s_top_dots, undo_sort, axis=0) merged_top_dots = np.moveaxis( np.reshape(top_dots, (self.n_hashes, seqlen, self._hard_k)), 0, -1) merged_top_dots = np.reshape(merged_top_dots, (seqlen, -1)) dots_thresh = np.sort(merged_top_dots)[:, -self._hard_k] # It's possible to compute the partition function at this point, but right # now this codepath isn't set up for backprop, and there might also be # issues computing it this way if two dot-products are exactly equal. sdots_thresh = dots_thresh[st] bdots_thresh = np.reshape(sdots_thresh, (self.n_hashes * self.n_bins, -1)) bdots_thresh = jax.lax.stop_gradient(bdots_thresh) top_k_mask = jax.lax.convert_element_type( dots < bdots_thresh[..., None], np.float32) dots = dots - 1e7 * jax.lax.stop_gradient(top_k_mask) # Softmax. dots_logsumexp = backend.logsumexp(dots, axis=-1, keepdims=True) dots = np.exp(dots - dots_logsumexp) if self._dropout > 0.0: # Dropout is broadcast across the bin dimension dropout_shape = (1, dots.shape[-2], dots.shape[-1]) keep_prob = jax.lax.tie_in(dots, 1.0 - self._dropout) keep = backend.random.bernoulli(rng, keep_prob, dropout_shape) multiplier = keep.astype(dots.dtype) / jax.lax.tie_in( keep, keep_prob) dots = dots * multiplier bo = np.matmul(dots, bv) so = np.reshape(bo, (-1, bo.shape[-1])) slogits = np.reshape(dots_logsumexp, (-1, )) def unsort_for_output_impl(so, slogits): o = np.take(so, undo_sort, axis=0) # Sorting is considerably faster than gather, but first we need to get the # XLA compiler to abandon the idea of fusing this sort with the input sort # (which introduces a computation cycle and leads to a crash). # TODO(kitaev): remove "sticker_" variable if XLA is fixed. sticker_ = sticker + jax.lax.convert_element_type( slogits[0] > 0, sticker.dtype) _, logits = jax.lax.sort_key_val(sticker_, slogits, dimension=-1) return o, logits def unsort_for_output_vjp(so, slogits): """Custom gradient for unsort_for_output.""" so = jax.lax.stop_gradient(so) slogits = jax.lax.stop_gradient(slogits) o, logits = unsort_for_output_impl(so, slogits) def vjpfun(o_logits_grads): so_grad = np.take(o_logits_grads[0], sticker, axis=0) # TODO(kitaev): this exists to match the forward pass, but I'm not sure # if it's actually required. buckets_and_t_ = buckets_and_t + jax.lax.convert_element_type( o_logits_grads[1][0] > 0, buckets_and_t.dtype) _, slogits_grad = jax.lax.sort_key_val(buckets_and_t_, o_logits_grads[1], dimension=-1) return (so_grad, slogits_grad) return (o, logits), vjpfun unsort_for_output = jax.custom_transforms(unsort_for_output_impl) jax.defvjp_all(unsort_for_output, unsort_for_output_vjp) o, logits = unsort_for_output_impl(so, slogits) if self.n_hashes == 1: out = o else: o = np.reshape(o, (self.n_hashes, seqlen, o.shape[-1])) logits = np.reshape(logits, (self.n_hashes, seqlen, 1)) probs = np.exp(logits - backend.logsumexp(logits, axis=0, keepdims=True)) out = np.sum(o * probs, axis=0) assert out.shape == v.shape return out
def hash_vectors(self, vecs, rng): # See https://arxiv.org/pdf/1509.02897.pdf # We sample a different random rotation for each round of hashing to # decrease the probability of hash misses. assert self.n_buckets % 2 == 0 # If we factorize the hash, find a factor dividing n_buckets nicely. rot_size, factor_list = self.n_buckets, [self.n_buckets] if self._factorize_hash: # If we are given a list of factors, verify it and use later. if isinstance(self._factorize_hash, list): rot_size, product = 0, 1 factor_list = self._factorize_hash for factor in factor_list: assert factor % 2 == 0 product *= factor rot_size += factor assert product == self.n_buckets else: # Find one factor if just set to True. # We want to represent self.n_buckets = factor * rest so that # (1) both factor and rest are even, and (2) factor + rest is minimal. # To compute this we start from factor = sqrt(n_buckets) and go down # with it until we find one that satisfies the constraints above. factor = int(math.sqrt(self.n_buckets)) while factor > 0 and not (self.n_buckets % factor == 0 and factor % 2 == 0 and (self.n_buckets // factor) % 2 == 0): factor -= 1 if factor > 2: # Factor of 2 does not warrant the effort. rot_size = factor + (self.n_buckets // factor) factor_list = [factor, self.n_buckets // factor] rotations_shape = (vecs.shape[-1], self.n_hashes if self._rehash_each_round else 1, rot_size // 2) rng = jax.lax.tie_in(vecs, rng) rng, subrng = backend.random.split(rng) random_rotations = self._sample_rotation(rotations_shape, vecs, rng) # TODO(lukaszkaiser): the dropout mask will be used for all rounds of # hashing, so it's shared between them. Check if that's what we want. dropped_vecs = self.drop_for_hash(vecs, subrng) rotated_vecs = np.einsum('tf,fhb->htb', dropped_vecs, random_rotations) if self._rehash_each_round: if self._factorize_hash and len(factor_list) > 1: # We factorized self.n_buckets as the product of factor_list. # Get the buckets for them and combine. buckets, cur_sum, cur_product = None, 0, 1 for factor in factor_list: rv = rotated_vecs[..., cur_sum:cur_sum + (factor // 2)] cur_sum += factor // 2 rv = np.concatenate([rv, -rv], axis=-1) if buckets is None: buckets = np.argmax(rv, axis=-1) else: buckets += cur_product * np.argmax(rv, axis=-1) cur_product *= factor else: rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1) buckets = np.argmax(rotated_vecs, axis=-1) # buckets is now (self.n_hashes, seqlen). Next we add offsets so that # bucket numbers from different hashing rounds don't overlap. offsets = jax.lax.tie_in(buckets, np.arange(self.n_hashes)) offsets = np.reshape(offsets * self.n_buckets, (-1, 1)) buckets = np.reshape(buckets + offsets, (-1, )) else: assert not self._factorize_hash rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1) # In this configuration, we map each item to the top self.n_hashes buckets rotated_vecs = np.squeeze(rotated_vecs, 0) bucket_range = jax.lax.tie_in(vecs, np.arange(rotated_vecs.shape[-1])) bucket_range = np.reshape(bucket_range, (1, -1)) bucket_range = np.broadcast_to(bucket_range, rotated_vecs.shape) _, buckets = jax.lax.sort_key_val(rotated_vecs, bucket_range, dimension=-1) buckets = buckets[:, -self.n_hashes:] buckets = np.reshape(np.moveaxis(buckets, 0, -1), (-1, )) return buckets