Exemple #1
0
    def forward(self, inputs):
        weights = self.weights[0]
        if isinstance(inputs, list):
            inputs = tuple(inputs)  # so that inputs structure matches outputs
        n_carry = self._n_carry

        def scannable_fn(x, carry_and_state):  # pylint: disable=invalid-name
            carry, state = carry_and_state
            x_and_carry = x + carry if n_carry > 0 else x
            res, new_state = self.sublayer.pure_fn(x_and_carry,
                                                   weights,
                                                   state,
                                                   self.rng,
                                                   use_cache=True)
            if n_carry > 0:
                return (res[:-n_carry], (res[-n_carry:], new_state))
            else:
                return (res, ([], new_state))

        if n_carry > 0:
            xs = inputs[:-n_carry]  # Split input stack into inputs and carry.
            init = (inputs[-n_carry:], self.state[0])
        else:
            xs, init = inputs, ([], self.state[0])
        ys, (carry, new_state) = fastmath.scan(scannable_fn,
                                               xs,
                                               init,
                                               axis=self._axis,
                                               remat=self._remat)
        res = ys + carry if n_carry > 0 else ys
        self.state = (new_state, )
        return res  # Put outputs and carry back on stack.
Exemple #2
0
  def favor_denominator_fwd(init_prefix_sum_value, precision,
                            query_prime, key_prime):
    def body(p, qk):
      q, k = qk
      p += k
      x = jnp.einsum('...m,...m->...', q, p, precision=precision)
      return p, x

    p, r = fastmath.scan(body, init_prefix_sum_value, (query_prime, key_prime))
    return r, (precision, query_prime, key_prime, p)
Exemple #3
0
 def favor_numerator_fwd(init_prefix_sum_value, precision,
                         query_prime, key_prime, value):
   def body(p, qkv):
     (q, k, v) = qkv
     p += jnp.einsum('...m,...d->...md', k, v, precision=precision)
     x_slice = jnp.einsum('...m,...md->...d', q, p, precision=precision)
     return p, x_slice
   p, w = fastmath.scan(body, init_prefix_sum_value,
                        (query_prime, key_prime, value))
   return w, (precision, p, query_prime, key_prime, value)
Exemple #4
0
  def favor_denominator_bwd(qkp, r_ct):
    precision, qs, ks, p = qkp

    def body(carry, qkx):
      p, p_ct = carry
      q, k, x_ct = qkx
      q_ct = jnp.einsum('...,...m->...m', x_ct, p, precision=precision)
      p_ct += jnp.einsum('...,...m->...m', x_ct, q, precision=precision)
      k_ct = p_ct
      p -= k
      return (p, p_ct), (q_ct, k_ct)

    _, (qs_ct, ks_ct) = fastmath.scan(
        body, (p, jnp.zeros_like(p)), (qs, ks, r_ct), reverse=True)
    return (None, None, qs_ct, ks_ct)
Exemple #5
0
  def favor_numerator_bwd(pqkv, w_ct):
    precision, p, qs, ks, vs = pqkv

    def body(carry, qkv_xct):
      p, p_ct = carry
      q, k, v, x_ct = qkv_xct
      q_ct = jnp.einsum('...d,...md->...m', x_ct, p, precision=precision)
      p_ct += jnp.einsum('...d,...m->...md', x_ct, q, precision=precision)
      k_ct = jnp.einsum('...md,...d->...m', p_ct, v, precision=precision)
      v_ct = jnp.einsum('...md,...m->...d', p_ct, k, precision=precision)
      p -= jnp.einsum('...m,...d->...md', k, v, precision=precision)
      return (p, p_ct), (q_ct, k_ct, v_ct)

    _, (qs_ct, ks_ct, vs_ct) = fastmath.scan(
        body, (p, jnp.zeros_like(p)), (qs, ks, vs, w_ct), reverse=True)
    return (None, None, qs_ct, ks_ct, vs_ct)
Exemple #6
0
    def favor_denominator_bwd(init_prefix_sum_value, precision, qkp, r_ct):
        del init_prefix_sum_value

        def body(carry, qkx):
            p, p_ct = carry
            q, k, x_ct = qkx
            q_ct = np.einsum('...,...m->...m', x_ct, p, precision=precision)
            p_ct += np.einsum('...,...m->...m', x_ct, q, precision=precision)
            k_ct = p_ct
            p -= k
            return (p, p_ct), (q_ct, k_ct)

        qs, ks, p = qkp
        _, (qs_ct, ks_ct) = fastmath.scan(body, (p, np.zeros_like(p)),
                                          (qs, ks, r_ct),
                                          reverse=True)
        return (qs_ct, ks_ct)
Exemple #7
0
    def favor_numerator_bwd(init_prefix_sum_value, precision, pqkv, w_ct):
        del init_prefix_sum_value

        def body(carry, qkv_xct):
            p, p_ct = carry
            q, k, v, x_ct = qkv_xct
            q_ct = np.einsum('...d,...md->...m', x_ct, p, precision=precision)
            p_ct += np.einsum('...d,...m->...md', x_ct, q, precision=precision)
            k_ct = np.einsum('...md,...d->...m', p_ct, v, precision=precision)
            v_ct = np.einsum('...md,...m->...d', p_ct, k, precision=precision)
            p -= np.einsum('...m,...d->...md', k, v, precision=precision)
            return (p, p_ct), (q_ct, k_ct, v_ct)

        p, qs, ks, vs = pqkv
        _, (qs_ct, ks_ct, vs_ct) = fastmath.scan(body, (p, np.zeros_like(p)),
                                                 (qs, ks, vs, w_ct),
                                                 reverse=True)
        return qs_ct, ks_ct, vs_ct
Exemple #8
0
def _scan(f, xs, init_value, axis=0, remat=False):
    """Scans the f over the given axis of xs.

  In pseudo-python, the scan function would look as follows:

  def scan(f, xs, init_value, axis):
    xs  = [xs[..., i, ...] for i in range(xs.shape[axis])]
    cur_value = init_value
    ys = []
    for x in xs:
      y, cur_value = f(x, cur_value)
      ys.append(y)
    return np.stack(ys, axis), cur_value

  Args:
    f: function (x, carry) -> (y, new_carry)
    xs: tensor, x will be xs slices on axis
    init_value: tensor, initial value of the carry-over
    axis: int, the axis on which to slice xs
    remat: whether to re-materialize f

  Returns:
    A pair (ys, last_value) as described above.
  """
    def swapaxes(x):
        transposed_axes = list(range(len(x.shape)))
        transposed_axes[axis] = 0
        transposed_axes[0] = axis
        return jnp.transpose(x, axes=transposed_axes)

    if axis != 0:
        xs = fastmath.nested_map(swapaxes, xs)

    def transposed_f(c, x):
        y, d = f(x, c)
        return d, y

    if remat:
        transposed_f = fastmath.remat(transposed_f)
    last_value, ys = fastmath.scan(transposed_f, init_value, xs)
    if axis != 0:
        ys = fastmath.nested_map(swapaxes, ys)
    return ys, last_value