Exemple #1
0
  def _do_custom_gradients(self, x, weights, state, rng):
    """Calls this layer for a forward pass, but with custom gradients."""
    assert math.backend_name() == 'jax', (
        'Custom gradients are only supported in JAX for now.')

    # See this link for how custom transformations are defined in JAX:
    # https://jax.readthedocs.io/en/latest/jax.html#jax.custom_transforms
    @jax.custom_transforms
    def _do_forward(y, weights):
      old_weights, old_state, old_rng = self._weights, self._state, self._rng
      res = self.forward(y, weights)
      s = self._state
      self._weights, self._state, self._rng = old_weights, old_state, old_rng
      return res, s

    # This is the custom gradient (vector-jacobian product in JAX) function.
    # For the exact specification of this custom transformation see this link:
    # https://jax.readthedocs.io/en/latest/jax.html#jax.defjvp_all
    def do_forward_vjp(y, weights):
      """Custom gradient (vjp) function."""
      old_weights, old_state, old_rng = self._weights, self._state, self._rng
      output = self.forward(y, weights)
      new_state = self._state
      self._weights, self._state, self._rng = old_weights, old_state, old_rng
      def vjpfun(grad):
        grad = grad[0]  # Ignore dummy gradient wrt state.
        res = self.backward(y, output, grad, weights, state, new_state, rng)
        return res
      return (output, new_state), vjpfun

    jax.defvjp_all(_do_forward, do_forward_vjp)
    output, state = _do_forward(x, weights)
    state = jax.lax.stop_gradient(state)
    return output, state
Exemple #2
0
def make_ito_integrate(flat_f, flat_g, ts, dt, bm, method='milstein'):
    # `flat_f` and `flat_g` each takes in *flat* states and parameters and returns *flat* gradients.

    # Make fast jitted helper functions for Milstein correction.
    @jax.jit
    def flat_g_prod(flat_y, t, args, noise):
        return flat_g(flat_y, t, args) * noise

    flat_gdg = jax.jit(make_gdg_prod(flat_g_prod))

    @jax.custom_transforms
    def ito_integrate_(flat_y0, flat_args):
        return ito_integrate(flat_f,
                             flat_g,
                             flat_y0,
                             ts,
                             bm,
                             dt,
                             args=flat_args,
                             g_prod=flat_g_prod,
                             gdg=flat_gdg,
                             method=method)  # (T, batch_size * D)

    def vjp_all(flat_y0, flat_args):
        ans = ito_integrate(flat_f,
                            flat_g,
                            flat_y0,
                            ts,
                            bm,
                            dt,
                            args=flat_args,
                            g_prod=flat_g_prod,
                            gdg=flat_gdg,
                            method=method)  # (T, batch_size * D).

        def actual_vjp_all(cotan):
            T, _ = cotan.shape

            v_flat_y, v_flat_args = cotan[-1, :], np.zeros_like(flat_args)
            for i in range(T - 1, 0, -1):
                ts_local = np.array([ts[i - 1], ts[i]])
                _, (v_flat_y,
                    v_flat_args) = vjp_ito_integrate(v_yt=v_flat_y,
                                                     v_argst=v_flat_args,
                                                     yt=ans[i, :],
                                                     f=flat_f,
                                                     g=flat_g,
                                                     ts=ts_local,
                                                     bm=bm,
                                                     dt=dt,
                                                     args=flat_args,
                                                     method=method)
                v_flat_y = v_flat_y + cotan[i - 1, :]
            return v_flat_y, v_flat_args

        return ans, actual_vjp_all

    jax.defvjp_all(ito_integrate_, vjp_all)

    return ito_integrate_
Exemple #3
0
    def test_custom_transforms_vjp_nones(self):
        # issue rasied by jsnoek@ and jumper@
        @jax.custom_transforms
        def solve(a, b):
            return np.dot(np.linalg.inv(a), b)

        # print(solve(a, b))

        def solve_vjp(a, b):
            x = solve(a, b)

            def vjp(x_tangent):
                dx = np.dot(solve(a, x_tangent), x.T)
                out = (dx, b * 0.)
                return out

            return x, vjp

        jax.defvjp_all(solve, solve_vjp)
        gf = grad(lambda a, b: np.sum(solve(a, b)))

        n = 3
        a_in = np.linspace(0, 1, n)[:, None]
        a = np.dot(a_in, a_in.T) + np.eye(n) * 0.1
        real_x = onp.random.RandomState(0).randn(n)
        b = np.dot(a + np.eye(a.shape[0]), real_x)
        print(gf(a, b))  # doesn't crash
Exemple #4
0
def _custom_gradient(f):
    """Jax implementation of tf.custom_gradient."""
    if not JAX_MODE:
        # Numpy backend ignores custom gradients, so we do too.
        return lambda *args, **kwargs: f(*args, **kwargs)[0]
    import jax  # pylint: disable=g-import-not-at-top

    def f_(*args, **kwargs):
        value, vjp = f(*args, **kwargs)

        def vjp_(cts_out):
            cts_in = vjp(cts_out)
            if not isinstance(cts_in, tuple):
                cts_in = (cts_in, )
            return cts_in

        return value, vjp_

    @jax.custom_transforms
    def wrapped(*args, **kwargs):
        value, _ = f(*args, **kwargs)
        return value

    jax.defvjp_all(wrapped, f_)
    return wrapped
