def CombineHeadsPos(x, n_heads=1, **unused_kwargs): """Mix x = (x0, p0, ..., xH, pH) into (x0, ...., xH), p_combined. The positions are averaged as vectors. Args: x: input vector, concatenated (x0, p0, ..., xH, pH). n_heads: number of heads. Returns: the vector with combined xs and one with combined positions. """ seqlen = x.shape[1] d_head = x.shape[2] x = np.reshape(x, (-1, n_heads, seqlen, d_head)) x = np.transpose(x, (0, 2, 1, 3)) # -> n_batch, seqlen, n_heads, d_head x = np.reshape(x, (-1, seqlen, n_heads * d_head)) head_size = int(d_head) - POS_VECTOR_SIZE res, positions, idx = [], [], 0 for _ in range(n_heads): res.append(x[:, :, idx:idx+head_size]) idx += head_size positions.append(x[:, :, idx:idx+POS_VECTOR_SIZE]) idx += POS_VECTOR_SIZE combined_position = sum(positions) / float(len(positions)) return np.concatenate(res, axis=-1), combined_position
def forward_with_state(self, inputs, weights=base.EMPTY_WEIGHTS, state=base.EMPTY_STATE, rng=None, **kwargs): embs = [] for ax_emb in weights: ax_emb = np.broadcast_to(ax_emb, (inputs.shape[0], ) + self._shape + (ax_emb.shape[-1], )) embs.append(ax_emb) emb = np.concatenate(embs, -1) if self._mode == 'predict': assert self._dropout == 0.0 emb = np.reshape(emb, (inputs.shape[0], -1, emb.shape[-1])) return inputs + emb[:, state, :][:, None, :], state + 1 elif self._dropout == 0: return inputs + np.reshape(emb, inputs.shape), state else: noise_shape = list(emb.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout if backend.get_name() == 'jax': keep_prob = jax.lax.tie_in( inputs, np.full((), keep_prob, dtype=inputs.dtype)) keep = backend.random.bernoulli(rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(inputs.dtype) / keep_prob return inputs + np.reshape(emb * multiplier, inputs.shape), state
def forward(self, x, weights): seqlen = x.shape[1] d_head = x.shape[2] x = np.reshape(x, (-1, self._n_heads, seqlen, d_head)) x = np.transpose(x, (0, 2, 1, 3)) # -> n_batch, seqlen, n_heads, d_head x = np.reshape(x, (-1, seqlen, self._n_heads * d_head)) return np.dot(x, weights)
def forward(self, x, weights): w, b = weights x_shape = list(x.shape) if len(x_shape) > 4: self._check_nhwc() new_batch_dim = six.moves.reduce(operator.mul, x_shape[:-3]) x = np.reshape(x, [new_batch_dim] + x_shape[-3:]) res = backend.conv(x, w, self._strides, self._padding, self._dimension_numbers, self._one) + b if len(x_shape) > 4: res = np.reshape(res, x_shape[:-3] + list(res.shape[-3:])) return res
def forward(self, x, weights): seqlen = x.shape[1] res = np.dot(x, weights) # n_batch, seqlen, n_heads*d_head -> n_batch, seqlen, n_heads, d_head res = np.reshape(res, (x.shape[0], seqlen, self._n_heads, self._d_head)) # n_batch, seqlen, n_heads, d_head -> n_batch, n_heads, seqlen, d_head res = np.transpose(res, (0, 2, 1, 3)) # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head res = np.reshape(res, (-1, seqlen, self._d_head)) return res
def multigaussian_loss(preds, targets, ngauss=1): # pylint: disable=invalid-name """Compute mixture of gaussians loss.""" ndims = targets.shape[-1] logits = preds[:, :ngauss] mus = preds[:, ngauss:ngauss*(ndims + 1)] sigmas = preds[:, ngauss(ndims + 1):] sigmas = sigmas * sigmas + 1e-6 # Make positive. loglogits = logits - backend.logsumexp(logits, axis=-1, keepdims=True) mus = np.reshape(mus, [-1, ngauss, ndims]) sigmas = np.reshape(sigmas, [-1, ngauss, ndims]) targets = np.reshape(targets, [-1, 1, ndims]) glogprobs = log_gaussian_diag_pdf(targets, mus, sigmas) return backend.logsumexp(loglogits + glogprobs, axis=-1)
def EncoderDecoderMask(x, **unused_kwargs): """Makes encoder-decoder mask from decoder input and a padding mask.""" decoder_input, padding_mask = x padding_mask = np.reshape( padding_mask, (padding_mask.shape[0], 1, 1, padding_mask.shape[-1])) # Final mask shape is [batch, 1 for heads, decoder-len, encoder-len]. return padding_mask + np.zeros((1, 1, decoder_input.shape[1], 1))
def Unchunk(x, weights, n_sections=2, **kwargs): del weights, kwargs assert x.shape[0] % n_sections == 0 return np.reshape(x, ( x.shape[0] // n_sections, x.shape[1] * n_sections, ) + x.shape[2:])
def Init(shape, rng): """Returns orthogonalized random normal values with the given `shape`.""" # Have at least 2 elements in shape. cur_shape = list(shape) while len(cur_shape) < 2: cur_shape = [1] + cur_shape # Flatten the input shape with the last dimension remaining. n_rows = 1 for dim in cur_shape[:-1]: n_rows *= dim n_cols = cur_shape[-1] flat_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows, n_cols) # Generate a random matrix a = random.normal(rng, flat_shape, dtype=np.float32) # Compute the qr factorization q, r = np.linalg.qr(a) # Make Q uniform d = np.diag(r) q *= np.sign(d) # Transpose and reshape back q if needed. if n_rows < n_cols: q = np.transpose(q) q = np.reshape(q, shape) # Return scaled as requested. return stddev * q
def forward(self, inp, weights): """Reshape input to have heads dimension and concatenate positions there.""" x = inp[0] n_batches, seqlen = x.shape[0], x.shape[1] d_head = x.shape[-1] // self._n_heads res = np.reshape(x, (n_batches, seqlen, self._n_heads, d_head)) res = np.transpose(res, (0, 2, 1, 3)) # (batch, heads, len, depth) if self._n_pos == 1: # Just one position given, tile into each head. pos_shape = list(res.shape)[:-1] + [inp[1].shape[-1]] pos = inp[1][:, None, :, :] + np.zeros(pos_shape) # Add 0 to broadcast. else: # As many positions as heads, concatenate them in. pos = [p[:, None, :, :] for p in inp[1:]] pos = np.concatenate(pos, axis=1) res = np.concatenate([res, pos], axis=-1) # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head res = np.reshape(res, (-1, seqlen, d_head + POS_VECTOR_SIZE)) return res
def _update_sketched(self, grads, params, m, v, opt_params): """Update for higher-rank parameters.""" learning_rate = opt_params['learning_rate'] momentum = opt_params['momentum'] shape = params.shape rank = len(shape) reshaped_accumulators = [np.reshape(v[i], self._expanded_shape(shape, i)) for i in range(rank)] current_accumulator = self._minimum(reshaped_accumulators) current_accumulator += grads * grads accumulator_inv_sqrt = np.where(current_accumulator > 0.0, 1.0 / np.sqrt(current_accumulator), np.zeros_like(current_accumulator)) preconditioned_gradient = grads * accumulator_inv_sqrt m = (1.0 - momentum) * preconditioned_gradient + momentum * m params = params - (learning_rate * m).astype(params.dtype) for i in range(len(v)): axes = list(range(int(i))) + list(range(int(i) + 1, rank)) dim_accumulator = np.amax(current_accumulator, axis=axes) v[i] = dim_accumulator return params, (m, v)
def test_batch_norm(self): input_shape = (2, 3, 4) input_dtype = np.float32 input_signature = ShapeDtype(input_shape, input_dtype) eps = 1e-5 inp1 = np.reshape(np.arange(np.prod(input_shape), dtype=input_dtype), input_shape) m1 = 11.5 # Mean of this random input. v1 = 47.9167 # Variance of this random input. layer = normalization.BatchNorm(axis=(0, 1, 2)) _, _ = layer.initialize_once(input_signature) state = layer.state onp.testing.assert_allclose(state[0], 0) onp.testing.assert_allclose(state[1], 1) self.assertEqual(state[2], 0) out = layer(inp1) state = layer.state onp.testing.assert_allclose(state[0], m1 * 0.001) onp.testing.assert_allclose(state[1], 0.999 + v1 * 0.001, rtol=1e-6) self.assertEqual(state[2], 1) onp.testing.assert_allclose(out, (inp1 - m1) / np.sqrt(v1 + eps), rtol=1e-6)
def single_call(self, qk, v, buckets, rng=None): # We use the same vector as both a query and a key. seqlen = qk.shape[-2] assert int(buckets.shape[0]) == self.n_hashes * seqlen ticker = jax.lax.tie_in(qk, np.arange(self.n_hashes * seqlen)) buckets_and_t = seqlen * buckets + (ticker % seqlen) buckets_and_t = jax.lax.stop_gradient(buckets_and_t) # Hash-based sort ("s" at the start of variable names means "sorted") sbuckets_and_t, sticker = jax.lax.sort_key_val(buckets_and_t, ticker, dimension=-1) _, undo_sort = jax.lax.sort_key_val(sticker, ticker, dimension=-1) sbuckets_and_t = jax.lax.stop_gradient(sbuckets_and_t) sticker = jax.lax.stop_gradient(sticker) undo_sort = jax.lax.stop_gradient(undo_sort) st = (sticker % seqlen) sqk = np.take(qk, st, axis=0) sv = np.take(v, st, axis=0) # Split off a "bin" axis so that attention only occurs within chunks. bq_t = bkv_t = np.reshape(st, (self.n_hashes * self.n_bins, -1)) bqk = np.reshape(sqk, (self.n_hashes * self.n_bins, -1, sqk.shape[-1])) bv = np.reshape(sv, (self.n_hashes * self.n_bins, -1, sv.shape[-1])) bq_buckets = bkv_buckets = np.reshape( sbuckets_and_t // seqlen, (self.n_hashes * self.n_bins, -1)) # Hashing operates on unit-length vectors. Unnormalized query vectors are # fine because they effectively provide a learnable temperature for the # attention softmax, but normalizing keys is needed so that similarity for # the purposes of attention correctly corresponds to hash locality. bq = bqk bk = self.make_unit_length(bqk) # Allow each chunk to attend within itself, and also one chunk back. Chunk # boundaries might occur in the middle of a sequence of items from the # same bucket, so this increases the chances of attending to relevant items. # TODO(kitaev): benchmark whether XLA pad operation is noticeably faster. def look_one_back(x): if len(x.shape) == 2: x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0) else: x_extra = np.concatenate([x[-1:, :, :], x[:-1, :, :]], axis=0) return np.concatenate([x, x_extra], axis=1) bk = look_one_back(bk) bv = look_one_back(bv) bkv_t = look_one_back(bkv_t) bkv_buckets = look_one_back(bkv_buckets) # Dot-product attention. dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(bq.shape[-1]) # Causal masking mask = jax.lax.convert_element_type( jax.lax.lt(bq_t[:, :, None], bkv_t[:, None, :]), np.float32) dots = dots - 1e9 * mask # Mask out attention to self except when no other targets are available. self_mask = jax.lax.convert_element_type( jax.lax.eq(bq_t[:, :, None], bkv_t[:, None, :]), np.float32) dots = dots - 1e5 * self_mask # Mask out attention to other hash buckets. if not self._attend_across_buckets: bucket_mask = jax.lax.convert_element_type( jax.lax.ne(bq_buckets[:, :, None], bkv_buckets[:, None, :]), np.float32) dots = dots - 1e7 * bucket_mask # Don't double-count query-key pairs across multiple rounds of hashing. # There are two possible strategies here. (1) The default is to count how # many times a query-key pair is repeated, and to lower its log-prob # correspondingly at each repetition. (2) When hard_k is set, the code # instead masks all but the first occurence of each query-key pair. # TODO(kitaev): is one strategy faster or more numerically stable? if not self._allow_duplicate_attention: locs1 = undo_sort // bq_t.shape[-1] locs2 = (locs1 + 1) % (self.n_hashes * self.n_bins) if not self._attend_across_buckets: locs1 = buckets * (self.n_hashes * self.n_bins) + locs1 locs2 = buckets * (self.n_hashes * self.n_bins) + locs2 locs = np.moveaxis( np.concatenate([ np.reshape(locs1, (self.n_hashes, seqlen)), np.reshape(locs2, (self.n_hashes, seqlen)), ], 0), 0, -1) # produces shape (seqlen, 2 * self.n_hashes) slocs = np.take(locs, st, axis=0) b_locs = np.reshape( slocs, (self.n_hashes * self.n_bins, -1, 2 * self.n_hashes)) # Queries always use the primary location (based on locs1). b_locs1 = b_locs[:, :, None, :self.n_hashes] if self._hard_k > 0: range_n_hashes = jax.lax.tie_in(b_locs, np.arange(self.n_hashes)) nouse_locs = (range_n_hashes[:, None] > range_n_hashes[None, :]) nouse_locs = 2 * nouse_locs - 1 # 1 = use, -1 = don't use nouse_locs = np.reshape( np.broadcast_to( nouse_locs[:, None, :], (self.n_hashes, self.n_bins, self.n_hashes)), (self.n_hashes * self.n_bins, 1, 1, self.n_hashes)) b_locs1 = b_locs1 * nouse_locs bq_locs = np.broadcast_to(b_locs1, b_locs.shape[:2] + (2, self.n_hashes)) bq_locs = np.reshape(bq_locs, b_locs.shape) bkv_locs = look_one_back(b_locs) dup_counts = np.sum(jax.lax.convert_element_type( jax.lax.eq(bq_locs[:, :, None, :], bkv_locs[:, None, :, :]), np.float32), axis=-1) assert dup_counts.shape == dots.shape if self._hard_k > 0: dots = dots - 1e7 * jax.lax.stop_gradient(dup_counts) else: dots = dots - jax.lax.stop_gradient(np.log(dup_counts + 1e-9)) # Each query only attends to the top k most relevant keys. if self._hard_k > 0: b_top_dots = np.sort(dots)[..., -self._hard_k:] # Get the top k dots. b_top_dots = jax.lax.stop_gradient(b_top_dots) s_top_dots = np.reshape(b_top_dots, (-1, self._hard_k)) top_dots = np.take(s_top_dots, undo_sort, axis=0) merged_top_dots = np.moveaxis( np.reshape(top_dots, (self.n_hashes, seqlen, self._hard_k)), 0, -1) merged_top_dots = np.reshape(merged_top_dots, (seqlen, -1)) dots_thresh = np.sort(merged_top_dots)[:, -self._hard_k] # It's possible to compute the partition function at this point, but right # now this codepath isn't set up for backprop, and there might also be # issues computing it this way if two dot-products are exactly equal. sdots_thresh = dots_thresh[st] bdots_thresh = np.reshape(sdots_thresh, (self.n_hashes * self.n_bins, -1)) bdots_thresh = jax.lax.stop_gradient(bdots_thresh) top_k_mask = jax.lax.convert_element_type( dots < bdots_thresh[..., None], np.float32) dots = dots - 1e7 * jax.lax.stop_gradient(top_k_mask) # Softmax. dots_logsumexp = backend.logsumexp(dots, axis=-1, keepdims=True) dots = np.exp(dots - dots_logsumexp) if self._dropout > 0.0: # Dropout is broadcast across the bin dimension dropout_shape = (1, dots.shape[-2], dots.shape[-1]) keep_prob = jax.lax.tie_in(dots, 1.0 - self._dropout) keep = backend.random.bernoulli(rng, keep_prob, dropout_shape) multiplier = keep.astype(dots.dtype) / jax.lax.tie_in( keep, keep_prob) dots = dots * multiplier bo = np.matmul(dots, bv) so = np.reshape(bo, (-1, bo.shape[-1])) slogits = np.reshape(dots_logsumexp, (-1, )) def unsort_for_output_impl(so, slogits): o = np.take(so, undo_sort, axis=0) # Sorting is considerably faster than gather, but first we need to get the # XLA compiler to abandon the idea of fusing this sort with the input sort # (which introduces a computation cycle and leads to a crash). # TODO(kitaev): remove "sticker_" variable if XLA is fixed. sticker_ = sticker + jax.lax.convert_element_type( slogits[0] > 0, sticker.dtype) _, logits = jax.lax.sort_key_val(sticker_, slogits, dimension=-1) return o, logits def unsort_for_output_vjp(so, slogits): """Custom gradient for unsort_for_output.""" so = jax.lax.stop_gradient(so) slogits = jax.lax.stop_gradient(slogits) o, logits = unsort_for_output_impl(so, slogits) def vjpfun(o_logits_grads): so_grad = np.take(o_logits_grads[0], sticker, axis=0) # TODO(kitaev): this exists to match the forward pass, but I'm not sure # if it's actually required. buckets_and_t_ = buckets_and_t + jax.lax.convert_element_type( o_logits_grads[1][0] > 0, buckets_and_t.dtype) _, slogits_grad = jax.lax.sort_key_val(buckets_and_t_, o_logits_grads[1], dimension=-1) return (so_grad, slogits_grad) return (o, logits), vjpfun unsort_for_output = jax.custom_transforms(unsort_for_output_impl) jax.defvjp_all(unsort_for_output, unsort_for_output_vjp) o, logits = unsort_for_output_impl(so, slogits) if self.n_hashes == 1: out = o else: o = np.reshape(o, (self.n_hashes, seqlen, o.shape[-1])) logits = np.reshape(logits, (self.n_hashes, seqlen, 1)) probs = np.exp(logits - backend.logsumexp(logits, axis=0, keepdims=True)) out = np.sum(o * probs, axis=0) assert out.shape == v.shape return out
def hash_vectors(self, vecs, rng): # See https://arxiv.org/pdf/1509.02897.pdf # We sample a different random rotation for each round of hashing to # decrease the probability of hash misses. assert self.n_buckets % 2 == 0 # If we factorize the hash, find a factor dividing n_buckets nicely. rot_size, factor_list = self.n_buckets, [self.n_buckets] if self._factorize_hash: # If we are given a list of factors, verify it and use later. if isinstance(self._factorize_hash, list): rot_size, product = 0, 1 factor_list = self._factorize_hash for factor in factor_list: assert factor % 2 == 0 product *= factor rot_size += factor assert product == self.n_buckets else: # Find one factor if just set to True. # We want to represent self.n_buckets = factor * rest so that # (1) both factor and rest are even, and (2) factor + rest is minimal. # To compute this we start from factor = sqrt(n_buckets) and go down # with it until we find one that satisfies the constraints above. factor = int(math.sqrt(self.n_buckets)) while factor > 0 and not (self.n_buckets % factor == 0 and factor % 2 == 0 and (self.n_buckets // factor) % 2 == 0): factor -= 1 if factor > 2: # Factor of 2 does not warrant the effort. rot_size = factor + (self.n_buckets // factor) factor_list = [factor, self.n_buckets // factor] rotations_shape = (vecs.shape[-1], self.n_hashes if self._rehash_each_round else 1, rot_size // 2) rng = jax.lax.tie_in(vecs, rng) rng, subrng = backend.random.split(rng) random_rotations = self._sample_rotation(rotations_shape, vecs, rng) # TODO(lukaszkaiser): the dropout mask will be used for all rounds of # hashing, so it's shared between them. Check if that's what we want. dropped_vecs = self.drop_for_hash(vecs, subrng) rotated_vecs = np.einsum('tf,fhb->htb', dropped_vecs, random_rotations) if self._rehash_each_round: if self._factorize_hash and len(factor_list) > 1: # We factorized self.n_buckets as the product of factor_list. # Get the buckets for them and combine. buckets, cur_sum, cur_product = None, 0, 1 for factor in factor_list: rv = rotated_vecs[..., cur_sum:cur_sum + (factor // 2)] cur_sum += factor // 2 rv = np.concatenate([rv, -rv], axis=-1) if buckets is None: buckets = np.argmax(rv, axis=-1) else: buckets += cur_product * np.argmax(rv, axis=-1) cur_product *= factor else: rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1) buckets = np.argmax(rotated_vecs, axis=-1) # buckets is now (self.n_hashes, seqlen). Next we add offsets so that # bucket numbers from different hashing rounds don't overlap. offsets = jax.lax.tie_in(buckets, np.arange(self.n_hashes)) offsets = np.reshape(offsets * self.n_buckets, (-1, 1)) buckets = np.reshape(buckets + offsets, (-1, )) else: assert not self._factorize_hash rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1) # In this configuration, we map each item to the top self.n_hashes buckets rotated_vecs = np.squeeze(rotated_vecs, 0) bucket_range = jax.lax.tie_in(vecs, np.arange(rotated_vecs.shape[-1])) bucket_range = np.reshape(bucket_range, (1, -1)) bucket_range = np.broadcast_to(bucket_range, rotated_vecs.shape) _, buckets = jax.lax.sort_key_val(rotated_vecs, bucket_range, dimension=-1) buckets = buckets[:, -self.n_hashes:] buckets = np.reshape(np.moveaxis(buckets, 0, -1), (-1, )) return buckets
def _sample_rotation(self, shape, vecs, rng): """Samples a rotation matrix, either randomly or based on `vecs`.""" if not self._data_rotation: return jax.random.normal(rng, shape).astype('float32') assert len(shape) == 3 unused_n_dim, n_hashes, r_div_2 = shape assert len(vecs.shape) == 2 n_vecs = vecs.shape[0] rng1, rng2 = backend.random.split(rng, num=2) # We need to sample 2 * n_hashes * r_div_2 vectors from `vecs` at random. num_needed = 2 * n_hashes * r_div_2 if n_vecs < num_needed: # shape = (n_hashes, r_div_2) random_idxs_1 = jax.random.randint(rng1, (n_hashes, r_div_2), 0, n_vecs) random_idxs_2 = jax.random.randint(rng2, (n_hashes, r_div_2), 0, n_vecs) else: # Sample without replacement. shuffled_indices = jax.random.shuffle(rng1, np.arange(n_vecs)) random_idxs = np.reshape(shuffled_indices[:num_needed], (2, n_hashes, r_div_2)) random_idxs_1 = random_idxs[0] random_idxs_2 = random_idxs[1] if self._data_rotation_farthest: # shape = (n_hashes * r_div_2, ) random_idxs_1 = np.reshape(random_idxs_1, (-1, )) random_vecs_1 = vecs[random_idxs_1] # Sample candidates for vec2s. rng, subrng = backend.random.split(rng) # shape = (self._data_rotation_farthest_num, n_hashes * r_div_2) candidate_idxs_2 = jax.random.randint( subrng, (self._data_rotation_farthest_num, n_hashes * r_div_2), 0, n_vecs) candidate_vecs_2 = vecs[candidate_idxs_2] # shape = candidate_idxs_2.shape distances = -np.abs( np.einsum('hd,chd->ch', random_vecs_1, candidate_vecs_2)) # shape = (n_hashes * r_div_2,) farthest_idxs = np.argmax(distances, axis=0) # candidate_vecs_2.shape random_vecs_2 = candidate_vecs_2[farthest_idxs, np.arange(n_hashes * r_div_2)] # reshape to (n_hashes, r_div_2, n_dim) random_vecs_1 = np.reshape(random_vecs_1, (n_hashes, r_div_2, -1)) random_vecs_2 = np.reshape(random_vecs_2, (n_hashes, r_div_2, -1)) else: # shape = (n_hashes, r_div_2, n_dim) random_vecs_1 = vecs[random_idxs_1] random_vecs_2 = vecs[random_idxs_2] # shape = (n_dim, n_hashes, r_div_2) return np.transpose(random_vecs_2 - random_vecs_1, axes=[2, 0, 1])
def _forward_train_eval(self, inputs, rng): (inputs, original_len, n_bins) = self._pad_inputs(inputs) q, k, v = inputs seqlen = q.shape[-2] # q/k/v are n_batch*n_heads, seqlen, d_head # Time indices for causal masking. t = jax.lax.tie_in(q, np.arange(seqlen)) # Split off a "bin" axis for chunks of consecutive items. bq_t = np.reshape(t, (n_bins, -1)) bq = np.reshape(q, (q.shape[0], n_bins, -1, q.shape[-1])) if self._share_qk: bk = self.make_unit_length(bq) else: bk = np.reshape(k, (k.shape[0], n_bins, -1, k.shape[-1])) bv = np.reshape(v, (v.shape[0], n_bins, -1, v.shape[-1])) # Allow each chunk to attend within itself, and also one chunk back. def look_one_back(x): # Output: pairs [ bin_i bin_{i-1} ] concatenated on the time axis. if len(x.shape) == 2: x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0) return np.concatenate([x, x_extra], axis=1) else: assert len(x.shape) == 4 x_extra = np.concatenate([x[:, -1:, :, :], x[:, :-1, :, :]], axis=1) return np.concatenate([x, x_extra], axis=2) bkv_t = look_one_back(bq_t) bk = look_one_back(bk) bv = look_one_back(bv) # Dot-product attention. dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(bq.shape[-1]) # Causal masking based on the time indices. mask = jax.lax.convert_element_type( jax.lax.lt(bq_t[None, :, :, None], bkv_t[None, :, None, :]), np.float32) dots = dots - 1e9 * mask # Mask out attention to self except when no other targets are available. if self._share_qk: self_mask = jax.lax.broadcasted_eye(dots.dtype, dots.shape, (2, 3)) self_mask = jax.lax.tie_in(dots, self_mask) dots = dots - 1e5 * self_mask # Softmax. dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True)) if self.dropout > 0.0: # Dropout is broadcast across the batch+head dimension dropout_shape = (1, dots.shape[-3], dots.shape[-2], dots.shape[-1]) keep_prob = jax.lax.tie_in(dots, 1.0 - self.dropout) keep = backend.random.bernoulli(rng, keep_prob, dropout_shape) multiplier = keep.astype(dots.dtype) / jax.lax.tie_in( keep, keep_prob) dots = dots * multiplier bo = np.matmul(dots, bv) output = np.reshape(bo, (bo.shape[0], -1, bo.shape[-1])) assert output.shape == v.shape return output[..., :original_len, :]
def SplitHeads(x): return np.transpose(np.reshape(x, (nbatch, -1, n_heads, d_head)), (0, 2, 1, 3))
def Flatten(x, n_axes_to_keep=1, **unused_kwargs): if n_axes_to_keep >= len(x.shape): raise ValueError("n_axes_to_keep[%d] should be less than input's rank[%d]" % (n_axes_to_keep, len(x.shape))) return np.reshape(x, (x.shape[:n_axes_to_keep] + (-1,)))
def JoinHeads(x): # pylint: disable=invalid-name return np.reshape(np.transpose(x, (0, 2, 1, 3)), (nbatch, -1, n_heads * d_head))
def ReformerShortenLM(vocab_size, shorten_factor=1, d_embedding=256, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, n_attention_chunks=1, attention_type=tl.DotProductCausalAttention, share_qk=False, axial_pos_shape=(), d_axial_pos_embs=None, ff_activation=tl.FastGelu, ff_use_sru=0, mode='train'): """Reversible transformer language model with shortening. When shorten_factor is F and processing an input of shape [batch, length], we embed the (shifted-right) input and then group each F elements (on length) into a single vector -- so that in the end we process a tensor of shape [batch, length // F, d_model] almost until the end -- at the end it's un-shortend and a SRU is applied. This reduces the length processed inside the main model body, effectively making the model faster but possibly slightly less accurate. Args: vocab_size: int: vocab size shorten_factor: by how much to shorten, see above d_embedding: the depth of the embedding layer and final logits d_model: int: depth of *each half* of the two-part features d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding n_attention_chunks: int: number of chunks for attention attention_type: class: attention class to use, such as DotProductAttention. share_qk: bool, whether to share queries and keys. axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, values must sum to d_embedding. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward mode: str: 'train' or 'eval' Returns: the layer. """ assert mode != 'predict' # TODO(lukaszkaiser,kitaev): fast inference if not axial_pos_shape: positional_encoding = tl.PositionalEncoding(max_len=max_len, dropout=dropout, mode=mode) else: assert d_axial_pos_embs is not None positional_encoding = tl.AxialPositionalEncoding( shape=axial_pos_shape, d_embs=d_axial_pos_embs, dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)), dropout=dropout, mode=mode) positional_embedder = [ tl.Embedding(d_embedding, vocab_size), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter positional_encoding, ] decoder_blocks = [] if isinstance(attention_type, (tuple, list)): assert n_layers % len(attention_type) == 0 else: attention_type = [attention_type] for layer_idx in range(n_layers): layer_attention_type = attention_type[layer_idx % len(attention_type)] decoder_block = DecoderBlock( d_model, d_ff, d_attention_key, d_attention_value, n_heads, n_attention_chunks, attention_type=layer_attention_type, dropout=dropout, share_qk=(share_qk or issubclass(layer_attention_type, tl.LSHCausalAttention)), ff_activation=ff_activation, ff_use_sru=ff_use_sru, mode=mode) decoder_blocks.append(decoder_block) # pylint: disable=g-long-lambda return tl.Serial( tl.ShiftRight(), positional_embedder, tl.Dup(), # Stack has (x, x), the first will be shortened # Before shortening, we need to pad by shorten factor so as not to leak # information into the future. To understand why, imagine shorten factor # of 2 and sequence of length 4, so ABCD. If we shift just by 1, then we # would have 0ABC, which gets grouped to [0A][BC] on input, which is # predicting ABCD as targets. The problem is that [0A] has access to A # and [BC] has access to C -- it will learn to copy it, peek into # the future. Shifting twice to [00][AB] solves the problem as the first # "big" symbol becomes all-0 and the rest is shifted enough. tl.ShiftRight(n_shifts=shorten_factor - 1), tl.Fn( lambda x: np.reshape( # Shorten -- move to depth. x, (x.shape[0], x.shape[1] // shorten_factor, -1)), n_out=1), tl.Dense(d_model), tl.Dup(), # Stack has (short_x, short_x, x) tl.ReversibleSerial(decoder_blocks), tl.Select([0], n_in=2), tl.LayerNorm(), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter tl.Dense(shorten_factor * d_embedding), tl.Fn( lambda x: np.reshape( # Prolong back. x, (x.shape[0], x.shape[1] * shorten_factor, -1)), n_out=1), tl.Concatenate(), # Concatenate with just the embeddings. tl.CausalConv(d_embedding), tl.Relu(), tl.SRU(d_embedding), # One RNN layer for conditional dependence. tl.Dense(vocab_size), tl.LogSoftmax())
def PaddingMask(x, weights, pad=0, **kwargs): del weights, kwargs return np.reshape(x != pad, (x.shape[0], 1, 1, x.shape[-1]))