def finalize(f):
    """Decorate f with a custom gradient."""

    if JAX_MODE:

      # https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html

      # For JAX, we prefer to specify a custom JVP, as JAX can use a function
      # transform to transpose a JVP (must be linear in the tangents) to a VJP.
      if jvp_fn is not None:
        f_jvp = custom_jvp(f, nondiff_argnums=nondiff_argnums)
        f_jvp.defjvp(jvp_fn)
        return f_jvp

      else:
        from jax import custom_vjp  # pylint: disable=g-import-not-at-top
        f_vjp = custom_vjp(f, nondiff_argnums=nondiff_argnums)
        f_vjp.defvjp(vjp_fwd, vjp_bwd)
        return f_vjp

    else:
      # TF custom gradients support only custom VJPs.
      @tf.custom_gradient
      def f_wrapped(*args, **kwargs):
        val, aux = vjp_fwd(*args, **kwargs)
        def vjp_bwd_wrapped(*g):
          result = vjp_bwd(aux, tf.nest.pack_sequence_as(val, g))
          for i in nondiff_argnums:
            result = tuple(result[:i]) + (None,) + tuple(result[i:])
          return result
        return val, vjp_bwd_wrapped

      return f_wrapped
Example #2
0
    def inner(scope_fn, repack_fn, variable_groups_xs, rng_groups_xs, *args):
        assert len(variable_groups_xs
                   ) == 1, 'transform does not support multi-scope lifting.'
        grad_variables, other_variables = variable_groups_xs[0]

        def simple_scope_fn(grad_variables):
            return scope_fn(((freeze(grad_variables), other_variables), ),
                            rng_groups_xs)

        def f(grad_variables, *args):
            scope = scope_fn(((grad_variables, other_variables), ),
                             rng_groups_xs)
            y, _ = module_fn(scope, *args)
            vars_out = repack_fn(scope)
            return y, vars_out

        f = jax.custom_vjp(f, nondiff_argnums=nondiff_argnums)

        def f_fwd(grad_variables, *args):
            scope = simple_scope_fn(grad_variables)
            y, res = module_fn(scope, *args)
            vars_out = repack_fn(scope)
            return (y, vars_out), (res, grad_variables)

        def f_bwd(*args):
            nondiff_args = args[:-2]
            res, g = args[-2:]
            g_y, _ = g
            user_res, grad_variables = res
            return backward_fn(*nondiff_args, simple_scope_fn, grad_variables,
                               user_res, g_y)

        f.defvjp(f_fwd, f_bwd)

        return f(grad_variables, *args)
Example #3
0
    def test_eigh_vjp(self):
        n = 20
        dtype = np.float64
        a = random_spd_coo(n=n, dtype=dtype)
        a = a.todense()

        def eigh(a):
            w, v = jax.numpy.linalg.eigh(a)
            v = standardize_eigenvector_signs(v)
            return w, v

        def eigh_fwd(a):
            w, v = eigh(a)
            return (w, v), (w, v)

        def eigh_rev(res, g):
            grad_w, grad_v = g
            w, v = res
            grad_a = cg.eigh_rev(grad_w, grad_v, w, v)
            grad_a = symmetrize(grad_a)
            return (grad_a, )

        eigh_fun = jax.custom_vjp(eigh)
        eigh_fun.defvjp(eigh_fwd, eigh_rev)
        jtu.check_grads(eigh_fun, (a, ), order=1, modes="rev", rtol=1e-3)
        w, v = eigh(a)
        self.assertAllClose(a @ v, v * w, rtol=1e-6)
Example #4
0
File: lift.py Project: pschuh/flax
    def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args):
        grad_variables, other_variables = variable_groups

        def simple_scope_fn(grad_variables):
            grad_variables = tuple(freeze(x) for x in grad_variables)
            return scope_fn((grad_variables, other_variables), rng_groups)

        def f(grad_variables, *args):
            scope = scope_fn((grad_variables, other_variables), rng_groups)
            y, _ = fn(scope, *args)
            vars_out = repack_fn(scope)
            return y, vars_out

        f = jax.custom_vjp(f, nondiff_argnums=nondiff_argnums)

        def f_fwd(grad_variables, *args):
            scope = simple_scope_fn(grad_variables)
            y, res = fn(scope, *args)
            vars_out = repack_fn(scope)
            return (y, vars_out), (res, grad_variables)

        def f_bwd(*args):
            nondiff_args = args[:-2]
            res, g = args[-2:]
            g_y, _ = g
            user_res, grad_variables = res
            return backward_fn(*nondiff_args, simple_scope_fn, grad_variables,
                               user_res, g_y)

        f.defvjp(f_fwd, f_bwd)

        return f(grad_variables, *args)
