def finalize(f): """Decorate f with a custom gradient.""" if JAX_MODE: # https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html # For JAX, we prefer to specify a custom JVP, as JAX can use a function # transform to transpose a JVP (must be linear in the tangents) to a VJP. if jvp_fn is not None: f_jvp = custom_jvp(f, nondiff_argnums=nondiff_argnums) f_jvp.defjvp(jvp_fn) return f_jvp else: from jax import custom_vjp # pylint: disable=g-import-not-at-top f_vjp = custom_vjp(f, nondiff_argnums=nondiff_argnums) f_vjp.defvjp(vjp_fwd, vjp_bwd) return f_vjp else: # TF custom gradients support only custom VJPs. @tf.custom_gradient def f_wrapped(*args, **kwargs): val, aux = vjp_fwd(*args, **kwargs) def vjp_bwd_wrapped(*g): result = vjp_bwd(aux, tf.nest.pack_sequence_as(val, g)) for i in nondiff_argnums: result = tuple(result[:i]) + (None,) + tuple(result[i:]) return result return val, vjp_bwd_wrapped return f_wrapped
def inner(scope_fn, repack_fn, variable_groups_xs, rng_groups_xs, *args): assert len(variable_groups_xs ) == 1, 'transform does not support multi-scope lifting.' grad_variables, other_variables = variable_groups_xs[0] def simple_scope_fn(grad_variables): return scope_fn(((freeze(grad_variables), other_variables), ), rng_groups_xs) def f(grad_variables, *args): scope = scope_fn(((grad_variables, other_variables), ), rng_groups_xs) y, _ = module_fn(scope, *args) vars_out = repack_fn(scope) return y, vars_out f = jax.custom_vjp(f, nondiff_argnums=nondiff_argnums) def f_fwd(grad_variables, *args): scope = simple_scope_fn(grad_variables) y, res = module_fn(scope, *args) vars_out = repack_fn(scope) return (y, vars_out), (res, grad_variables) def f_bwd(*args): nondiff_args = args[:-2] res, g = args[-2:] g_y, _ = g user_res, grad_variables = res return backward_fn(*nondiff_args, simple_scope_fn, grad_variables, user_res, g_y) f.defvjp(f_fwd, f_bwd) return f(grad_variables, *args)
def test_eigh_vjp(self): n = 20 dtype = np.float64 a = random_spd_coo(n=n, dtype=dtype) a = a.todense() def eigh(a): w, v = jax.numpy.linalg.eigh(a) v = standardize_eigenvector_signs(v) return w, v def eigh_fwd(a): w, v = eigh(a) return (w, v), (w, v) def eigh_rev(res, g): grad_w, grad_v = g w, v = res grad_a = cg.eigh_rev(grad_w, grad_v, w, v) grad_a = symmetrize(grad_a) return (grad_a, ) eigh_fun = jax.custom_vjp(eigh) eigh_fun.defvjp(eigh_fwd, eigh_rev) jtu.check_grads(eigh_fun, (a, ), order=1, modes="rev", rtol=1e-3) w, v = eigh(a) self.assertAllClose(a @ v, v * w, rtol=1e-6)
def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args): grad_variables, other_variables = variable_groups def simple_scope_fn(grad_variables): grad_variables = tuple(freeze(x) for x in grad_variables) return scope_fn((grad_variables, other_variables), rng_groups) def f(grad_variables, *args): scope = scope_fn((grad_variables, other_variables), rng_groups) y, _ = fn(scope, *args) vars_out = repack_fn(scope) return y, vars_out f = jax.custom_vjp(f, nondiff_argnums=nondiff_argnums) def f_fwd(grad_variables, *args): scope = simple_scope_fn(grad_variables) y, res = fn(scope, *args) vars_out = repack_fn(scope) return (y, vars_out), (res, grad_variables) def f_bwd(*args): nondiff_args = args[:-2] res, g = args[-2:] g_y, _ = g user_res, grad_variables = res return backward_fn(*nondiff_args, simple_scope_fn, grad_variables, user_res, g_y) f.defvjp(f_fwd, f_bwd) return f(grad_variables, *args)
def test_eigh_partial_vjp(self): dtype = np.float64 n = 20 k = 4 largest = False a = random_spd_coo(n, dtype=dtype).todense() def eigh_partial_fwd(a, k: int, largest: bool): w, v = eigh_partial(a, k, largest) return (w, v), (w, v, a) def eigh_partial_rev(res, g): w, v, a = res grad_w, grad_v = g rng_key = jax.random.PRNGKey(0) x0 = jax.random.normal(rng_key, v.shape, dtype=v.dtype) grad_a, x0 = cg.eigh_partial_rev(grad_w, grad_v, w, v, a, x0) grad_a = symmetrize(grad_a) return (grad_a, None, None) eigh_partial_fun = jax.custom_vjp(eigh_partial) eigh_partial_fun.defvjp(eigh_partial_fwd, eigh_partial_rev) jtu.check_grads( partial(eigh_partial_fun, k=k, largest=largest), (a, ), 1, modes=["rev"], rtol=1e-3, )
def test_lobpcg_csr_vjp(self): m = 50 k = 10 largest = False dtype = np.float64 def lobpcg_csr(data, indices, indptr, X0, largest, k): data = csr.symmetrize(data, indices) A = csr.matmul_fun(data, indices, indptr) w, v = lobpcg(A, X0, largest=largest, k=k) v = standardize_eigenvector_signs(v) return w, v def lobpcg_fwd(data, indices, indptr, X0, largest, k): w, v = lobpcg_csr(data, indices, indptr, X0, largest, k) return (w, v), (w, v, data, indices, indptr) def lobpcg_rev(res, g): grad_w, grad_v = g w, v, data, indices, indptr = res A = csr.matmul_fun(data, indices, indptr) x0 = jax.random.normal(jax.random.PRNGKey(0), shape=v.shape, dtype=v.dtype) grad_data, x0 = eigh_partial_rev( grad_w, grad_v, w, v, A, x0, outer_impl=csr.masked_outer_fun(indices, indptr), ) grad_data = csr.symmetrize(grad_data, indices) return grad_data, None, None, None, None, None lobpcg_fun = jax.custom_vjp(lobpcg_csr) lobpcg_fun.defvjp(lobpcg_fwd, lobpcg_rev) rng = np.random.default_rng(0) A = random_spd_csr(m, sparsity=0.1, dtype=dtype) data, indices, indptr, _ = csr_components(A) X0 = rng.uniform(size=(m, k)).astype(dtype) jtu.check_grads( partial(lobpcg_fun, indices=indices, indptr=indptr, X0=X0, largest=largest, k=k), (data, ), order=1, modes=["rev"], rtol=2e-3, )
def test_eigh_partial_coo_vjp(self): dtype = np.float64 n = 20 k = 4 largest = False a = random_spd_coo(n, dtype=dtype) def eigh_partial_coo(data, row, col, size, k: int, largest: bool): data = coo.symmetrize(data, row, col, size) a = coo.to_dense(data, row, col, (size, size)) w, v = eigh_partial(a, k, largest) v = standardize_eigenvector_signs(v) return w, v def eigh_partial_fwd(data, row, col, size, k: int, largest: bool): w, v = eigh_partial_coo(data, row, col, size, k, largest) return (w, v), (w, v, data, row, col) def eigh_partial_rev(res, g): w, v, data, row, col = res size = v.shape[0] grad_w, grad_v = g rng_key = jax.random.PRNGKey(0) x0 = jax.random.normal(rng_key, shape=v.shape, dtype=w.dtype) grad_data, x0 = cg.eigh_partial_rev( grad_w, grad_v, w, v, coo.matmul_fun(data, row, col, jnp.zeros((size, ))), x0, outer_impl=coo.masked_outer_fun(row, col), ) grad_data = coo.symmetrize(grad_data, row, col, size) return (grad_data, None, None, None, None, None) eigh_partial_fn = jax.custom_vjp(eigh_partial_coo) eigh_partial_fn.defvjp(eigh_partial_fwd, eigh_partial_rev) data, row, col, _ = coo_components(a) self.assertTrue(coo.is_symmetric(row, col, data)) self.assertTrue(coo.is_ordered(row, col)) jtu.check_grads( partial(eigh_partial_fn, k=k, largest=largest, row=row, col=col, size=n), (data, ), 1, modes=["rev"], rtol=1e-3, )
def __init__(self, fun: Callable[Concatenate[U, P], R_co], static_argnums: Tuple[int, ...] = ()): """ Args: fun: the function to decorate. static_argnums: The indices of the static arguments. """ super().__init__() static_argnums = tuple(sorted(static_argnums)) self.vjp = jax.custom_vjp(fun, nondiff_argnums=static_argnums)
def test_lobpcg_coo_vjp(self): m = 50 k = 10 largest = False dtype = np.float64 def lobpcg_coo(data, row, col, X0, largest, k): size = X0.shape[0] data = coo.symmetrize(data, row, col, size) A = coo.matmul_fun(data, row, col, jnp.zeros((size, ))) w, v = lobpcg(A, X0, largest=largest, k=k) v = standardize_eigenvector_signs(v) return w, v def lobpcg_fwd(data, row, col, X0, largest, k): w, v = lobpcg_coo(data, row, col, X0, largest, k) return (w, v), (w, v, data, row, col) def lobpcg_rev(res, g): grad_w, grad_v = g w, v, data, row, col = res size = v.shape[0] A = coo.matmul_fun(data, row, col, jnp.zeros((size, ))) x0 = jax.random.normal(jax.random.PRNGKey(0), shape=v.shape, dtype=v.dtype) grad_data, x0 = eigh_partial_rev(grad_w, grad_v, w, v, A, x0, outer_impl=coo.masked_outer_fun( row, col)) grad_data = coo.symmetrize(grad_data, row, col, size) return grad_data, None, None, None, None, None lobpcg_fun = jax.custom_vjp(lobpcg_coo) lobpcg_fun.defvjp(lobpcg_fwd, lobpcg_rev) rng = np.random.default_rng(0) A = random_spd_coo(m, sparsity=0.1, dtype=dtype) data, row, col, _ = coo_components(A) X0 = rng.uniform(size=(m, k)).astype(dtype) jtu.check_grads( partial(lobpcg_fun, row=row, col=col, X0=X0, largest=largest, k=k), (data, ), order=1, modes=["rev"], rtol=2e-3, )
def finalize(f): """Decorate f with a custom gradient.""" if JAX_MODE: # https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html # For JAX, we prefer to specify a custom JVP, as JAX can use a function # transform to transpose a JVP (must be linear in the tangents) to a VJP. if jvp_fn is not None: f_jvp = custom_jvp(f, nondiff_argnums=nondiff_argnums) f_jvp.defjvp(jvp_fn) return f_jvp else: from jax import custom_vjp # pylint: disable=g-import-not-at-top f_vjp = custom_vjp(f, nondiff_argnums=nondiff_argnums) f_vjp.defvjp(vjp_fwd, vjp_bwd) return f_vjp else: # TF custom gradients support only custom VJPs. def none_wrapper(*args, **kwargs): # custom_gradient can't handle None. closure = {i: a for i, a in enumerate(args) if i in nondiff_argnums or a is None} trimmed_args = [a for i, a in enumerate(args) if i not in closure] @tf.custom_gradient def f_wrapped(*args, **kwargs): reconstruct_args = [] args_structure = tf.nest.map_structure(lambda _: 0, args) for i in range(len(args) + len(closure)): if i in closure: reconstruct_args.append(closure[i]) else: reconstruct_args.append(args[0]) args = args[1:] val, aux = vjp_fwd(*reconstruct_args, **kwargs) def vjp_bwd_wrapped(*g): result = tf.nest.flatten( vjp_bwd(aux, tf.nest.pack_sequence_as(val, g))) for i in nondiff_argnums: result = tuple(result[:i]) + (None,) + tuple(result[i:]) result = [a for i, a in enumerate(result) if i not in closure] return tf.nest.pack_sequence_as(args_structure, result) return val, vjp_bwd_wrapped return f_wrapped(*trimmed_args, **kwargs) return none_wrapper
def test_eigh_partial_csr_vjp(self): dtype = np.float64 n = 20 k = 4 largest = False a = random_spd_csr(n, dtype=dtype) def eigh_partial_coo(data, indices, indptr, k: int, largest: bool): size = indptr.size - 1 data = csr.symmetrize(data, indices) a = csr.to_dense(data, indices, indptr, (size, size)) w, v = eigh_partial(a, k, largest) v = standardize_eigenvector_signs(v) return w, v def eigh_partial_fwd(data, indices, indptr, k: int, largest: bool): w, v = eigh_partial_coo(data, indices, indptr, k, largest) return (w, v), (w, v, data, indices, indptr) def eigh_partial_rev(res, g): w, v, data, indices, indptr = res grad_w, grad_v = g rng_key = jax.random.PRNGKey(0) x0 = jax.random.normal(rng_key, shape=v.shape, dtype=w.dtype) grad_data, x0 = cg.eigh_partial_rev( grad_w, grad_v, w, v, csr.matmul_fun(data, indices, indptr), x0, outer_impl=csr.masked_outer_fun(indices, indptr), ) grad_data = csr.symmetrize(grad_data, indices) return grad_data, None, None, None, None eigh_partial_fn = jax.custom_vjp(eigh_partial_coo) eigh_partial_fn.defvjp(eigh_partial_fwd, eigh_partial_rev) data, indices, indptr, _ = csr_components(a) jtu.check_grads( partial(eigh_partial_fn, k=k, largest=largest, indices=indices, indptr=indptr), (data, ), 1, modes=["rev"], rtol=1e-3, )
def custom_gradient(self, f: Callable, gradient: Callable) -> Callable: jax_fun = jax.custom_vjp(f) # custom vector-Jacobian product (reverse-mode differentiation) def forward(*x): y = f(*x) return y, (x, y) def backward(x_y, dy): x, y = x_y dx = gradient(x, y, dy) return tuple(dx) jax_fun.defvjp(forward, backward) return jax_fun
def permute_via_gather(val, permutation, inverse_permutation, axis=0): """Permutation helper for LSH attention.""" def permute_impl(p, unused_ip, val): return jnp.take(val, p, axis=axis) def permute_fwd(p, ip, val): return jnp.take(val, p, axis=axis), ip def permute_bwd(ip, permuted_grad): # JAX autodiff would synthesize a scatter operation because it doesn't # know that the indices are a permutation. However on TPU, gathers are # faster than scatters (at least in the regime the LSH attention uses). return (None, None, jnp.take(permuted_grad, ip, axis=axis)) permute = jax.custom_vjp(permute_impl, permute_fwd, permute_bwd) return permute(permutation, inverse_permutation, val)
def test_lobpcg_vjp(self): m = 50 k = 10 largest = False dtype = np.float64 def lobpcg_simple(A, X0, largest, k): A = symmetrize(A) w, v = lobpcg(A, X0, largest=largest, k=k) v = standardize_eigenvector_signs(v) return w, v def lobpcg_fwd(A, X0, largest, k): w, v = lobpcg_simple(A, X0, largest, k) return (w, v), (w, v, A) def lobpcg_rev(res, g): grad_w, grad_v = g w, v, a = res x0 = jax.random.normal(jax.random.PRNGKey(0), shape=v.shape, dtype=v.dtype) grad_a, x0 = eigh_partial_rev(grad_w, grad_v, w, v, a, x0) grad_a = symmetrize(grad_a) return grad_a, None, None, None lobpcg_fun = jax.custom_vjp(lobpcg_simple) lobpcg_fun.defvjp(lobpcg_fwd, lobpcg_rev) A = random_spd_coo(m, dtype=dtype).todense() rng = np.random.default_rng(0) X0 = rng.uniform(size=(m, k)).astype(dtype) jtu.check_grads( partial(lobpcg_fun, X0=X0, largest=largest, k=k), (A, ), order=1, modes=["rev"], rtol=1e-3, )
def test_eigh_coo_vjp(self): n = 20 dtype = np.float64 a = random_spd_coo(n=n, dtype=dtype) def eigh_coo(data, row, col, size): data = coo.symmetrize(data, row, col, size) a = coo.to_dense(data, row, col, (size, size)) w, v = jnp.linalg.eigh(a) v = standardize_eigenvector_signs(v) return w, v def eigh_coo_fwd(data, row, col, size): w, v = eigh_coo(data, row, col, size) return (w, v), (w, v, row, col) def eigh_coo_rev(res, g): grad_w, grad_v = g w, v, row, col = res size = v.shape[0] grad_data = cg.eigh_rev(grad_w, grad_v, w, v, coo.masked_matmul_fun(row, col)) grad_data = coo.symmetrize(grad_data, row, col, size) return (grad_data, None, None, None) eigh = jax.custom_vjp(eigh_coo) eigh.defvjp(eigh_coo_fwd, eigh_coo_rev) data, row, col, shape = coo_components(a) self.assertTrue(coo.is_symmetric(row, col, data, shape)) jtu.check_grads( partial(eigh, row=row, col=col, size=n), (data, ), order=1, modes="rev", rtol=1e-3, )
def test_eigh_csr_vjp(self): n = 20 dtype = np.float64 a = random_spd_csr(n=n, dtype=dtype) def eigh_csr(data, indices, indptr): size = indptr.size - 1 data = csr.symmetrize(data, indices) a = csr.to_dense(data, indices, indptr, (size, size)) w, v = jnp.linalg.eigh(a) v = standardize_eigenvector_signs(v) return w, v def eigh_csr_fwd(data, indices, indptr): w, v = eigh_csr(data, indices, indptr) return (w, v), (w, v, indices, indptr) def eigh_csr_rev(res, g): grad_w, grad_v = g w, v, indices, indptr = res grad_data = cg.eigh_rev(grad_w, grad_v, w, v, csr.masked_matmul_fun(indices, indptr)) grad_data = csr.symmetrize(grad_data, indices) return (grad_data, None, None) eigh = jax.custom_vjp(eigh_csr) eigh.defvjp(eigh_csr_fwd, eigh_csr_rev) data, indices, indptr, _ = csr_components(a) jtu.check_grads( partial(eigh, indices=indices, indptr=indptr), (data, ), order=1, modes="rev", rtol=1e-3, )
return r.T, q.T def _lq_fwd(x): l, q = _lq(x) return (l, q), (l, q) def _copyltu(x): """Copy lower triangular matrix to upper.""" return jnp.tril(x) + jnp.tril(x, -1).T def _lq_bwd(res, g): l, q = res grad_l, grad_q = g m = l.T @ grad_l - grad_q @ q.T grad_x = jax.scipy.linalg.solve_triangular(l.T, grad_q + _copyltu(m) @ q) return (grad_x, ) lq = jax.custom_vjp(_lq) lq.defvjp(_lq_fwd, _lq_bwd) def qr(x): """With backward gradient.""" l, q = lq(x.T) return q.T, l.T
jnp.einsum("lj,ij->lij", padded_h4, gamma) del h1, padded_h2, h3, padded_h4 padded_dbdc = jnp.concatenate([jnp.zeros((1, n, k + 1)), dbdc], axis=0) del dbdc dldc = 1 / epsilon * ( -jnp.einsum("ij,ij->ij", grad_output_gamma, gamma) + jnp.einsum("hl,hl,hij->ij", grad_output_gamma, gamma, dxidc) + jnp.einsum("hl,hl,lij->ij", grad_output_gamma, gamma, padded_dbdc)) return (dldc, None, None, None, None) # create differentiable Monte Carlo estimate with custom backward sinkhord_differentiable = jax.custom_vjp(sinkhorn) sinkhord_differentiable.defvjp(sinkhorn_forward, sinkhorn_backward) def differentiable_smooth_sorted_top_k(x, k, epsilon, num_iterations): """Differentiable smooth sorted top k.""" n = x.shape[0] y = jnp.arange(k + 1) mu = jnp.ones(n) / n nu = jnp.array([(n - k) / n] + [1. / n] * k) # shape: n, k + 1 costs = (x[:, jnp.newaxis] - y[jnp.newaxis, :]) ** 2 transport_plan = sinkhord_differentiable(costs, mu, nu, epsilon, num_iterations)
def unrolled_body_fn(iteration_g_gconst): iteration, g, g_constants = iteration_g_gconst state = jax.tree_map(lambda x: x[iteration // inner_iterations], states) _, pullback = jax.vjp(unrolled_body_fn_no_errors, iteration, constants, state) _, gi_constants, g_state = pullback(g) g_constants = jax.tree_multimap(lambda x, y: x + y, g_constants, gi_constants) out = (iteration - inner_iterations, g_state, g_constants) return (out, None) if force_scan else out if force_scan: (_, g_state, g_constants), _ = jax.lax.scan( lambda carry, x: unrolled_body_fn(carry), (0, g, g_constants), None, length=max_iterations // inner_iterations) else: _, g_state, g_constants = jax.lax.while_loop( bwd_cond_fn, unrolled_body_fn, (iteration - inner_iterations, g, g_constants)) return g_constants, g_state # definition of backprop friendly variant of fixpoint_iter. fixpoint_iter_backprop = jax.custom_vjp(fixpoint_iter, nondiff_argnums=(0, 1, 2, 3, 4)) fixpoint_iter_backprop.defvjp(fixpoint_iter_fwd, fixpoint_iter_bwd)