Exemple #5
0
def _custom_gradient(f):
    """JAX implementation of tf.custom_gradient."""
    if not JAX_MODE:
        # Numpy backend ignores custom gradients, so we do too.
        return lambda *args, **kwargs: f(*args, **kwargs)[0]

    def f_(*args, **kwargs):
        value, vjp = f(*args, **kwargs)

        def vjp_(cts_out):
            cts_in = vjp(cts_out)
            if isinstance(cts_in, list):
                cts_in = tuple(cts_in)
            elif not isinstance(cts_in, tuple):
                cts_in = (cts_in, )
            return cts_in

        return value, vjp_

    @jax.custom_transforms
    def wrapped(*args, **kwargs):
        value, _ = f(*args, **kwargs)
        return value

    jax.defvjp_all(wrapped, f_)
    return wrapped
    def __call__(self, x, params=(), **kwargs):
        assert backend.get_name() == 'jax', (
            'Reversible layers are only supported in JAX')

        if params is () and self._params:  # pylint: disable=literal-comparison
            # TODO(kitaev): Figure out why parameter sharing doesn't work (if this
            # explicit error isn't thrown, a jax tracer error occurs instead)
            raise NotImplementedError(
                'Parameter sharing between reversible layers is not implemented.'
            )

        @jax.custom_transforms
        def do_call(x, params, kwargs):
            return super(ReversibleLayerMixin,
                         self).__call__(x, params, **kwargs)

        def do_call_vjp(x, params, kwargs):
            output = super(ReversibleLayerMixin,
                           self).__call__(x, params, **kwargs)

            def vjpfun(ct):
                _, input_ct = self.inverse_and_vjp(output, ct, params,
                                                   **kwargs)
                return input_ct

            return output, vjpfun

        jax.defvjp_all(do_call, do_call_vjp)
        return do_call(x, params, kwargs)
Exemple #7
0
def build_odeint(ofunc, rtol=1.4e-8, atol=1.4e-8):
    """Return `f(y0, t, args) = odeint(ofunc(y, t, *args), y0, t, args)`.

  Given the function ofunc(y, t, *args), return the jitted function
  `f(y0, t, args) = odeint(ofunc(y, t, *args), y0, t, args)` with
  the VJP of `f` defined using `vjp_odeint`, where:

    `y0` is the initial condition of the ODE integration,
    `t` is the time course of the integration, and
    `*args` are all other arguments to `ofunc`.

  Args:
    ofunc: The function to be wrapped into an ODE integration.
    rtol: relative local error tolerance for solver.
    atol: absolute local error tolerance for solver.

  Returns:
    `f(y0, t, args) = odeint(ofunc(y, t, *args), y0, t, args)`
  """
    ct_odeint = jax.custom_transforms(
        lambda y0, t, *args: odeint(ofunc, y0, t, *args, rtol=rtol, atol=atol))

    v = lambda y0, t, *args: vjp_odeint(
        ofunc, y0, t, *args, rtol=rtol, atol=atol)
    jax.defvjp_all(ct_odeint, v)

    return jax.jit(ct_odeint)
def permute_via_sort(val, keys, inverse_keys, axis=0):
    """Permutation helper for LSH attention."""
    # It is *not* safe to use jax.custom_vjp here (see permute_via_gather).
    keys = jax.lax.stop_gradient(keys)
    inverse_keys = jax.lax.stop_gradient(inverse_keys)

    def permute_impl(val):
        # On TPU, sorting scalars by key is faster than a gather.
        _, permuted = jax.lax.sort_key_val(keys, val, dimension=axis)
        return permuted

    def permute_vjp(val):
        permuted = permute_impl(jax.lax.stop_gradient(val))

        def vjpfun(permuted_grad):
            _, val_grad = jax.lax.sort_key_val(inverse_keys,
                                               permuted_grad,
                                               dimension=axis)
            return (val_grad, )

        return permuted, vjpfun

    permute = jax.custom_transforms(permute_impl)
    jax.defvjp_all(permute, permute_vjp)
    return permute(val)
def permute_via_gather(val, permutation, inverse_permutation, axis=0):
    """Permutation helper for LSH attention."""
    # It is *not* safe to use jax.custom_vjp here. The most likely cause is that
    # it can't close over values: https://github.com/google/jax/issues/2676
    # The error only occurs in some configurations (e.g. use_python_loop = True,
    # num_parallel_heads = 1) but not others.
    permutation = jax.lax.stop_gradient(permutation)
    inverse_permutation = jax.lax.stop_gradient(inverse_permutation)

    def permute_impl(val):
        return jnp.take(val, permutation, axis=axis)

    def permute_vjp(val):
        permuted = permute_impl(jax.lax.stop_gradient(val))

        def vjpfun(permuted_grad):
            # JAX autodiff would synthesize a scatter operation because it doesn't
            # know that the indices are a permutatation. However on TPU, gathers are
            # faster than scatters (at least in the regime the LSH attention uses).
            return (jnp.take(permuted_grad, inverse_permutation, axis=axis), )

        return permuted, vjpfun

    permute = jax.custom_transforms(permute_impl)
    jax.defvjp_all(permute, permute_vjp)
    return permute(val)