Example #5
0
    def test_eigh_partial_vjp(self):
        dtype = np.float64
        n = 20
        k = 4
        largest = False
        a = random_spd_coo(n, dtype=dtype).todense()

        def eigh_partial_fwd(a, k: int, largest: bool):
            w, v = eigh_partial(a, k, largest)
            return (w, v), (w, v, a)

        def eigh_partial_rev(res, g):
            w, v, a = res
            grad_w, grad_v = g
            rng_key = jax.random.PRNGKey(0)
            x0 = jax.random.normal(rng_key, v.shape, dtype=v.dtype)
            grad_a, x0 = cg.eigh_partial_rev(grad_w, grad_v, w, v, a, x0)
            grad_a = symmetrize(grad_a)
            return (grad_a, None, None)

        eigh_partial_fun = jax.custom_vjp(eigh_partial)
        eigh_partial_fun.defvjp(eigh_partial_fwd, eigh_partial_rev)

        jtu.check_grads(
            partial(eigh_partial_fun, k=k, largest=largest),
            (a, ),
            1,
            modes=["rev"],
            rtol=1e-3,
        )
Example #6
0
    def test_lobpcg_csr_vjp(self):
        m = 50
        k = 10
        largest = False
        dtype = np.float64

        def lobpcg_csr(data, indices, indptr, X0, largest, k):
            data = csr.symmetrize(data, indices)
            A = csr.matmul_fun(data, indices, indptr)
            w, v = lobpcg(A, X0, largest=largest, k=k)
            v = standardize_eigenvector_signs(v)
            return w, v

        def lobpcg_fwd(data, indices, indptr, X0, largest, k):
            w, v = lobpcg_csr(data, indices, indptr, X0, largest, k)
            return (w, v), (w, v, data, indices, indptr)

        def lobpcg_rev(res, g):
            grad_w, grad_v = g
            w, v, data, indices, indptr = res
            A = csr.matmul_fun(data, indices, indptr)
            x0 = jax.random.normal(jax.random.PRNGKey(0),
                                   shape=v.shape,
                                   dtype=v.dtype)
            grad_data, x0 = eigh_partial_rev(
                grad_w,
                grad_v,
                w,
                v,
                A,
                x0,
                outer_impl=csr.masked_outer_fun(indices, indptr),
            )
            grad_data = csr.symmetrize(grad_data, indices)
            return grad_data, None, None, None, None, None

        lobpcg_fun = jax.custom_vjp(lobpcg_csr)
        lobpcg_fun.defvjp(lobpcg_fwd, lobpcg_rev)

        rng = np.random.default_rng(0)
        A = random_spd_csr(m, sparsity=0.1, dtype=dtype)
        data, indices, indptr, _ = csr_components(A)

        X0 = rng.uniform(size=(m, k)).astype(dtype)
        jtu.check_grads(
            partial(lobpcg_fun,
                    indices=indices,
                    indptr=indptr,
                    X0=X0,
                    largest=largest,
                    k=k),
            (data, ),
            order=1,
            modes=["rev"],
            rtol=2e-3,
        )
