def call(self, x, params, state, **kwargs): del kwargs seqlen = x.shape[1] d_head = x.shape[2] x = np.reshape(x, (-1, self._n_heads, seqlen, d_head)) x = np.transpose(x, (0, 2, 1, 3)) # -> n_batch, seqlen, n_heads, d_head x = np.reshape(x, (-1, seqlen, self._n_heads * d_head)) return np.dot(x, params), state
def predict(x, params=(), rng=None): """Predict function jited and parallelized as requested.""" # On one device, jit and run. pred = mapped_predict(reshape_by_device(x, n_devices), params, jax_random.split(rng, n_devices)) # Need to reduce the [device, per-device-batch, ...] tensors back to # a [batch, ...] tensor. The tensors may be nested. if not isinstance(pred, (list, tuple)): # Not nested. batch_size = pred.shape[0] * pred.shape[1] return np.reshape(pred, [batch_size] + list(pred.shape[2:])) batch_size = pred[0].shape[0] * pred[0].shape[1] return [np.reshape(p, [batch_size] + list(p.shape[2:])) for p in pred]
def multigaussian_loss(preds, targets, ngauss=1): # pylint: disable=invalid-name """Compute mixture of gaussians loss.""" ndims = targets.shape[-1] logits = preds[:, :ngauss] mus = preds[:, ngauss:ngauss * (ndims + 1)] sigmas = preds[:, ngauss(ndims + 1):] sigmas = sigmas * sigmas + 1e-6 # Make positive. loglogits = logits - backend.logsumexp(logits, axis=-1, keepdims=True) mus = np.reshape(mus, [-1, ngauss, ndims]) sigmas = np.reshape(sigmas, [-1, ngauss, ndims]) targets = np.reshape(targets, [-1, 1, ndims]) glogprobs = log_gaussian_diag_pdf(targets, mus, sigmas) return backend.logsumexp(loglogits + glogprobs, axis=-1)
def call(self, x, params=(), state=(), **kwargs): del kwargs w, b = params x_shape = list(x.shape) if len(x_shape) > 4: self._check_nhwc() new_batch_dim = six.moves.reduce(operator.mul, x_shape[:-3]) x = np.reshape(x, [new_batch_dim] + x_shape[-3:]) res = backend.conv(x, w, self._strides, self._padding, self._dimension_numbers, self._one) + b if len(x_shape) > 4: res = np.reshape(res, x_shape[:-3] + list(res.shape[-3:])) return res, state
def call(self, x, params, state, **kwargs): del kwargs seqlen = x.shape[1] res = np.dot(x, params) # n_batch, seqlen, n_heads*d_head -> n_batch, seqlen, n_heads, d_head res = np.reshape(res, (x.shape[0], seqlen, self._n_heads, self._d_head)) # n_batch, seqlen, n_heads, d_head -> n_batch, n_heads, seqlen, d_head res = np.transpose(res, (0, 2, 1, 3)) # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head res = np.reshape(res, (-1, seqlen, self._d_head)) return res, state
def reshape_by_device(train_data, num_devices): """Reshape the train_data into a shape [num_devices, ...].""" x, y = train_data x_shape, y_shape = list(x.shape), list(y.shape) assert x_shape[0] == y_shape[0] # Same batch size. batch_size = x_shape[0] batch_size_per_device = batch_size // num_devices # We require that num_devices divides batch_size evenly. assert batch_size_per_device * num_devices == batch_size # New shapes. new_shape_prefix = [num_devices, batch_size_per_device] x = np.reshape(x, new_shape_prefix + x_shape[1:]) y = np.reshape(y, new_shape_prefix + y_shape[1:]) return x, y
def Flatten(x, params, n_axes_to_keep=1, **kwargs): del params, kwargs if n_axes_to_keep >= len(x.shape): raise ValueError( "n_axes_to_keep[%d] should be less than input's rank[%d]" % (n_axes_to_keep, len(x.shape))) return np.reshape(x, (x.shape[:n_axes_to_keep] + (-1, )))
def EncoderDecoderMask(x, **unused_kwargs): """Makes encoder-decoder mask from decoder input and a padding mask.""" decoder_input, padding_mask = x padding_mask = np.reshape( padding_mask, (padding_mask.shape[0], 1, 1, padding_mask.shape[-1])) # Final mask shape is [batch, 1 for heads, decoder-len, encoder-len]. return padding_mask + np.zeros((1, 1, decoder_input.shape[1], 1))
def combine(x): if len(x.shape) > 1: batch_size = x.shape[0] * x.shape[1] return np.reshape(x, [batch_size] + list(x.shape[2:])) # TODO(lukaszkaiser): is returning averages for scalars the right choice? # If it is only scalar, return the average. return np.mean(x, axis=0)
def SplitHeads(x, params, n_heads=1, **kwargs): del params, kwargs d_model = x.shape[-1] assert d_model % n_heads == 0 d_head = d_model // n_heads n_batch = np.shape(x)[0] # n_batch, seqlen, d_model --> n_batch, n_heads, seqlen, d_head return np.transpose(np.reshape(x, (n_batch, -1, n_heads, d_head)), (0, 2, 1, 3))
def predict(x, params=(), rng=None): """Predict function jited and parallelized as requested.""" # On one device, jit and run. if num_devices == 1: return backend.jit(model_predict)(x, params, rng=rng) # Multi-devices, pmap and run. @functools.partial(backend.pmap, axis_name="batch") def mapped_predict(x, params, rng): return model_predict(x, params, rng=rng) pred = mapped_predict(reshape_by_device(x, num_devices), params, jax_random.split(rng, num_devices)) # Need to reduce the [device, per-device-batch, ...] tensors back to # a [batch, ...] tensor. The tensors may be nested. if not isinstance(pred, (list, tuple)): # Not nested. batch_size = pred.shape[0] * pred.shape[1] return np.reshape(pred, [batch_size] + list(pred.shape[2:])) batch_size = pred[0].shape[0] * pred[0].shape[1] return [np.reshape(p, [batch_size] + list(p.shape[2:])) for p in pred]
def hash_vectors(self, vecs, rng): # See https://arxiv.org/pdf/1509.02897.pdf # We sample a different random rotation for each round of hashing to # decrease the probability of hash misses. assert self.n_buckets % 2 == 0 random_rotations_shape = ( vecs.shape[-1], self.n_hashes if self._rehash_each_round else 1, self.n_buckets // 2) rng = jax.lax.tie_in(vecs, rng) rng, subrng = backend.random.split(rng) random_rotations = jax.random.normal( rng, random_rotations_shape).astype('float32') # TODO(lukaszkaiser): the dropout mask will be used for all rounds of # hashing, so it's shared between them. Check if that's what we want. dropped_vecs = self.drop_for_hash(vecs, subrng) rotated_vecs = np.einsum('tf,fhb->htb', dropped_vecs, random_rotations) rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1) if self._rehash_each_round: buckets = np.argmax(rotated_vecs, axis=-1) # buckets is now (self.n_hashes, seqlen). Next we add offsets so that # bucket numbers from different hashing rounds don't overlap. offsets = jax.lax.tie_in(buckets, np.arange(self.n_hashes)) offsets = np.reshape(offsets * self.n_buckets, (-1, 1)) buckets = np.reshape(buckets + offsets, (-1,)) else: # In this configuration, we map each item to the top self.n_hashes buckets rotated_vecs = np.squeeze(rotated_vecs, 0) bucket_range = jax.lax.tie_in(vecs, np.arange(rotated_vecs.shape[-1])) bucket_range = np.reshape(bucket_range, (1, -1)) bucket_range = np.broadcast_to(bucket_range, rotated_vecs.shape) _, buckets = jax.lax.sort_key_val( rotated_vecs, bucket_range, dimension=-1) buckets = buckets[:, -self.n_hashes:] buckets = np.reshape(np.moveaxis(buckets, 0, -1), (-1,)) return buckets
def PreparePairedSequenceBatch(source, target_in, pad=0): """Build masks for this batch. Args: source: (batch, source_len) array of integer-coded symbols for inputs target_in: (batch, batch_len) array of integer-coded symbols for targets pad: int: the padding symbol used to pad the above Returns: Prepared batch of tuple of arrays: source, input-target, shifted-target, source mask, target mask, source-target "memory" mask, minibatch token count """ target = target_in[:, :-1] target_y = target_in[:, 1:] source_mask = np.reshape(source != pad, (source.shape[0], 1, 1, source.shape[-1])) target_mask = MakeTargetMask(target, pad) memory_mask = (np.reshape( np.arange(target.shape[-1]) < source.shape[-1], [-1, 1])) ntokens = np.sum(target_y != pad) return (source, target, target_y, source_mask, target_mask, memory_mask, ntokens)
def _reshape_by_device_single(x, n_devices): """Reshape x into a shape [n_devices, ...].""" x_shape = list(x.shape) batch_size = x_shape[0] batch_size_per_device = batch_size // n_devices # We require that n_devices divides batch_size evenly. if batch_size_per_device * n_devices != batch_size: logging.fatal( "We require that n_devices[%d] divides batch_size[%d] evenly.", n_devices, batch_size) # New shape. new_shape_prefix = [n_devices, batch_size_per_device] return np.reshape(x, new_shape_prefix + x_shape[1:])
def predict(x, params=(), rng=None): """Predict function jited and parallelized as requested.""" # On one device, jit and run. if num_devices == 1: return backend.jit(model_predict)(x, params, rng=rng) # Multi-devices, pmap and run. @functools.partial(backend.pmap, axis_name="batch") def mapped_predict(x, params, rng): return model_predict(x, params, rng=rng) pred = mapped_predict(reshape_by_device(x, num_devices), params, jax_random.split(rng, num_devices)) batch_size = x.shape[0] return np.reshape(pred, [batch_size] + list(pred.shape[2:]))
def test_batch_norm(self): input_shape = (2, 3, 4) input_dtype = np.float32 eps = 1e-5 rng = backend.random.get_prng(0) inp1 = np.reshape(np.arange(np.prod(input_shape), dtype=input_dtype), input_shape) m1 = 11.5 # Mean of this random input. v1 = 47.9167 # Variance of this random input. layer = normalization.BatchNorm(axis=(0, 1, 2)) params, state = layer.initialize(input_shape, input_dtype, rng) onp.testing.assert_allclose(state[0], 0) onp.testing.assert_allclose(state[1], 1) self.assertEqual(state[2], 0) out, state = layer(inp1, params, state) onp.testing.assert_allclose(state[0], m1 * 0.001) onp.testing.assert_allclose(state[1], 0.999 + v1 * 0.001, rtol=1e-6) self.assertEqual(state[2], 1) onp.testing.assert_allclose(out, (inp1 - m1) / np.sqrt(v1 + eps), rtol=1e-6)
def predict(params, batch, rng=None): """Predict function jited and parallelized as requested.""" # If not jit'ing, just run the function. if not jit_eval: return model_predict(params, batch, rng=rng) # On one device, jit and run. if num_devices == 1: return backend.jit(model_predict)(params, batch, rng=rng) # Multi-devices, pmap and run. @functools.partial(backend.pmap, axis_name="batch") def mapped_predict(params, batch, rng): return model_predict(params, batch, rng=rng) pred = mapped_predict( jax.replicate(params), reshape_by_device(batch, num_devices), jax.replicate(rng)) batch_size = batch.shape[0] return np.reshape(pred, [batch_size] + list(pred.shape[2:]))
def _update_sketched(self, grads, params, m, v, opt_params): """Update for higher-rank parameters.""" (learning_rate, momentum) = opt_params shape = params.shape rank = len(shape) reshaped_accumulators = [np.reshape(v[i], self._expanded_shape(shape, i)) for i in range(rank)] current_accumulator = self._minimum(reshaped_accumulators) current_accumulator += grads * grads accumulator_inv_sqrt = np.where(current_accumulator > 0.0, 1.0 / np.sqrt(current_accumulator), np.zeros_like(current_accumulator)) preconditioned_gradient = grads * accumulator_inv_sqrt m = (1.0 - momentum) * preconditioned_gradient + momentum * m params = params - (learning_rate * m).astype(params.dtype) for i in range(len(v)): axes = list(range(int(i))) + list(range(int(i) + 1, rank)) dim_accumulator = np.amax(current_accumulator, axis=axes) v[i] = dim_accumulator return params, (m, v)
def _update_sketched(self, step, g, x, m, v): """Update for higher-rank parameters.""" shape = x.shape rank = len(shape) reshaped_accumulators = [ np.reshape(v[i], self._expanded_shape(shape, i)) for i in range(rank) ] current_accumulator = self._minimum(reshaped_accumulators) current_accumulator += g * g accumulator_inv_sqrt = np.where(current_accumulator > 0.0, 1.0 / np.sqrt(current_accumulator), np.zeros_like(current_accumulator)) preconditioned_gradient = g * accumulator_inv_sqrt m = (1.0 - self._momentum) * preconditioned_gradient + self._momentum * m x = x - self.step_size(step) * m for i in range(len(v)): axes = list(range(int(i))) + list(range(int(i) + 1, rank)) dim_accumulator = np.amax(current_accumulator, axis=axes) v[i] = dim_accumulator return x, (m, v)
def test_batch_norm(self): input_shape = (2, 3, 4) input_dtype = np.float32 eps = 1e-5 rng = backend.random.get_prng(0) inp1 = np.reshape(np.arange(np.prod(input_shape), dtype=input_dtype), input_shape) m1 = 11.5 v1 = 47.9167 layer = normalization.BatchNorm(axis=(0, 1, 2)) params, state = layer.initialize(input_shape, input_dtype, rng) onp.testing.assert_allclose(state[0], 0) onp.testing.assert_allclose(state[1], 0) self.assertEqual(state[2], 0) out, state = layer(inp1, params, state) onp.testing.assert_allclose(state[0], m1) onp.testing.assert_allclose(state[1], v1, rtol=1e-6) self.assertEqual(state[2], 1) onp.testing.assert_allclose(out, (inp1 - m1) / np.sqrt(v1 + eps), rtol=1e-6) inp2 = inp1 * 2 + 3 m2 = m1 * 2 + 3 v2 = v1 * 4 m12 = (m1 + m2) / 2 v12 = (v1 + v2) / 2 out, state = layer(inp2, params, state) onp.testing.assert_allclose(state[0], m12) onp.testing.assert_allclose(state[1], v12, rtol=1e-6) self.assertEqual(state[2], 2) onp.testing.assert_allclose(out, (inp2 - m2) / np.sqrt(v2 + eps), rtol=1e-6) layer = normalization.BatchNorm(axis=(0, 1, 2), mode="eval") inp3 = inp1 * 5 + 7 out, state_unchanged = layer(inp3, params, state) for i in range(3): onp.testing.assert_allclose(state_unchanged[i], state[i]) onp.testing.assert_allclose(out, (inp3 - m12) / np.sqrt(v12 + eps), rtol=1e-6)
def combine(x): batch_size = x.shape[0] * x.shape[1] return np.reshape(x, [batch_size] + list(x.shape[2:]))
def unchunk_vectors(x): # pylint: disable=invalid-name return np.reshape(x, (x.shape[0], -1, x.shape[-1]))
def chunk_scalars(x): # pylint: disable=invalid-name return np.reshape(x, (x.shape[0], self.n_bins, -1))
def call(self, inputs, params=(), state=(), rng=None, **kwargs): del params, kwargs # We use the same vector as both a query and a key. For now we haven't # adjusted any of the surrounding code, so we still get a separate "key" # input that we ignore. qk, _, v = inputs seqlen = qk.shape[-2] # qk/v are n_hashes*n_batch*n_heads, seqlen, d_head # TODO(kitaev): is it faster to fuse this tiling into gather/scatter ops? qk = np.tile(qk, (self.n_hashes, 1, 1)) v = np.tile(v, (self.n_hashes, 1, 1)) # bins are n_hashes*n_batch*n_heads, seqlen # They specify which hash bucket the query/key/value vectors fall in. bins = self.hash_vectors(qk, rng=rng) # joint_t is n_hashes*n_batch*n_heads, seqlen joint_t = jax.lax.tie_in(qk, np.arange(seqlen)) joint_t = np.reshape(joint_t, (1, seqlen)) joint_t = np.broadcast_to(joint_t, qk.shape[:-1]) assert int( (self.n_buckets_per_bin * self.n_bins + 1) * seqlen ) < 2**31, ( 'Potential 32-bit integer overflow; please double-check the code.') joint_bins_and_t = seqlen * bins + joint_t def chunk_scalars(x): # pylint: disable=invalid-name return np.reshape(x, (x.shape[0], self.n_bins, -1)) def chunk_vectors(x): # pylint: disable=invalid-name return np.reshape(x, (x.shape[0], self.n_bins, -1, x.shape[-1])) def unchunk_vectors(x): # pylint: disable=invalid-name return np.reshape(x, (x.shape[0], -1, x.shape[-1])) # Sort everything by bin number, with a secondary sort by time # (variables starting with "s" are sorted) _, sjoint_t = jax.lax.sort_key_val(joint_bins_and_t, joint_t, dimension=-1) _, undo_sort = jax.lax.sort_key_val(sjoint_t, joint_t, dimension=-1) # TODO(kitaev): why does jax flag integer indices as differentiable? # If we don't call stop_gradient here, custom gradients below won't work # because the primitive functions close over "differentiable" variables. sjoint_t = jax.lax.stop_gradient(sjoint_t) undo_sort = jax.lax.stop_gradient(undo_sort) # The backward pass of gather is in general a scatter operation, but we know # we're dealing with permutations so we use gather for the backward pass # too. This custom gradient should be about 2x faster than having jax infer # one that uses scatter ops instead. def permute_impl(vecs): assert len(vecs.shape) == 3 return np.take_along_axis(vecs, sjoint_t[:, :, None], axis=-2) def unpermute_impl(vecs): assert len(vecs.shape) == 3 return np.take_along_axis(vecs, undo_sort[:, :, None], axis=-2) @jax.custom_transforms def permute(vecs): return permute_impl(vecs) def permute_vjp(vecs): out_vecs = permute_impl(vecs) def vjpfun(grad): return (unpermute_impl(grad), ) return out_vecs, vjpfun @jax.custom_transforms def unpermute(vecs): return unpermute_impl(vecs) def unpermute_vjp(vecs): out_vecs = unpermute_impl(vecs) def vjpfun(grad): return (permute_impl(grad), ) return out_vecs, vjpfun jax.defvjp_all(permute, permute_vjp) jax.defvjp_all(unpermute, unpermute_vjp) sqk = permute(qk) sv = permute(v) # Split off a "bin" axis so that attention only occurs within chunks. bq_t = bkv_t = chunk_scalars(sjoint_t) bqk = chunk_vectors(sqk) bv = chunk_vectors(sv) # Hashing operates on unit-length vectors. Unnormalized query vectors are # fine because they effectively provide a learnable temperature for the # attention softmax, but normalizing keys is needed so that similarity for # the purposes of attention correctly corresponds to hash locality. bq = bqk bk = self.make_unit_length(bqk) # Allow each chunk to attend within itself, and also one chunk back. Chunk # boundaries might occur in the middle of a sequence of items from the # same bin, so this increases the chances of attending to relevant items. # TODO(kitaev): benchmark whether XLA pad operation is noticeably faster. bk_extra = np.concatenate([bk[:, -1:, :, :], bk[:, :-1, :, :]], axis=1) bk = np.concatenate([bk, bk_extra], axis=2) bv_extra = np.concatenate([bv[:, -1:, :, :], bv[:, :-1, :, :]], axis=1) bv = np.concatenate([bv, bv_extra], axis=2) bkv_t_extra = np.concatenate([bkv_t[:, -1:, :], bkv_t[:, :-1, :]], axis=1) bkv_t = np.concatenate([bkv_t, bkv_t_extra], axis=2) # Dot-product attention. dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(bq.shape[-1]) # Causal masking mask = jax.lax.convert_element_type( jax.lax.lt(bq_t[:, :, :, None], bkv_t[:, :, None, :]), np.float32) dots = dots - 1e9 * mask # Mask out attention to self except when no other targets are available. self_mask = jax.lax.broadcasted_eye(dots.dtype, dots.shape, (2, 3)) self_mask = jax.lax.tie_in(dots, self_mask) dots = dots - 32 * self_mask # Softmax. dots_logsumexp = backend.logsumexp(dots, axis=-1, keepdims=True) dots = np.exp(dots - dots_logsumexp) if self._hard_k > 0: top_k = np.sort(dots)[..., -self._hard_k] # Get the top-kth weight. top_k = jax.lax.stop_gradient(top_k) dots -= top_k[..., np.newaxis] # Subtract (be 0 for lower ones). dots = np.maximum(dots, 0) dots_sum = np.sum(dots, axis=-1, keepdims=True) # Sum to re-normalize. dots_logsumexp += np.log(dots_sum) # Add it to the weight. dots /= dots_sum # Re-normalize. bo = np.matmul(dots, bv) so = unchunk_vectors(bo) slogits = unchunk_vectors(dots_logsumexp) o = unpermute(so) logits = unpermute(slogits) o = np.reshape(o, (self.n_hashes, -1, seqlen, o.shape[-1])) logits = np.reshape(logits, (self.n_hashes, -1, seqlen, 1)) probs = np.exp(logits - backend.logsumexp(logits, axis=0, keepdims=True)) out = np.sum(o * probs, axis=0) assert out.shape == inputs[2].shape return out, state
def call_and_grad(self, inputs, ct, rng=None, **kwargs): del kwargs # We use the same vector as both a query and a key. For now we haven't # adjusted any of the surrounding code, so we still get a separate "key" # input that we ignore. qk, ignored_k, v = inputs seqlen = qk.shape[-2] # qk/v are n_batch*n_heads, seqlen, d_head # bins are n_batch*n_heads, seqlen # They specify which hash bucket the query/key/value vectors fall in. bins = self.hash_vectors(qk, rng=rng) # joint_t is n_batch*n_heads, seqlen joint_t = jax.lax.tie_in(qk, np.arange(seqlen)) joint_t = np.reshape(joint_t, (1, seqlen)) joint_t = np.broadcast_to(joint_t, qk.shape[:-1]) assert int((self.n_bins + 1) * seqlen) < 2**31, ( 'Potential 32-bit integer overflow; please double-check the code.') joint_bins_and_t = seqlen * bins + joint_t def chunk_scalars(x): # pylint: disable=invalid-name return np.reshape(x, (x.shape[0], self.n_bins, -1)) def chunk_vectors(x): # pylint: disable=invalid-name return np.reshape(x, (x.shape[0], self.n_bins, -1, x.shape[-1])) def unchunk_vectors(x): # pylint: disable=invalid-name return np.reshape(x, (x.shape[0], -1, x.shape[-1])) # Sort everything by bin number, with a secondary sort by time # (variables starting with "s" are sorted) _, sjoint_t = jax.lax.sort_key_val(joint_bins_and_t, joint_t, dimension=-1) sqk = np.take_along_axis(qk, sjoint_t[:, :, None], axis=-2) sv = np.take_along_axis(v, sjoint_t[:, :, None], axis=-2) if ct is not None: so_ct = np.take_along_axis(ct, sjoint_t[:, :, None], axis=-2) @jax.jit def binned_attn(sqk, sv): # pylint: disable=invalid-name """Performs attention on sorted queries/keys/values.""" # Split off a "bin" axis so that attention only occurs whithin chunks. bq_t = bkv_t = chunk_scalars(sjoint_t) bqk = chunk_vectors(sqk) bv = chunk_vectors(sv) # Hashing operates on unit-length vectors. Unnormalized query vectors are # fine because they effectively provide a learnable temperature for the # attention softmax, but normalizing keys is needed so that similarity for # the purposes of attention correctly corresponds to hash locality. bq = bqk bk = self.make_unit_length(bqk) # Allow each chunk to attend within itself, and also one chunk back. Chunk # boundaries might occur in the middle of a sequence of items from the # same bin, so this increases the chances of attending to relevant items. # TODO(kitaev): benchmark whether XLA pad operation is noticeably faster. bk_extra = np.concatenate([bk[:, -1:, :, :], bk[:, :-1, :, :]], axis=1) bk = np.concatenate([bk, bk_extra], axis=2) bv_extra = np.concatenate([bv[:, -1:, :, :], bv[:, :-1, :, :]], axis=1) bv = np.concatenate([bv, bv_extra], axis=2) bkv_t_extra = np.concatenate([bkv_t[:, -1:, :], bkv_t[:, :-1, :]], axis=1) bkv_t = np.concatenate([bkv_t, bkv_t_extra], axis=2) # Dot-product attention. dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt( bq.shape[-1]) # Causal masking mask = jax.lax.convert_element_type( jax.lax.lt(bq_t[:, :, :, None], bkv_t[:, :, None, :]), np.float32) dots = dots - 1e9 * mask # Mask out attention to self except when no other targets are available. self_mask = jax.lax.broadcasted_eye(dots.dtype, dots.shape, (2, 3)) self_mask = jax.lax.tie_in(dots, self_mask) dots = dots - 32 * self_mask # Softmax. dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True)) bo = np.matmul(dots, bv) so = unchunk_vectors(bo) return so @jax.jit def binned_attn_vjp(sqk, sv, so_ct): # pylint: disable=invalid-name so, vjpfun = jax.vjp(binned_attn, sqk, sv) sqkv_ct = vjpfun(so_ct) return so, sqkv_ct if ct is None: so = binned_attn(sqk, sv) _, undo_sort = jax.lax.sort_key_val(sjoint_t, joint_t, dimension=-1) out = np.take_along_axis(so, undo_sort[:, :, None], axis=-2) return out, None else: # Jax can construct a backward pass automatically, but it's about 2x # slower than writing our own. The main reason is that the backward pass # of gather is in general a scatter operation, but we know we're dealing # with permutations so we use gather for the backward pass too. so, (sqk_ct, sv_ct) = binned_attn_vjp(sqk, sv, so_ct) _, undo_sort = jax.lax.sort_key_val(sjoint_t, joint_t, dimension=-1) out = np.take_along_axis(so, undo_sort[:, :, None], axis=-2) qk_ct = np.take_along_axis(sqk_ct, undo_sort[:, :, None], axis=-2) v_ct = np.take_along_axis(sv_ct, undo_sort[:, :, None], axis=-2) return out, (qk_ct, np.zeros_like(ignored_k), v_ct)
def PaddingMask(x, params, pad=0, **kwargs): del params, kwargs return np.reshape(x != pad, (x.shape[0], 1, 1, x.shape[-1]))
def JoinHeads(x): # pylint: disable=invalid-name return np.reshape(np.transpose(x, (0, 2, 1, 3)), (nbatch, -1, n_heads * d_head))
def SplitHeads(x): return np.transpose(np.reshape(x, (nbatch, -1, n_heads, d_head)), (0, 2, 1, 3))
def unchunk_rank4(x): return np.reshape(x, (x.shape[0], x.shape[1], -1, x.shape[-1]))
def chunk_rank4(x): return np.reshape( x, (x.shape[0], x.shape[1], self.n_bins, -1, x.shape[-1]))