Exemple #10
0
    def _do_custom_gradients(self, x, weights, state, **kwargs):
        """Calls this layer for a forward pass, but with custom gradients."""
        assert backend.get_name() == 'jax', (
            'Custom gradients are only supported in JAX for now.')

        # TODO(wangpeng): JAX doesn't support custom grads for functions with
        #   auxiliary output yet (https://github.com/google/jax/issues/844). Will
        #   remove the constraints on state below when this feature is added to
        #   JAX.

        assert not jax.tree_util.tree_leaves(state), (
            'Custom gradients require trivial start state. Got %s' %
            str(state))

        def check_end_state(output_state):
            output, state = output_state
            assert not jax.tree_util.tree_leaves(state), (
                'Custom gradients require trivial end state. Got %s' %
                str(state))
            return output

        # See this link for how custom transformations are defined in JAX:
        # https://jax.readthedocs.io/en/latest/jax.html#jax.custom_transforms
        # Note that we capture the kwargs and don't calculate gradients wrt. them.
        @jax.custom_transforms
        def _do_forward(y, weights):
            return check_end_state(
                self.forward_with_state(y,
                                        weights=weights,
                                        state=state,
                                        **kwargs))

        # This is the custom gradient (vector-jacobian product in JAX) function.
        # For the exact specification of this custom transformation see this link:
        # https://jax.readthedocs.io/en/latest/jax.html#jax.defjvp_all
        def do_forward_vjp(y, weights):
            """Custom gradient (vjp) function."""
            stash = None
            if Layer._STASH_IN is None:
                Layer._STASH_IN = stash = {}
            output = check_end_state(
                self.forward_with_state(y,
                                        weights=weights,
                                        state=state,
                                        **kwargs))
            if stash is not None:
                Layer._STASH_IN = None

            def vjpfun(grad):
                assert Layer._STASH_OUT is None
                Layer._STASH_OUT = stash
                res = self.backward(y, output, grad, weights, state, **kwargs)
                Layer._STASH_OUT = None
                return res

            return output, vjpfun

        jax.defvjp_all(_do_forward, do_forward_vjp)
        return _do_forward(x, weights), state
Exemple #11
0
    def __call__(self, x, params=(), state=(), **kwargs):
        try:
            # If params are nothing, we may be reusing this layer.
            # Use the cached parameters to calculate the value.
            # Note: to make sure jit tracers can decide this branch in python we
            #   use "params is ()" instead of, e.g., "not params" or "params == ()".
            if params is ():  # pylint: disable=literal-comparison
                params = self._params
            else:
                # In this case, we're called for the first time: cache parameters.
                self._params = params

            if not self.has_custom_grad:
                return self.call(x, params=params, state=state, **kwargs)

            # TODO(wangpeng): JAX doesn't support custom grads for functions with
            #   auxiliary output yet (https://github.com/google/jax/issues/844). Will
            #   remove the constraints on state below when this feature is added to
            #   JAX.

            assert state is (), (  # pylint: disable=literal-comparison
                'Custom gradients require trivial start state. Got %s' %
                str(state))

            def check_end_state(output_state):
                output, state = output_state
                assert state is (), (  # pylint: disable=literal-comparison
                    'Custom gradients require trivial end state. Got %s' %
                    str(state))
                return output

            # See this link for how custom transformations are defined in JAX:
            # https://jax.readthedocs.io/en/latest/jax.html#jax.custom_transforms
            # Note that we capture the kwargs and don't calculate gradients wrt. them.
            @jax.custom_transforms
            def do_call(y, params):
                return check_end_state(
                    self.call(y, params=params, state=(), **kwargs))

            # This is the custom gradient (vector-jacobian product in JAX) function.
            # For the exact specification of this custom transformation see this link:
            # https://jax.readthedocs.io/en/latest/jax.html#jax.defjvp_all
            def do_call_vjp(y, params):
                output = check_end_state(
                    self.call(y, params=params, state=(), **kwargs))

                def vjpfun(grad):
                    return self.custom_grad(y, output, grad, params, **kwargs)

                return output, vjpfun

            jax.defvjp_all(do_call, do_call_vjp)
            return do_call(x, params), ()

        except Exception:
            name, trace = self.__class__.__name__, _short_traceback()
            raise LayerError(name, 'call', self._caller, shapes(x), trace)
Exemple #12
0
def permute_via_gather(val, permutation, inverse_permutation, axis=0):
  """Permutation helper for LSH attention."""
  def permute_impl(val):
    return np.take(val, permutation, axis=axis)
  def permute_vjp(val):
    permuted = permute_impl(jax.lax.stop_gradient(val))
    def vjpfun(permuted_grad):
      # JAX autodiff would synthesize a scatter operation because it doesn't
      # know that the indices are a permutatation. However on TPU, gathers are
      # faster than scatters (at least in the regime the LSH attention uses).
      return (np.take(permuted_grad, inverse_permutation, axis=axis),)
    return permuted, vjpfun
  permute = jax.custom_transforms(permute_impl)
  jax.defvjp_all(permute, permute_vjp)
  return permute(val)