Example #7
0
    def test_eigh_partial_coo_vjp(self):
        dtype = np.float64
        n = 20
        k = 4
        largest = False
        a = random_spd_coo(n, dtype=dtype)

        def eigh_partial_coo(data, row, col, size, k: int, largest: bool):
            data = coo.symmetrize(data, row, col, size)
            a = coo.to_dense(data, row, col, (size, size))
            w, v = eigh_partial(a, k, largest)
            v = standardize_eigenvector_signs(v)
            return w, v

        def eigh_partial_fwd(data, row, col, size, k: int, largest: bool):
            w, v = eigh_partial_coo(data, row, col, size, k, largest)
            return (w, v), (w, v, data, row, col)

        def eigh_partial_rev(res, g):
            w, v, data, row, col = res
            size = v.shape[0]
            grad_w, grad_v = g
            rng_key = jax.random.PRNGKey(0)
            x0 = jax.random.normal(rng_key, shape=v.shape, dtype=w.dtype)
            grad_data, x0 = cg.eigh_partial_rev(
                grad_w,
                grad_v,
                w,
                v,
                coo.matmul_fun(data, row, col, jnp.zeros((size, ))),
                x0,
                outer_impl=coo.masked_outer_fun(row, col),
            )
            grad_data = coo.symmetrize(grad_data, row, col, size)
            return (grad_data, None, None, None, None, None)

        eigh_partial_fn = jax.custom_vjp(eigh_partial_coo)
        eigh_partial_fn.defvjp(eigh_partial_fwd, eigh_partial_rev)

        data, row, col, _ = coo_components(a)
        self.assertTrue(coo.is_symmetric(row, col, data))
        self.assertTrue(coo.is_ordered(row, col))
        jtu.check_grads(
            partial(eigh_partial_fn,
                    k=k,
                    largest=largest,
                    row=row,
                    col=col,
                    size=n),
            (data, ),
            1,
            modes=["rev"],
            rtol=1e-3,
        )
Example #8
0
 def __init__(self,
              fun: Callable[Concatenate[U, P], R_co],
              static_argnums: Tuple[int, ...] = ()):
     """
     Args:
         fun: the function to decorate.
         static_argnums: The indices of the static arguments.
     """
     super().__init__()
     static_argnums = tuple(sorted(static_argnums))
     self.vjp = jax.custom_vjp(fun, nondiff_argnums=static_argnums)
Example #9
0
    def test_lobpcg_coo_vjp(self):
        m = 50
        k = 10
        largest = False
        dtype = np.float64

        def lobpcg_coo(data, row, col, X0, largest, k):
            size = X0.shape[0]
            data = coo.symmetrize(data, row, col, size)
            A = coo.matmul_fun(data, row, col, jnp.zeros((size, )))
            w, v = lobpcg(A, X0, largest=largest, k=k)
            v = standardize_eigenvector_signs(v)
            return w, v

        def lobpcg_fwd(data, row, col, X0, largest, k):
            w, v = lobpcg_coo(data, row, col, X0, largest, k)
            return (w, v), (w, v, data, row, col)

        def lobpcg_rev(res, g):
            grad_w, grad_v = g
            w, v, data, row, col = res
            size = v.shape[0]
            A = coo.matmul_fun(data, row, col, jnp.zeros((size, )))
            x0 = jax.random.normal(jax.random.PRNGKey(0),
                                   shape=v.shape,
                                   dtype=v.dtype)
            grad_data, x0 = eigh_partial_rev(grad_w,
                                             grad_v,
                                             w,
                                             v,
                                             A,
                                             x0,
                                             outer_impl=coo.masked_outer_fun(
                                                 row, col))
            grad_data = coo.symmetrize(grad_data, row, col, size)
            return grad_data, None, None, None, None, None

        lobpcg_fun = jax.custom_vjp(lobpcg_coo)
        lobpcg_fun.defvjp(lobpcg_fwd, lobpcg_rev)

        rng = np.random.default_rng(0)
        A = random_spd_coo(m, sparsity=0.1, dtype=dtype)
        data, row, col, _ = coo_components(A)

        X0 = rng.uniform(size=(m, k)).astype(dtype)
        jtu.check_grads(
            partial(lobpcg_fun, row=row, col=col, X0=X0, largest=largest, k=k),
            (data, ),
            order=1,
            modes=["rev"],
            rtol=2e-3,
        )
