def f(x): if n > 1 and backend.get_name() == 'jax': return _multi_device_put(x) elif n > 1: return np.broadcast_to(x, (n, ) + x.shape) else: return x
def forward_with_state(self, inputs, weights=base.EMPTY_WEIGHTS, state=base.EMPTY_STATE, rng=None, **kwargs): embs = [] for ax_emb in weights: ax_emb = np.broadcast_to(ax_emb, (inputs.shape[0], ) + self._shape + (ax_emb.shape[-1], )) embs.append(ax_emb) emb = np.concatenate(embs, -1) if self._mode == 'predict': assert self._dropout == 0.0 emb = np.reshape(emb, (inputs.shape[0], -1, emb.shape[-1])) return inputs + emb[:, state, :][:, None, :], state + 1 elif self._dropout == 0: return inputs + np.reshape(emb, inputs.shape), state else: noise_shape = list(emb.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout if backend.get_name() == 'jax': keep_prob = jax.lax.tie_in( inputs, np.full((), keep_prob, dtype=inputs.dtype)) keep = backend.random.bernoulli(rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(inputs.dtype) / keep_prob return inputs + np.reshape(emb * multiplier, inputs.shape), state
def _maybe_replicate(self, x): if self._n_devices > 1: if backend.get_name() == 'jax': return multi_device_put(x) else: return np.broadcast_to(x, (self._n_devices,) + x.shape) else: return x
def single_call(self, qk, v, buckets, rng=None): # We use the same vector as both a query and a key. seqlen = qk.shape[-2] assert int(buckets.shape[0]) == self.n_hashes * seqlen ticker = jax.lax.tie_in(qk, np.arange(self.n_hashes * seqlen)) buckets_and_t = seqlen * buckets + (ticker % seqlen) buckets_and_t = jax.lax.stop_gradient(buckets_and_t) # Hash-based sort ("s" at the start of variable names means "sorted") sbuckets_and_t, sticker = jax.lax.sort_key_val(buckets_and_t, ticker, dimension=-1) _, undo_sort = jax.lax.sort_key_val(sticker, ticker, dimension=-1) sbuckets_and_t = jax.lax.stop_gradient(sbuckets_and_t) sticker = jax.lax.stop_gradient(sticker) undo_sort = jax.lax.stop_gradient(undo_sort) st = (sticker % seqlen) sqk = np.take(qk, st, axis=0) sv = np.take(v, st, axis=0) # Split off a "bin" axis so that attention only occurs within chunks. bq_t = bkv_t = np.reshape(st, (self.n_hashes * self.n_bins, -1)) bqk = np.reshape(sqk, (self.n_hashes * self.n_bins, -1, sqk.shape[-1])) bv = np.reshape(sv, (self.n_hashes * self.n_bins, -1, sv.shape[-1])) bq_buckets = bkv_buckets = np.reshape( sbuckets_and_t // seqlen, (self.n_hashes * self.n_bins, -1)) # Hashing operates on unit-length vectors. Unnormalized query vectors are # fine because they effectively provide a learnable temperature for the # attention softmax, but normalizing keys is needed so that similarity for # the purposes of attention correctly corresponds to hash locality. bq = bqk bk = self.make_unit_length(bqk) # Allow each chunk to attend within itself, and also one chunk back. Chunk # boundaries might occur in the middle of a sequence of items from the # same bucket, so this increases the chances of attending to relevant items. # TODO(kitaev): benchmark whether XLA pad operation is noticeably faster. def look_one_back(x): if len(x.shape) == 2: x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0) else: x_extra = np.concatenate([x[-1:, :, :], x[:-1, :, :]], axis=0) return np.concatenate([x, x_extra], axis=1) bk = look_one_back(bk) bv = look_one_back(bv) bkv_t = look_one_back(bkv_t) bkv_buckets = look_one_back(bkv_buckets) # Dot-product attention. dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(bq.shape[-1]) # Causal masking mask = jax.lax.convert_element_type( jax.lax.lt(bq_t[:, :, None], bkv_t[:, None, :]), np.float32) dots = dots - 1e9 * mask # Mask out attention to self except when no other targets are available. self_mask = jax.lax.convert_element_type( jax.lax.eq(bq_t[:, :, None], bkv_t[:, None, :]), np.float32) dots = dots - 1e5 * self_mask # Mask out attention to other hash buckets. if not self._attend_across_buckets: bucket_mask = jax.lax.convert_element_type( jax.lax.ne(bq_buckets[:, :, None], bkv_buckets[:, None, :]), np.float32) dots = dots - 1e7 * bucket_mask # Don't double-count query-key pairs across multiple rounds of hashing. # There are two possible strategies here. (1) The default is to count how # many times a query-key pair is repeated, and to lower its log-prob # correspondingly at each repetition. (2) When hard_k is set, the code # instead masks all but the first occurence of each query-key pair. # TODO(kitaev): is one strategy faster or more numerically stable? if not self._allow_duplicate_attention: locs1 = undo_sort // bq_t.shape[-1] locs2 = (locs1 + 1) % (self.n_hashes * self.n_bins) if not self._attend_across_buckets: locs1 = buckets * (self.n_hashes * self.n_bins) + locs1 locs2 = buckets * (self.n_hashes * self.n_bins) + locs2 locs = np.moveaxis( np.concatenate([ np.reshape(locs1, (self.n_hashes, seqlen)), np.reshape(locs2, (self.n_hashes, seqlen)), ], 0), 0, -1) # produces shape (seqlen, 2 * self.n_hashes) slocs = np.take(locs, st, axis=0) b_locs = np.reshape( slocs, (self.n_hashes * self.n_bins, -1, 2 * self.n_hashes)) # Queries always use the primary location (based on locs1). b_locs1 = b_locs[:, :, None, :self.n_hashes] if self._hard_k > 0: range_n_hashes = jax.lax.tie_in(b_locs, np.arange(self.n_hashes)) nouse_locs = (range_n_hashes[:, None] > range_n_hashes[None, :]) nouse_locs = 2 * nouse_locs - 1 # 1 = use, -1 = don't use nouse_locs = np.reshape( np.broadcast_to( nouse_locs[:, None, :], (self.n_hashes, self.n_bins, self.n_hashes)), (self.n_hashes * self.n_bins, 1, 1, self.n_hashes)) b_locs1 = b_locs1 * nouse_locs bq_locs = np.broadcast_to(b_locs1, b_locs.shape[:2] + (2, self.n_hashes)) bq_locs = np.reshape(bq_locs, b_locs.shape) bkv_locs = look_one_back(b_locs) dup_counts = np.sum(jax.lax.convert_element_type( jax.lax.eq(bq_locs[:, :, None, :], bkv_locs[:, None, :, :]), np.float32), axis=-1) assert dup_counts.shape == dots.shape if self._hard_k > 0: dots = dots - 1e7 * jax.lax.stop_gradient(dup_counts) else: dots = dots - jax.lax.stop_gradient(np.log(dup_counts + 1e-9)) # Each query only attends to the top k most relevant keys. if self._hard_k > 0: b_top_dots = np.sort(dots)[..., -self._hard_k:] # Get the top k dots. b_top_dots = jax.lax.stop_gradient(b_top_dots) s_top_dots = np.reshape(b_top_dots, (-1, self._hard_k)) top_dots = np.take(s_top_dots, undo_sort, axis=0) merged_top_dots = np.moveaxis( np.reshape(top_dots, (self.n_hashes, seqlen, self._hard_k)), 0, -1) merged_top_dots = np.reshape(merged_top_dots, (seqlen, -1)) dots_thresh = np.sort(merged_top_dots)[:, -self._hard_k] # It's possible to compute the partition function at this point, but right # now this codepath isn't set up for backprop, and there might also be # issues computing it this way if two dot-products are exactly equal. sdots_thresh = dots_thresh[st] bdots_thresh = np.reshape(sdots_thresh, (self.n_hashes * self.n_bins, -1)) bdots_thresh = jax.lax.stop_gradient(bdots_thresh) top_k_mask = jax.lax.convert_element_type( dots < bdots_thresh[..., None], np.float32) dots = dots - 1e7 * jax.lax.stop_gradient(top_k_mask) # Softmax. dots_logsumexp = backend.logsumexp(dots, axis=-1, keepdims=True) dots = np.exp(dots - dots_logsumexp) if self._dropout > 0.0: # Dropout is broadcast across the bin dimension dropout_shape = (1, dots.shape[-2], dots.shape[-1]) keep_prob = jax.lax.tie_in(dots, 1.0 - self._dropout) keep = backend.random.bernoulli(rng, keep_prob, dropout_shape) multiplier = keep.astype(dots.dtype) / jax.lax.tie_in( keep, keep_prob) dots = dots * multiplier bo = np.matmul(dots, bv) so = np.reshape(bo, (-1, bo.shape[-1])) slogits = np.reshape(dots_logsumexp, (-1, )) def unsort_for_output_impl(so, slogits): o = np.take(so, undo_sort, axis=0) # Sorting is considerably faster than gather, but first we need to get the # XLA compiler to abandon the idea of fusing this sort with the input sort # (which introduces a computation cycle and leads to a crash). # TODO(kitaev): remove "sticker_" variable if XLA is fixed. sticker_ = sticker + jax.lax.convert_element_type( slogits[0] > 0, sticker.dtype) _, logits = jax.lax.sort_key_val(sticker_, slogits, dimension=-1) return o, logits def unsort_for_output_vjp(so, slogits): """Custom gradient for unsort_for_output.""" so = jax.lax.stop_gradient(so) slogits = jax.lax.stop_gradient(slogits) o, logits = unsort_for_output_impl(so, slogits) def vjpfun(o_logits_grads): so_grad = np.take(o_logits_grads[0], sticker, axis=0) # TODO(kitaev): this exists to match the forward pass, but I'm not sure # if it's actually required. buckets_and_t_ = buckets_and_t + jax.lax.convert_element_type( o_logits_grads[1][0] > 0, buckets_and_t.dtype) _, slogits_grad = jax.lax.sort_key_val(buckets_and_t_, o_logits_grads[1], dimension=-1) return (so_grad, slogits_grad) return (o, logits), vjpfun unsort_for_output = jax.custom_transforms(unsort_for_output_impl) jax.defvjp_all(unsort_for_output, unsort_for_output_vjp) o, logits = unsort_for_output_impl(so, slogits) if self.n_hashes == 1: out = o else: o = np.reshape(o, (self.n_hashes, seqlen, o.shape[-1])) logits = np.reshape(logits, (self.n_hashes, seqlen, 1)) probs = np.exp(logits - backend.logsumexp(logits, axis=0, keepdims=True)) out = np.sum(o * probs, axis=0) assert out.shape == v.shape return out
def hash_vectors(self, vecs, rng): # See https://arxiv.org/pdf/1509.02897.pdf # We sample a different random rotation for each round of hashing to # decrease the probability of hash misses. assert self.n_buckets % 2 == 0 # If we factorize the hash, find a factor dividing n_buckets nicely. rot_size, factor_list = self.n_buckets, [self.n_buckets] if self._factorize_hash: # If we are given a list of factors, verify it and use later. if isinstance(self._factorize_hash, list): rot_size, product = 0, 1 factor_list = self._factorize_hash for factor in factor_list: assert factor % 2 == 0 product *= factor rot_size += factor assert product == self.n_buckets else: # Find one factor if just set to True. # We want to represent self.n_buckets = factor * rest so that # (1) both factor and rest are even, and (2) factor + rest is minimal. # To compute this we start from factor = sqrt(n_buckets) and go down # with it until we find one that satisfies the constraints above. factor = int(math.sqrt(self.n_buckets)) while factor > 0 and not (self.n_buckets % factor == 0 and factor % 2 == 0 and (self.n_buckets // factor) % 2 == 0): factor -= 1 if factor > 2: # Factor of 2 does not warrant the effort. rot_size = factor + (self.n_buckets // factor) factor_list = [factor, self.n_buckets // factor] rotations_shape = (vecs.shape[-1], self.n_hashes if self._rehash_each_round else 1, rot_size // 2) rng = jax.lax.tie_in(vecs, rng) rng, subrng = backend.random.split(rng) random_rotations = self._sample_rotation(rotations_shape, vecs, rng) # TODO(lukaszkaiser): the dropout mask will be used for all rounds of # hashing, so it's shared between them. Check if that's what we want. dropped_vecs = self.drop_for_hash(vecs, subrng) rotated_vecs = np.einsum('tf,fhb->htb', dropped_vecs, random_rotations) if self._rehash_each_round: if self._factorize_hash and len(factor_list) > 1: # We factorized self.n_buckets as the product of factor_list. # Get the buckets for them and combine. buckets, cur_sum, cur_product = None, 0, 1 for factor in factor_list: rv = rotated_vecs[..., cur_sum:cur_sum + (factor // 2)] cur_sum += factor // 2 rv = np.concatenate([rv, -rv], axis=-1) if buckets is None: buckets = np.argmax(rv, axis=-1) else: buckets += cur_product * np.argmax(rv, axis=-1) cur_product *= factor else: rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1) buckets = np.argmax(rotated_vecs, axis=-1) # buckets is now (self.n_hashes, seqlen). Next we add offsets so that # bucket numbers from different hashing rounds don't overlap. offsets = jax.lax.tie_in(buckets, np.arange(self.n_hashes)) offsets = np.reshape(offsets * self.n_buckets, (-1, 1)) buckets = np.reshape(buckets + offsets, (-1, )) else: assert not self._factorize_hash rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1) # In this configuration, we map each item to the top self.n_hashes buckets rotated_vecs = np.squeeze(rotated_vecs, 0) bucket_range = jax.lax.tie_in(vecs, np.arange(rotated_vecs.shape[-1])) bucket_range = np.reshape(bucket_range, (1, -1)) bucket_range = np.broadcast_to(bucket_range, rotated_vecs.shape) _, buckets = jax.lax.sort_key_val(rotated_vecs, bucket_range, dimension=-1) buckets = buckets[:, -self.n_hashes:] buckets = np.reshape(np.moveaxis(buckets, 0, -1), (-1, )) return buckets