Exemple #13
0
    def __call__(self, x, params=(), **kwargs):
        try:
            # If params are nothing, we may be reusing this layer.
            # Use the cached parameters to calculate the value.
            # Note: to make sure jit tracers can decide this branch in python we
            #   use "params is ()" instead of, e.g., "not params" or "params == ()".
            if params is ():  # pylint: disable=literal-comparison
                params = self._params
            # In this case, we're called for the first time: cache parameters.
            self._params = params

            if not self.has_custom_grad:
                return self.call(x, params=params, **kwargs)

            # Custom gradients part.
            assert backend.get_name() == 'jax', (
                'Custom gradients are only supported in JAX for now.')

            # See this link for how custom transformations are defined in JAX:
            # https://jax.readthedocs.io/en/latest/jax.html#jax.custom_transforms
            @jax.custom_transforms
            def do_call(y, params, kwargs):
                return self.call(y, params=params, **kwargs)

            # This is the custom gradient (vector-jacobian product in JAX) function.
            # For the exact specification of this custom transformation see this link:
            # https://jax.readthedocs.io/en/latest/jax.html#jax.defjvp_all
            # Note that we make arguments positional to allow gradients wrt. them.
            def do_call_vjp(y, params, kwargs):
                output = self.call(y, params=params, **kwargs)

                def vjpfun(grad):
                    return self.custom_grad(y, output, grad, params, **kwargs)

                return output, vjpfun

            jax.defvjp_all(do_call, do_call_vjp)
            return do_call(x, params, kwargs)

        except Exception:
            name, trace = self.__class__.__name__, _short_traceback()
            raise LayerError(name, 'call', self._caller, shapes(x), trace)
Exemple #14
0
    def _do_custom_gradients(self, x, weights, state, **kwargs):
        """Calls this layer for a forward pass, but with custom gradients."""
        assert backend.get_name() == 'jax', (
            'Custom gradients are only supported in JAX for now.')

        # See this link for how custom transformations are defined in JAX:
        # https://jax.readthedocs.io/en/latest/jax.html#jax.custom_transforms
        # Note that we capture the kwargs and don't calculate gradients wrt. them.
        @jax.custom_transforms
        def _do_forward(y, weights):
            res = self.forward_with_state(y,
                                          weights=weights,
                                          state=state,
                                          **kwargs)
            return res

        # This is the custom gradient (vector-jacobian product in JAX) function.
        # For the exact specification of this custom transformation see this link:
        # https://jax.readthedocs.io/en/latest/jax.html#jax.defjvp_all
        def do_forward_vjp(y, weights):
            """Custom gradient (vjp) function."""
            output, new_state = self.forward_with_state(y,
                                                        weights=weights,
                                                        state=state,
                                                        **kwargs)

            def vjpfun(grad):
                grad = grad[0]  # Ignore dummy gradient wrt state.
                res = self.backward(y, output, grad, weights, state, new_state,
                                    **kwargs)
                return res

            return (output, state), vjpfun

        jax.defvjp_all(_do_forward, do_forward_vjp)
        output, state = _do_forward(x, weights)
        state = jax.lax.stop_gradient(state)
        return output, state
Exemple #15
0
    x[(x < up_limit) & (x > low_limit)] = np.log(
        softplus(x[(x < up_limit) & (x > low_limit)]))
    #x[x<low_limit] = x[x<low_limit]
    return x


def safe_logsoftplus_vjp(ans, x, low_limit=-30):
    x_shape = x.shape
    operator = np.ones(x.shape)
    operator[x > low_limit] = 1 / (
        (1 + np.exp(-x[x > low_limit])) * softplus(x[x > low_limit]))
    return lambda g: np.full(x_shape, g) * operator
    #return lambda g: np.full(x_shape, g) * 1/ ((1+np.exp(-x))*safe_softplus(x))


jax.defvjp_all(safe_logsoftplus, safe_logsoftplus_vjp)


def make_cov(N, rh, len_sc):
    M1 = np.array([np.arange(N)]) - np.transpose(np.array([np.arange(N)]))
    K = rh * np.exp(-(np.square(M1) / (2 * np.square(len_sc))))
    return K