Example #10
0
  def finalize(f):
    """Decorate f with a custom gradient."""

    if JAX_MODE:

      # https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html

      # For JAX, we prefer to specify a custom JVP, as JAX can use a function
      # transform to transpose a JVP (must be linear in the tangents) to a VJP.
      if jvp_fn is not None:
        f_jvp = custom_jvp(f, nondiff_argnums=nondiff_argnums)
        f_jvp.defjvp(jvp_fn)
        return f_jvp

      else:
        from jax import custom_vjp  # pylint: disable=g-import-not-at-top
        f_vjp = custom_vjp(f, nondiff_argnums=nondiff_argnums)
        f_vjp.defvjp(vjp_fwd, vjp_bwd)
        return f_vjp

    else:
      # TF custom gradients support only custom VJPs.
      def none_wrapper(*args, **kwargs):  # custom_gradient can't handle None.
        closure = {i: a for i, a in enumerate(args)
                   if i in nondiff_argnums or a is None}
        trimmed_args = [a for i, a in enumerate(args) if i not in closure]

        @tf.custom_gradient
        def f_wrapped(*args, **kwargs):
          reconstruct_args = []
          args_structure = tf.nest.map_structure(lambda _: 0, args)
          for i in range(len(args) + len(closure)):
            if i in closure:
              reconstruct_args.append(closure[i])
            else:
              reconstruct_args.append(args[0])
              args = args[1:]
          val, aux = vjp_fwd(*reconstruct_args, **kwargs)

          def vjp_bwd_wrapped(*g):
            result = tf.nest.flatten(
                vjp_bwd(aux, tf.nest.pack_sequence_as(val, g)))
            for i in nondiff_argnums:
              result = tuple(result[:i]) + (None,) + tuple(result[i:])
            result = [a for i, a in enumerate(result) if i not in closure]
            return tf.nest.pack_sequence_as(args_structure, result)

          return val, vjp_bwd_wrapped

        return f_wrapped(*trimmed_args, **kwargs)

      return none_wrapper
Example #11
0
    def test_eigh_partial_csr_vjp(self):
        dtype = np.float64
        n = 20
        k = 4
        largest = False
        a = random_spd_csr(n, dtype=dtype)

        def eigh_partial_coo(data, indices, indptr, k: int, largest: bool):
            size = indptr.size - 1
            data = csr.symmetrize(data, indices)
            a = csr.to_dense(data, indices, indptr, (size, size))
            w, v = eigh_partial(a, k, largest)
            v = standardize_eigenvector_signs(v)
            return w, v

        def eigh_partial_fwd(data, indices, indptr, k: int, largest: bool):
            w, v = eigh_partial_coo(data, indices, indptr, k, largest)
            return (w, v), (w, v, data, indices, indptr)

        def eigh_partial_rev(res, g):
            w, v, data, indices, indptr = res
            grad_w, grad_v = g
            rng_key = jax.random.PRNGKey(0)
            x0 = jax.random.normal(rng_key, shape=v.shape, dtype=w.dtype)
            grad_data, x0 = cg.eigh_partial_rev(
                grad_w,
                grad_v,
                w,
                v,
                csr.matmul_fun(data, indices, indptr),
                x0,
                outer_impl=csr.masked_outer_fun(indices, indptr),
            )
            grad_data = csr.symmetrize(grad_data, indices)
            return grad_data, None, None, None, None

        eigh_partial_fn = jax.custom_vjp(eigh_partial_coo)
        eigh_partial_fn.defvjp(eigh_partial_fwd, eigh_partial_rev)

        data, indices, indptr, _ = csr_components(a)
        jtu.check_grads(
            partial(eigh_partial_fn,
                    k=k,
                    largest=largest,
                    indices=indices,
                    indptr=indptr),
            (data, ),
            1,
            modes=["rev"],
            rtol=1e-3,
        )
Example #12
0
    def custom_gradient(self, f: Callable, gradient: Callable) -> Callable:
        jax_fun = jax.custom_vjp(f)  # custom vector-Jacobian product (reverse-mode differentiation)

        def forward(*x):
            y = f(*x)
            return y, (x, y)

        def backward(x_y, dy):
            x, y = x_y
            dx = gradient(x, y, dy)
            return tuple(dx)

        jax_fun.defvjp(forward, backward)
        return jax_fun
