def test_make_fx_jacrev(self, device): def f(x): return x.sin().sum() inp = torch.randn(3) f = jacrev(jacrev(f)) fx_f = make_fx(f)(inp) new_inp = torch.randn(3) self.assertEqual(fx_f(new_inp), f(new_inp))
def test_hessian_simple(self, device): def foo(x): return x.sin().sum() x = torch.randn(3, device=device) y = jacrev(jacrev(foo))(x) expected = torch.diagflat(-x.sin()) assert torch.allclose(y, expected)
def test_unrelated_hessian(self, device): N = 5 M = 3 W = torch.randn(N, M, device=device) def f(x): return W @ x x = torch.randn(M) result = jacrev(jacrev(f))(x) expected = torch.zeros(N, M, M, device=device) self.assertEqual(result, expected)
def make_prediction(model, drs): norms = torch.norm(drs, dim=1).reshape(-1, 1) energies = model(norms) network_derivs = vmap(jacrev(model))(norms).squeeze(-1) forces = -network_derivs * drs / norms return energies, forces
def jacrev(model, inp, v=None, strict=None): argnums = tuple(range(len(inp))) return ft.jacrev(model, argnums)(*inp)
def hessian_fwdrev(model, inp, v=None, strict=None): argnums = tuple(range(len(inp))) return ft.jacfwd(ft.jacrev(model, argnums=argnums), argnums=argnums)(*inp)
from functorch import vmap, vjp _, vjp_fn = vjp(partial(predict, weight, bias), x) ft_jacobian, = vmap(vjp_fn)(unit_vectors) assert torch.allclose(ft_jacobian, jacobian) # In another tutorial a composition of reverse-mode AD and vmap gave us # per-sample-gradients. In this tutorial, composing reverse-mode AD and vmap # gives us Jacobian computation! Various compositions of vmap and autodiff # transforms can give us different interesting quantities. # # functorch provides ``jacrev`` as a convenience function that performs # the vmap-vjp composition to compute jacobians. ``jacrev`` accepts an argnums # argument that says which argument we would like to compute Jacobians with # respect to. from functorch import jacrev ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x) assert torch.allclose(ft_jacobian, jacobian) # Let's compare the performance of the two ways to compute jacobian. # The functorch version is much faster (and becomes even faster the more outputs # there are). In general, we expect that vectorization via ``vmap`` can help # eliminate overhead and give better utilization of your hardware. from torch.utils.benchmark import Timer without_vmap = Timer(stmt="compute_jac(xp)", globals=globals()) with_vmap = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals()) print(without_vmap.timeit(500)) print(with_vmap.timeit(500)) # It's pretty easy to flip the problem around and say we want to compute # Jacobians of the parameters to our model (weight, bias) instead of the input.
def test_vmap_on_jacrev_simple(self, device): x = torch.randn(2, 3, device=device) y = vmap(jacrev(torch.sin))(x) expected = torch.stack([torch.diagflat(x[i].cos()) for i in range(2)]) assert torch.allclose(y, expected)
def test_simple_not_flat(self, device): x = torch.randn(2, 3, device=device) y = jacrev(torch.sin)(x) expected = torch.diagflat(x.view(-1).cos()) expected = expected.view(2, 3, 2, 3) assert torch.allclose(y, expected)
def test_simple(self, device): x = torch.randn(3, device=device) y = jacrev(torch.sin)(x) expected = torch.diagflat(x.cos()) assert torch.allclose(y, expected)