def forward(self, inputs): weights = self.weights[0] if isinstance(inputs, list): inputs = tuple(inputs) # so that inputs structure matches outputs n_carry = self._n_carry def scannable_fn(x, carry_and_state): # pylint: disable=invalid-name carry, state = carry_and_state x_and_carry = x + carry if n_carry > 0 else x res, new_state = self.sublayer.pure_fn(x_and_carry, weights, state, self.rng, use_cache=True) if n_carry > 0: return (res[:-n_carry], (res[-n_carry:], new_state)) else: return (res, ([], new_state)) if n_carry > 0: xs = inputs[:-n_carry] # Split input stack into inputs and carry. init = (inputs[-n_carry:], self.state[0]) else: xs, init = inputs, ([], self.state[0]) ys, (carry, new_state) = fastmath.scan(scannable_fn, xs, init, axis=self._axis, remat=self._remat) res = ys + carry if n_carry > 0 else ys self.state = (new_state, ) return res # Put outputs and carry back on stack.
def favor_denominator_fwd(init_prefix_sum_value, precision, query_prime, key_prime): def body(p, qk): q, k = qk p += k x = jnp.einsum('...m,...m->...', q, p, precision=precision) return p, x p, r = fastmath.scan(body, init_prefix_sum_value, (query_prime, key_prime)) return r, (precision, query_prime, key_prime, p)
def favor_numerator_fwd(init_prefix_sum_value, precision, query_prime, key_prime, value): def body(p, qkv): (q, k, v) = qkv p += jnp.einsum('...m,...d->...md', k, v, precision=precision) x_slice = jnp.einsum('...m,...md->...d', q, p, precision=precision) return p, x_slice p, w = fastmath.scan(body, init_prefix_sum_value, (query_prime, key_prime, value)) return w, (precision, p, query_prime, key_prime, value)
def favor_denominator_bwd(qkp, r_ct): precision, qs, ks, p = qkp def body(carry, qkx): p, p_ct = carry q, k, x_ct = qkx q_ct = jnp.einsum('...,...m->...m', x_ct, p, precision=precision) p_ct += jnp.einsum('...,...m->...m', x_ct, q, precision=precision) k_ct = p_ct p -= k return (p, p_ct), (q_ct, k_ct) _, (qs_ct, ks_ct) = fastmath.scan( body, (p, jnp.zeros_like(p)), (qs, ks, r_ct), reverse=True) return (None, None, qs_ct, ks_ct)
def favor_numerator_bwd(pqkv, w_ct): precision, p, qs, ks, vs = pqkv def body(carry, qkv_xct): p, p_ct = carry q, k, v, x_ct = qkv_xct q_ct = jnp.einsum('...d,...md->...m', x_ct, p, precision=precision) p_ct += jnp.einsum('...d,...m->...md', x_ct, q, precision=precision) k_ct = jnp.einsum('...md,...d->...m', p_ct, v, precision=precision) v_ct = jnp.einsum('...md,...m->...d', p_ct, k, precision=precision) p -= jnp.einsum('...m,...d->...md', k, v, precision=precision) return (p, p_ct), (q_ct, k_ct, v_ct) _, (qs_ct, ks_ct, vs_ct) = fastmath.scan( body, (p, jnp.zeros_like(p)), (qs, ks, vs, w_ct), reverse=True) return (None, None, qs_ct, ks_ct, vs_ct)
def favor_denominator_bwd(init_prefix_sum_value, precision, qkp, r_ct): del init_prefix_sum_value def body(carry, qkx): p, p_ct = carry q, k, x_ct = qkx q_ct = np.einsum('...,...m->...m', x_ct, p, precision=precision) p_ct += np.einsum('...,...m->...m', x_ct, q, precision=precision) k_ct = p_ct p -= k return (p, p_ct), (q_ct, k_ct) qs, ks, p = qkp _, (qs_ct, ks_ct) = fastmath.scan(body, (p, np.zeros_like(p)), (qs, ks, r_ct), reverse=True) return (qs_ct, ks_ct)
def favor_numerator_bwd(init_prefix_sum_value, precision, pqkv, w_ct): del init_prefix_sum_value def body(carry, qkv_xct): p, p_ct = carry q, k, v, x_ct = qkv_xct q_ct = np.einsum('...d,...md->...m', x_ct, p, precision=precision) p_ct += np.einsum('...d,...m->...md', x_ct, q, precision=precision) k_ct = np.einsum('...md,...d->...m', p_ct, v, precision=precision) v_ct = np.einsum('...md,...m->...d', p_ct, k, precision=precision) p -= np.einsum('...m,...d->...md', k, v, precision=precision) return (p, p_ct), (q_ct, k_ct, v_ct) p, qs, ks, vs = pqkv _, (qs_ct, ks_ct, vs_ct) = fastmath.scan(body, (p, np.zeros_like(p)), (qs, ks, vs, w_ct), reverse=True) return qs_ct, ks_ct, vs_ct
def _scan(f, xs, init_value, axis=0, remat=False): """Scans the f over the given axis of xs. In pseudo-python, the scan function would look as follows: def scan(f, xs, init_value, axis): xs = [xs[..., i, ...] for i in range(xs.shape[axis])] cur_value = init_value ys = [] for x in xs: y, cur_value = f(x, cur_value) ys.append(y) return np.stack(ys, axis), cur_value Args: f: function (x, carry) -> (y, new_carry) xs: tensor, x will be xs slices on axis init_value: tensor, initial value of the carry-over axis: int, the axis on which to slice xs remat: whether to re-materialize f Returns: A pair (ys, last_value) as described above. """ def swapaxes(x): transposed_axes = list(range(len(x.shape))) transposed_axes[axis] = 0 transposed_axes[0] = axis return jnp.transpose(x, axes=transposed_axes) if axis != 0: xs = fastmath.nested_map(swapaxes, xs) def transposed_f(c, x): y, d = f(x, c) return d, y if remat: transposed_f = fastmath.remat(transposed_f) last_value, ys = fastmath.scan(transposed_f, init_value, xs) if axis != 0: ys = fastmath.nested_map(swapaxes, ys) return ys, last_value