def bbvi(logprob, N, num_samples):
    """Implements http://arxiv.org/abs/1401.0118, and uses the
    local reparameterization trick from http://arxiv.org/abs/1506.02557

    Structure of function taken from: https://github.com/HIPS/autograd/blob/master/examples/black_box_svi.py
	
	inputs: 
Exemple #16
0
    def forward_unbatched(self, x, *, weights, state, update_state):
        w_q, w_v, w_o = weights

        q = np.matmul(x, w_q)
        v = np.matmul(x, w_v)

        if update_state:
            _, old_rng = state
            rng = jax.random.fold_in(old_rng, 0)
            hash_rng = jax.random.fold_in(rng, 1)
            buckets = self.hash_vectors(q, hash_rng)
            state = (buckets, rng)
        else:
            buckets, rng = state

        rng = jax.random.fold_in(rng, 2)

        seqlen = x.shape[0]
        assert int(buckets.shape[0]) == self.n_hashes * seqlen

        ticker = jax.lax.tie_in(x, np.arange(self.n_hashes * seqlen))
        buckets_and_t = seqlen * buckets + (ticker % seqlen)
        buckets_and_t = jax.lax.stop_gradient(buckets_and_t)

        # Hash-based sort ("s" at the start of variable names means "sorted")
        sbuckets_and_t, sticker = jax.lax.sort_key_val(buckets_and_t,
                                                       ticker,
                                                       dimension=-1)
        _, undo_sort = jax.lax.sort_key_val(sticker, ticker, dimension=-1)
        sbuckets_and_t = jax.lax.stop_gradient(sbuckets_and_t)
        sticker = jax.lax.stop_gradient(sticker)
        undo_sort = jax.lax.stop_gradient(undo_sort)

        st = (sticker % seqlen)
        sq = np.take(q, st, axis=0)
        sv = np.take(v, st, axis=0)

        mask_fn = functools.partial(mask_self_attention,
                                    causal=self.causal,
                                    exclude_self=True)
        q_info = st
        so, slogits = attend(
            sq,
            k=None,
            v=sv,
            q_chunk_len=self.chunk_len,
            n_chunks_before=self.n_chunks_before,
            n_chunks_after=self.n_chunks_after,
            mask_fn=mask_fn,
            q_info=q_info,
            dropout=self.attention_dropout,
            rng=rng,
        )

        def unsort_for_output_impl(so, slogits):
            o = np.take(so, undo_sort, axis=0)
            # Sorting is considerably faster than gather, but first we need to get the
            # XLA compiler to abandon the idea of fusing this sort with the input sort
            # (which introduces a computation cycle and leads to a crash).
            # TODO(kitaev): remove "sticker_" variable if XLA is fixed.
            sticker_ = sticker + jax.lax.convert_element_type(
                slogits[0] > 0, sticker.dtype)
            _, logits = jax.lax.sort_key_val(sticker_, slogits, dimension=-1)
            return o, logits

        def unsort_for_output_vjp(so, slogits):
            """Custom gradient for unsort_for_output."""
            so = jax.lax.stop_gradient(so)
            slogits = jax.lax.stop_gradient(slogits)
            o, logits = unsort_for_output_impl(so, slogits)

            def vjpfun(o_logits_grads):
                so_grad = np.take(o_logits_grads[0], sticker, axis=0)
                # TODO(kitaev): this exists to match the forward pass, but I'm not sure
                # if it's actually required.
                buckets_and_t_ = buckets_and_t + jax.lax.convert_element_type(
                    o_logits_grads[1][0] > 0, buckets_and_t.dtype)
                _, slogits_grad = jax.lax.sort_key_val(buckets_and_t_,
                                                       o_logits_grads[1],
                                                       dimension=-1)
                return (so_grad, slogits_grad)

            return (o, logits), vjpfun

        unsort_for_output = jax.custom_transforms(unsort_for_output_impl)
        jax.defvjp_all(unsort_for_output, unsort_for_output_vjp)
        o, logits = unsort_for_output_impl(so, slogits)

        if self.n_hashes > 1:
            o = np.reshape(o, (self.n_hashes, seqlen, o.shape[-1]))
            logits = np.reshape(logits, (self.n_hashes, seqlen, 1))
            probs = np.exp(logits - logsumexp(logits, axis=0, keepdims=True))
            o = np.sum(o * probs, axis=0)

        assert o.shape == (seqlen, w_v.shape[-1])
        out = np.matmul(o, w_o)
        return out, state
    def single_call(self, qk, v, buckets, rng=None):
        # We use the same vector as both a query and a key.
        seqlen = qk.shape[-2]
        assert int(buckets.shape[0]) == self.n_hashes * seqlen

        ticker = jax.lax.tie_in(qk, np.arange(self.n_hashes * seqlen))
        buckets_and_t = seqlen * buckets + (ticker % seqlen)
        buckets_and_t = jax.lax.stop_gradient(buckets_and_t)

        # Hash-based sort ("s" at the start of variable names means "sorted")
        sbuckets_and_t, sticker = jax.lax.sort_key_val(buckets_and_t,
                                                       ticker,
                                                       dimension=-1)
        _, undo_sort = jax.lax.sort_key_val(sticker, ticker, dimension=-1)
        sbuckets_and_t = jax.lax.stop_gradient(sbuckets_and_t)
        sticker = jax.lax.stop_gradient(sticker)
        undo_sort = jax.lax.stop_gradient(undo_sort)

        st = (sticker % seqlen)
        sqk = np.take(qk, st, axis=0)
        sv = np.take(v, st, axis=0)

        # Split off a "bin" axis so that attention only occurs within chunks.
        bq_t = bkv_t = np.reshape(st, (self.n_hashes * self.n_bins, -1))
        bqk = np.reshape(sqk, (self.n_hashes * self.n_bins, -1, sqk.shape[-1]))
        bv = np.reshape(sv, (self.n_hashes * self.n_bins, -1, sv.shape[-1]))
        bq_buckets = bkv_buckets = np.reshape(
            sbuckets_and_t // seqlen, (self.n_hashes * self.n_bins, -1))

        # Hashing operates on unit-length vectors. Unnormalized query vectors are
        # fine because they effectively provide a learnable temperature for the
        # attention softmax, but normalizing keys is needed so that similarity for
        # the purposes of attention correctly corresponds to hash locality.
        bq = bqk
        bk = self.make_unit_length(bqk)

        # Allow each chunk to attend within itself, and also one chunk back. Chunk
        # boundaries might occur in the middle of a sequence of items from the
        # same bucket, so this increases the chances of attending to relevant items.
        # TODO(kitaev): benchmark whether XLA pad operation is noticeably faster.
        def look_one_back(x):
            if len(x.shape) == 2:
                x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0)
            else:
                x_extra = np.concatenate([x[-1:, :, :], x[:-1, :, :]], axis=0)
            return np.concatenate([x, x_extra], axis=1)

        bk = look_one_back(bk)
        bv = look_one_back(bv)
        bkv_t = look_one_back(bkv_t)
        bkv_buckets = look_one_back(bkv_buckets)

        # Dot-product attention.
        dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(bq.shape[-1])

        # Causal masking
        mask = jax.lax.convert_element_type(
            jax.lax.lt(bq_t[:, :, None], bkv_t[:, None, :]), np.float32)
        dots = dots - 1e9 * mask

        # Mask out attention to self except when no other targets are available.
        self_mask = jax.lax.convert_element_type(
            jax.lax.eq(bq_t[:, :, None], bkv_t[:, None, :]), np.float32)
        dots = dots - 1e5 * self_mask

        # Mask out attention to other hash buckets.
        if not self._attend_across_buckets:
            bucket_mask = jax.lax.convert_element_type(
                jax.lax.ne(bq_buckets[:, :, None], bkv_buckets[:, None, :]),
                np.float32)
            dots = dots - 1e7 * bucket_mask

        # Don't double-count query-key pairs across multiple rounds of hashing.
        # There are two possible strategies here. (1) The default is to count how
        # many times a query-key pair is repeated, and to lower its log-prob
        # correspondingly at each repetition. (2) When hard_k is set, the code
        # instead masks all but the first occurence of each query-key pair.
        # TODO(kitaev): is one strategy faster or more numerically stable?
        if not self._allow_duplicate_attention:
            locs1 = undo_sort // bq_t.shape[-1]
            locs2 = (locs1 + 1) % (self.n_hashes * self.n_bins)
            if not self._attend_across_buckets:
                locs1 = buckets * (self.n_hashes * self.n_bins) + locs1
                locs2 = buckets * (self.n_hashes * self.n_bins) + locs2
            locs = np.moveaxis(
                np.concatenate([
                    np.reshape(locs1, (self.n_hashes, seqlen)),
                    np.reshape(locs2, (self.n_hashes, seqlen)),
                ], 0), 0, -1)  # produces shape (seqlen, 2 * self.n_hashes)
            slocs = np.take(locs, st, axis=0)
            b_locs = np.reshape(
                slocs, (self.n_hashes * self.n_bins, -1, 2 * self.n_hashes))
            # Queries always use the primary location (based on locs1).
            b_locs1 = b_locs[:, :, None, :self.n_hashes]
            if self._hard_k > 0:
                range_n_hashes = jax.lax.tie_in(b_locs,
                                                np.arange(self.n_hashes))
                nouse_locs = (range_n_hashes[:, None] >
                              range_n_hashes[None, :])
                nouse_locs = 2 * nouse_locs - 1  # 1 = use, -1 = don't use
                nouse_locs = np.reshape(
                    np.broadcast_to(
                        nouse_locs[:, None, :],
                        (self.n_hashes, self.n_bins, self.n_hashes)),
                    (self.n_hashes * self.n_bins, 1, 1, self.n_hashes))
                b_locs1 = b_locs1 * nouse_locs
            bq_locs = np.broadcast_to(b_locs1,
                                      b_locs.shape[:2] + (2, self.n_hashes))
            bq_locs = np.reshape(bq_locs, b_locs.shape)
            bkv_locs = look_one_back(b_locs)

            dup_counts = np.sum(jax.lax.convert_element_type(
                jax.lax.eq(bq_locs[:, :, None, :], bkv_locs[:, None, :, :]),
                np.float32),
                                axis=-1)
            assert dup_counts.shape == dots.shape
            if self._hard_k > 0:
                dots = dots - 1e7 * jax.lax.stop_gradient(dup_counts)
            else:
                dots = dots - jax.lax.stop_gradient(np.log(dup_counts + 1e-9))

        # Each query only attends to the top k most relevant keys.
        if self._hard_k > 0:
            b_top_dots = np.sort(dots)[...,
                                       -self._hard_k:]  # Get the top k dots.
            b_top_dots = jax.lax.stop_gradient(b_top_dots)
            s_top_dots = np.reshape(b_top_dots, (-1, self._hard_k))
            top_dots = np.take(s_top_dots, undo_sort, axis=0)

            merged_top_dots = np.moveaxis(
                np.reshape(top_dots, (self.n_hashes, seqlen, self._hard_k)), 0,
                -1)
            merged_top_dots = np.reshape(merged_top_dots, (seqlen, -1))

            dots_thresh = np.sort(merged_top_dots)[:, -self._hard_k]
            # It's possible to compute the partition function at this point, but right
            # now this codepath isn't set up for backprop, and there might also be
            # issues computing it this way if two dot-products are exactly equal.

            sdots_thresh = dots_thresh[st]
            bdots_thresh = np.reshape(sdots_thresh,
                                      (self.n_hashes * self.n_bins, -1))
            bdots_thresh = jax.lax.stop_gradient(bdots_thresh)

            top_k_mask = jax.lax.convert_element_type(
                dots < bdots_thresh[..., None], np.float32)
            dots = dots - 1e7 * jax.lax.stop_gradient(top_k_mask)

        # Softmax.
        dots_logsumexp = backend.logsumexp(dots, axis=-1, keepdims=True)
        dots = np.exp(dots - dots_logsumexp)

        if self._dropout > 0.0:
            # Dropout is broadcast across the bin dimension
            dropout_shape = (1, dots.shape[-2], dots.shape[-1])
            keep_prob = jax.lax.tie_in(dots, 1.0 - self._dropout)
            keep = backend.random.bernoulli(rng, keep_prob, dropout_shape)
            multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(
                keep, keep_prob)
            dots = dots * multiplier

        bo = np.matmul(dots, bv)
        so = np.reshape(bo, (-1, bo.shape[-1]))
        slogits = np.reshape(dots_logsumexp, (-1, ))

        def unsort_for_output_impl(so, slogits):
            o = np.take(so, undo_sort, axis=0)
            # Sorting is considerably faster than gather, but first we need to get the
            # XLA compiler to abandon the idea of fusing this sort with the input sort
            # (which introduces a computation cycle and leads to a crash).
            # TODO(kitaev): remove "sticker_" variable if XLA is fixed.
            sticker_ = sticker + jax.lax.convert_element_type(
                slogits[0] > 0, sticker.dtype)
            _, logits = jax.lax.sort_key_val(sticker_, slogits, dimension=-1)
            return o, logits

        def unsort_for_output_vjp(so, slogits):
            """Custom gradient for unsort_for_output."""
            so = jax.lax.stop_gradient(so)
            slogits = jax.lax.stop_gradient(slogits)
            o, logits = unsort_for_output_impl(so, slogits)

            def vjpfun(o_logits_grads):
                so_grad = np.take(o_logits_grads[0], sticker, axis=0)
                # TODO(kitaev): this exists to match the forward pass, but I'm not sure
                # if it's actually required.
                buckets_and_t_ = buckets_and_t + jax.lax.convert_element_type(
                    o_logits_grads[1][0] > 0, buckets_and_t.dtype)
                _, slogits_grad = jax.lax.sort_key_val(buckets_and_t_,
                                                       o_logits_grads[1],
                                                       dimension=-1)
                return (so_grad, slogits_grad)

            return (o, logits), vjpfun

        unsort_for_output = jax.custom_transforms(unsort_for_output_impl)
        jax.defvjp_all(unsort_for_output, unsort_for_output_vjp)
        o, logits = unsort_for_output_impl(so, slogits)

        if self.n_hashes == 1:
            out = o
        else:
            o = np.reshape(o, (self.n_hashes, seqlen, o.shape[-1]))
            logits = np.reshape(logits, (self.n_hashes, seqlen, 1))
            probs = np.exp(logits -
                           backend.logsumexp(logits, axis=0, keepdims=True))
            out = np.sum(o * probs, axis=0)

        assert out.shape == v.shape
        return out
Exemple #18
0
def _custom_grad(f_vjp, f_original):
  f_ = jax.custom_transforms(f_original)
  jax.defvjp_all(f_, f_vjp)
  return f_
Exemple #19
0
    def call(self, inputs, params=(), state=(), rng=None, **kwargs):
        del params, kwargs
        # We use the same vector as both a query and a key. For now we haven't
        # adjusted any of the surrounding code, so we still get a separate "key"
        # input that we ignore.
        qk, _, v = inputs
        seqlen = qk.shape[-2]

        # qk/v are n_hashes*n_batch*n_heads, seqlen, d_head
        # TODO(kitaev): is it faster to fuse this tiling into gather/scatter ops?
        qk = np.tile(qk, (self.n_hashes, 1, 1))
        v = np.tile(v, (self.n_hashes, 1, 1))

        # bins are n_hashes*n_batch*n_heads, seqlen
        # They specify which hash bucket the query/key/value vectors fall in.
        bins = self.hash_vectors(qk, rng=rng)

        # joint_t is n_hashes*n_batch*n_heads, seqlen
        joint_t = jax.lax.tie_in(qk, np.arange(seqlen))
        joint_t = np.reshape(joint_t, (1, seqlen))
        joint_t = np.broadcast_to(joint_t, qk.shape[:-1])

        assert int(
            (self.n_buckets_per_bin * self.n_bins + 1) * seqlen
        ) < 2**31, (
            'Potential 32-bit integer overflow; please double-check the code.')
        joint_bins_and_t = seqlen * bins + joint_t

        def chunk_scalars(x):  # pylint: disable=invalid-name
            return np.reshape(x, (x.shape[0], self.n_bins, -1))

        def chunk_vectors(x):  # pylint: disable=invalid-name
            return np.reshape(x, (x.shape[0], self.n_bins, -1, x.shape[-1]))

        def unchunk_vectors(x):  # pylint: disable=invalid-name
            return np.reshape(x, (x.shape[0], -1, x.shape[-1]))

        # Sort everything by bin number, with a secondary sort by time
        # (variables starting with "s" are sorted)
        _, sjoint_t = jax.lax.sort_key_val(joint_bins_and_t,
                                           joint_t,
                                           dimension=-1)
        _, undo_sort = jax.lax.sort_key_val(sjoint_t, joint_t, dimension=-1)
        # TODO(kitaev): why does jax flag integer indices as differentiable?
        # If we don't call stop_gradient here, custom gradients below won't work
        # because the primitive functions close over "differentiable" variables.
        sjoint_t = jax.lax.stop_gradient(sjoint_t)
        undo_sort = jax.lax.stop_gradient(undo_sort)

        # The backward pass of gather is in general a scatter operation, but we know
        # we're dealing with permutations so we use gather for the backward pass
        # too. This custom gradient should be about 2x faster than having jax infer
        # one that uses scatter ops instead.
        def permute_impl(vecs):
            assert len(vecs.shape) == 3
            return np.take_along_axis(vecs, sjoint_t[:, :, None], axis=-2)

        def unpermute_impl(vecs):
            assert len(vecs.shape) == 3
            return np.take_along_axis(vecs, undo_sort[:, :, None], axis=-2)

        @jax.custom_transforms
        def permute(vecs):
            return permute_impl(vecs)

        def permute_vjp(vecs):
            out_vecs = permute_impl(vecs)

            def vjpfun(grad):
                return (unpermute_impl(grad), )

            return out_vecs, vjpfun

        @jax.custom_transforms
        def unpermute(vecs):
            return unpermute_impl(vecs)

        def unpermute_vjp(vecs):
            out_vecs = unpermute_impl(vecs)

            def vjpfun(grad):
                return (permute_impl(grad), )

            return out_vecs, vjpfun

        jax.defvjp_all(permute, permute_vjp)
        jax.defvjp_all(unpermute, unpermute_vjp)

        sqk = permute(qk)
        sv = permute(v)

        # Split off a "bin" axis so that attention only occurs within chunks.
        bq_t = bkv_t = chunk_scalars(sjoint_t)
        bqk = chunk_vectors(sqk)
        bv = chunk_vectors(sv)

        # Hashing operates on unit-length vectors. Unnormalized query vectors are
        # fine because they effectively provide a learnable temperature for the
        # attention softmax, but normalizing keys is needed so that similarity for
        # the purposes of attention correctly corresponds to hash locality.
        bq = bqk
        bk = self.make_unit_length(bqk)

        # Allow each chunk to attend within itself, and also one chunk back. Chunk
        # boundaries might occur in the middle of a sequence of items from the
        # same bin, so this increases the chances of attending to relevant items.
        # TODO(kitaev): benchmark whether XLA pad operation is noticeably faster.
        bk_extra = np.concatenate([bk[:, -1:, :, :], bk[:, :-1, :, :]], axis=1)
        bk = np.concatenate([bk, bk_extra], axis=2)
        bv_extra = np.concatenate([bv[:, -1:, :, :], bv[:, :-1, :, :]], axis=1)
        bv = np.concatenate([bv, bv_extra], axis=2)
        bkv_t_extra = np.concatenate([bkv_t[:, -1:, :], bkv_t[:, :-1, :]],
                                     axis=1)
        bkv_t = np.concatenate([bkv_t, bkv_t_extra], axis=2)

        # Dot-product attention.
        dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(bq.shape[-1])

        # Causal masking
        mask = jax.lax.convert_element_type(
            jax.lax.lt(bq_t[:, :, :, None], bkv_t[:, :, None, :]), np.float32)
        dots = dots - 1e9 * mask

        # Mask out attention to self except when no other targets are available.
        self_mask = jax.lax.broadcasted_eye(dots.dtype, dots.shape, (2, 3))
        self_mask = jax.lax.tie_in(dots, self_mask)
        dots = dots - 32 * self_mask

        # Softmax.
        dots_logsumexp = backend.logsumexp(dots, axis=-1, keepdims=True)
        dots = np.exp(dots - dots_logsumexp)

        if self._hard_k > 0:
            top_k = np.sort(dots)[...,
                                  -self._hard_k]  # Get the top-kth weight.
            top_k = jax.lax.stop_gradient(top_k)
            dots -= top_k[..., np.newaxis]  # Subtract (be 0 for lower ones).
            dots = np.maximum(dots, 0)
            dots_sum = np.sum(dots, axis=-1,
                              keepdims=True)  # Sum to re-normalize.
            dots_logsumexp += np.log(dots_sum)  # Add it to the weight.
            dots /= dots_sum  # Re-normalize.

        bo = np.matmul(dots, bv)
        so = unchunk_vectors(bo)
        slogits = unchunk_vectors(dots_logsumexp)

        o = unpermute(so)
        logits = unpermute(slogits)

        o = np.reshape(o, (self.n_hashes, -1, seqlen, o.shape[-1]))
        logits = np.reshape(logits, (self.n_hashes, -1, seqlen, 1))
        probs = np.exp(logits -
                       backend.logsumexp(logits, axis=0, keepdims=True))
        out = np.sum(o * probs, axis=0)
        assert out.shape == inputs[2].shape

        return out, state
Exemple #20
0
  @jax.custom_transforms
  def _minimum_(x, y, name=None):  # pylint: disable=unused-argument
    return np.minimum(x, y)

  # TF and Jax have differing behavior
  # when the inputs to maximum/minimum are equal.
  # This custom transforms rule
  # modifies Jax to match TF's behavior.

  def _maximum_vjp(x, y):
    out_primals = _maximum_(x, y)
    def vjp(g):
      gx = g * np.where(x >= y, np.ones_like(x), np.zeros_like(x))
      return (gx.astype(x.dtype), (g - gx).astype(y.dtype))
    return out_primals, vjp
  jax.defvjp_all(_maximum_, _maximum_vjp)

  def _minimum_vjp(x, y):
    out_primals = _minimum_(x, y)
    def vjp(g):
      gx = g * np.where(x <= y, np.ones_like(x), np.zeros_like(x))
      return (gx.astype(x.dtype), (g - gx).astype(y.dtype))
    return out_primals, vjp
  jax.defvjp_all(_minimum_, _minimum_vjp)

  # Need to wrap in a function because custom_transforms
  # returns an object, not a function
  # which breaks docstring wrapping

  def _promote_dtypes(x, y):
    # Need to explicitly promote types