Example #13
0
def permute_via_gather(val, permutation, inverse_permutation, axis=0):
    """Permutation helper for LSH attention."""
    def permute_impl(p, unused_ip, val):
        return jnp.take(val, p, axis=axis)

    def permute_fwd(p, ip, val):
        return jnp.take(val, p, axis=axis), ip

    def permute_bwd(ip, permuted_grad):
        # JAX autodiff would synthesize a scatter operation because it doesn't
        # know that the indices are a permutation. However on TPU, gathers are
        # faster than scatters (at least in the regime the LSH attention uses).
        return (None, None, jnp.take(permuted_grad, ip, axis=axis))

    permute = jax.custom_vjp(permute_impl, permute_fwd, permute_bwd)
    return permute(permutation, inverse_permutation, val)
Example #14
0
    def test_lobpcg_vjp(self):
        m = 50
        k = 10
        largest = False
        dtype = np.float64

        def lobpcg_simple(A, X0, largest, k):
            A = symmetrize(A)
            w, v = lobpcg(A, X0, largest=largest, k=k)
            v = standardize_eigenvector_signs(v)
            return w, v

        def lobpcg_fwd(A, X0, largest, k):
            w, v = lobpcg_simple(A, X0, largest, k)
            return (w, v), (w, v, A)

        def lobpcg_rev(res, g):
            grad_w, grad_v = g
            w, v, a = res
            x0 = jax.random.normal(jax.random.PRNGKey(0),
                                   shape=v.shape,
                                   dtype=v.dtype)
            grad_a, x0 = eigh_partial_rev(grad_w, grad_v, w, v, a, x0)
            grad_a = symmetrize(grad_a)
            return grad_a, None, None, None

        lobpcg_fun = jax.custom_vjp(lobpcg_simple)
        lobpcg_fun.defvjp(lobpcg_fwd, lobpcg_rev)

        A = random_spd_coo(m, dtype=dtype).todense()
        rng = np.random.default_rng(0)
        X0 = rng.uniform(size=(m, k)).astype(dtype)
        jtu.check_grads(
            partial(lobpcg_fun, X0=X0, largest=largest, k=k),
            (A, ),
            order=1,
            modes=["rev"],
            rtol=1e-3,
        )
Example #15
0
    def test_eigh_coo_vjp(self):
        n = 20
        dtype = np.float64
        a = random_spd_coo(n=n, dtype=dtype)

        def eigh_coo(data, row, col, size):
            data = coo.symmetrize(data, row, col, size)
            a = coo.to_dense(data, row, col, (size, size))
            w, v = jnp.linalg.eigh(a)
            v = standardize_eigenvector_signs(v)
            return w, v

        def eigh_coo_fwd(data, row, col, size):
            w, v = eigh_coo(data, row, col, size)
            return (w, v), (w, v, row, col)

        def eigh_coo_rev(res, g):
            grad_w, grad_v = g
            w, v, row, col = res
            size = v.shape[0]
            grad_data = cg.eigh_rev(grad_w, grad_v, w, v,
                                    coo.masked_matmul_fun(row, col))
            grad_data = coo.symmetrize(grad_data, row, col, size)
            return (grad_data, None, None, None)

        eigh = jax.custom_vjp(eigh_coo)
        eigh.defvjp(eigh_coo_fwd, eigh_coo_rev)

        data, row, col, shape = coo_components(a)
        self.assertTrue(coo.is_symmetric(row, col, data, shape))
        jtu.check_grads(
            partial(eigh, row=row, col=col, size=n),
            (data, ),
            order=1,
            modes="rev",
            rtol=1e-3,
        )
Example #16
0
    def test_eigh_csr_vjp(self):
        n = 20
        dtype = np.float64
        a = random_spd_csr(n=n, dtype=dtype)

        def eigh_csr(data, indices, indptr):
            size = indptr.size - 1
            data = csr.symmetrize(data, indices)
            a = csr.to_dense(data, indices, indptr, (size, size))
            w, v = jnp.linalg.eigh(a)
            v = standardize_eigenvector_signs(v)
            return w, v

        def eigh_csr_fwd(data, indices, indptr):
            w, v = eigh_csr(data, indices, indptr)
            return (w, v), (w, v, indices, indptr)

        def eigh_csr_rev(res, g):
            grad_w, grad_v = g
            w, v, indices, indptr = res
            grad_data = cg.eigh_rev(grad_w, grad_v, w, v,
                                    csr.masked_matmul_fun(indices, indptr))
            grad_data = csr.symmetrize(grad_data, indices)
            return (grad_data, None, None)

        eigh = jax.custom_vjp(eigh_csr)
        eigh.defvjp(eigh_csr_fwd, eigh_csr_rev)

        data, indices, indptr, _ = csr_components(a)
        jtu.check_grads(
            partial(eigh, indices=indices, indptr=indptr),
            (data, ),
            order=1,
            modes="rev",
            rtol=1e-3,
        )
