def test_vjp_vjp(self, device): x = torch.randn(3, device=device) y, vjp_fn = vjp(torch.sin, x) self.assertEqual(y, x.sin()) y, vjp_fn = vjp(lambda x: vjp_fn(x)[0], x) self.assertEqual(y, x * x.cos()) y = vjp_fn(x)[0]
def test_vjp_vmap(self, device): x = torch.randn(3, device=device) y, vjp_fn = vjp(vmap(torch.sin), x) self.assertEqual(y, x.sin()) v = torch.randn(3, device=device) self.assertEqual(vjp_fn(v)[0], x.cos() * v)
def test_vjp_grad(self, device): x = torch.randn([], device=device) y, vjp_fn = vjp(grad(torch.sin), x) self.assertEqual(y, x.cos()) v = torch.randn([]) self.assertEqual(vjp_fn(v)[0], -x.sin() * v)
def vhp(model, inp, v=None, strict=None): assert v is not None argnums = tuple(range(len(inp))) _, vjpfunc, aux = ft.vjp(ft.grad_and_value(model, argnums), *inp, has_aux=True) return aux, vjpfunc(v)
def test_vjp(self, device): x = torch.randn([], device=device) out, vjp_fn = vjp(torch.sin, x) self.assertEqual(out, x.sin()) v = torch.randn([], device=device) result, = vjp_fn(v) self.assertEqual(result, v * x.cos())
def test_vjp_pytree_input(self, device): def f(x): return x[0] * x[1][0] x = torch.randn([], device=device) v = torch.randn([], device=device) out, vjp_fn = vjp(f, (x, (x, x))) self.assertEqual(out, x * x) result = vjp_fn(v) self.assertEqual(result, ((x * v, (x * v, 0.)),))
def test_make_fx_vjp(self, device): def f(x): return torch.sin(x).sum() primals = torch.randn(3) _, vjp_fn = vjp(f, primals) cotangent = torch.randn(()) fx_f = make_fx(vjp_fn)(cotangent, True, True) new_cotangent = torch.randn(()) self.assertEqual(fx_f(new_cotangent, True, True), vjp_fn(new_cotangent))
def test_vjp_pytree_error(self, device): def f(x): return x, (x, x) x = torch.randn([], device=device) v1 = torch.randn([], device=device) v2 = torch.randn([], device=device) v3 = torch.randn([], device=device) _, vjp_fn = vjp(f, x) with self.assertRaisesRegex(RuntimeError, 'Expected pytree structure'): result, = vjp_fn((v1, (v2, v3)))
def test_vjp_pytree_output(self, device): def f(x): return x, (x, x) x = torch.randn([], device=device) v1 = torch.randn([], device=device) v2 = torch.randn([], device=device) v3 = torch.randn([], device=device) _, vjp_fn = vjp(f, x) result, = vjp_fn(v1, (v2, v3)) self.assertEqual(result, v1 + v2 + v3)
def test_unrelated_vjp(self, device): x = torch.tensor(1., device=device) y = torch.tensor(2., device=device) v = torch.tensor(1., device=device) def unrelated(x): return y out, vjp_fn = vjp(unrelated, x) result = vjp_fn(v) expected = (torch.zeros_like(x),) self.assertEqual(result, expected)
def test_unrelated_vjp_multiple_inputs_outputs(self, device): w = torch.tensor(3., device=device) x = torch.tensor(4., device=device) y = torch.tensor(2., device=device) v = torch.tensor(1., device=device) def unrelated(w, x): return y, y, x out, vjp_fn = vjp(unrelated, w, x) result = vjp_fn(v, v, v) expected = (torch.zeros_like(x), torch.ones_like(x)) self.assertEqual(result, expected)
def test_vmap_vjp(self, device): x = torch.randn(3, device=device) _, vjp_fn = vjp(torch.sin, x) def foo(x): _, vjp_fn = vjp(torch.sin, x) return vjp_fn(x) y = vmap(foo)(x) self.assertEqual(y, vjp_fn(x)) # TODO: there's a very interesting error message when the following # is on CPU xs = torch.randn(5, 3, device=device) expected = torch.stack([vjp_fn(x)[0] for x in xs]) result = vmap(lambda x: vjp_fn(x)[0])(xs) self.assertEqual(result, expected)
def vjp(model, inp, v=None, strict=None): assert v is not None out, vjpfunc = ft.vjp(model, *inp) return out, vjpfunc(v)
def compute_jac(xp): jacobian_rows = [ torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0] for vec in unit_vectors ] return torch.stack(jacobian_rows) jacobian = compute_jac(xp) # Instead of computing the jacobian row-by-row, we can use ``vmap`` to get rid # of the for-loop and vectorize the computation. We can't directly apply vmap # to PyTorch Autograd; instead, functorch provides a ``vjp`` transform: from functorch import vmap, vjp _, vjp_fn = vjp(partial(predict, weight, bias), x) ft_jacobian, = vmap(vjp_fn)(unit_vectors) assert torch.allclose(ft_jacobian, jacobian) # In another tutorial a composition of reverse-mode AD and vmap gave us # per-sample-gradients. In this tutorial, composing reverse-mode AD and vmap # gives us Jacobian computation! Various compositions of vmap and autodiff # transforms can give us different interesting quantities. # # functorch provides ``jacrev`` as a convenience function that performs # the vmap-vjp composition to compute jacobians. ``jacrev`` accepts an argnums # argument that says which argument we would like to compute Jacobians with # respect to. from functorch import jacrev ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x) assert torch.allclose(ft_jacobian, jacobian)
def foo(x, y): out, vjp_fn = vjp(grad(torch.sin), x) return vjp_fn(y)[0]
def foo(x, y): df, vjp_fn = vjp(grad(lambda x: -torch.cos(x)), x) return grad(lambda y: vjp_fn(y)[0])(y)
def foo(x): _, vjp_fn = vjp(torch.sin, x) return vjp_fn(x)
def foo(x, y): out, vjp_fn = vjp(torch.sin, x) return grad(lambda y: vjp_fn(y)[0])(y)