Example #17
0
    return r.T, q.T


def _lq_fwd(x):
    l, q = _lq(x)
    return (l, q), (l, q)


def _copyltu(x):
    """Copy lower triangular matrix to upper."""
    return jnp.tril(x) + jnp.tril(x, -1).T


def _lq_bwd(res, g):
    l, q = res
    grad_l, grad_q = g

    m = l.T @ grad_l - grad_q @ q.T
    grad_x = jax.scipy.linalg.solve_triangular(l.T, grad_q + _copyltu(m) @ q)
    return (grad_x, )


lq = jax.custom_vjp(_lq)
lq.defvjp(_lq_fwd, _lq_bwd)


def qr(x):
    """With backward gradient."""
    l, q = lq(x.T)
    return q.T, l.T
Example #18
0
      jnp.einsum("lj,ij->lij", padded_h4, gamma)
  del h1, padded_h2, h3, padded_h4

  padded_dbdc = jnp.concatenate([jnp.zeros((1, n, k + 1)), dbdc], axis=0)
  del dbdc

  dldc = 1 / epsilon * (
      -jnp.einsum("ij,ij->ij", grad_output_gamma, gamma) +
      jnp.einsum("hl,hl,hij->ij", grad_output_gamma, gamma, dxidc) +
      jnp.einsum("hl,hl,lij->ij", grad_output_gamma, gamma, padded_dbdc))

  return (dldc, None, None, None, None)


# create differentiable Monte Carlo estimate with custom backward
sinkhord_differentiable = jax.custom_vjp(sinkhorn)
sinkhord_differentiable.defvjp(sinkhorn_forward, sinkhorn_backward)


def differentiable_smooth_sorted_top_k(x, k, epsilon, num_iterations):
  """Differentiable smooth sorted top k."""
  n = x.shape[0]
  y = jnp.arange(k + 1)
  mu = jnp.ones(n) / n
  nu = jnp.array([(n - k) / n] + [1. / n] * k)

  # shape: n, k + 1
  costs = (x[:, jnp.newaxis] - y[jnp.newaxis, :]) ** 2

  transport_plan = sinkhord_differentiable(costs, mu, nu, epsilon,
                                           num_iterations)
Example #19
0
    def unrolled_body_fn(iteration_g_gconst):
        iteration, g, g_constants = iteration_g_gconst
        state = jax.tree_map(lambda x: x[iteration // inner_iterations],
                             states)
        _, pullback = jax.vjp(unrolled_body_fn_no_errors, iteration, constants,
                              state)
        _, gi_constants, g_state = pullback(g)
        g_constants = jax.tree_multimap(lambda x, y: x + y, g_constants,
                                        gi_constants)
        out = (iteration - inner_iterations, g_state, g_constants)
        return (out, None) if force_scan else out

    if force_scan:
        (_, g_state, g_constants), _ = jax.lax.scan(
            lambda carry, x: unrolled_body_fn(carry), (0, g, g_constants),
            None,
            length=max_iterations // inner_iterations)
    else:
        _, g_state, g_constants = jax.lax.while_loop(
            bwd_cond_fn, unrolled_body_fn,
            (iteration - inner_iterations, g, g_constants))

    return g_constants, g_state


# definition of backprop friendly variant of fixpoint_iter.
fixpoint_iter_backprop = jax.custom_vjp(fixpoint_iter,
                                        nondiff_argnums=(0, 1, 2, 3, 4))

fixpoint_iter_backprop.defvjp(fixpoint_iter_fwd, fixpoint_iter_bwd)