예제 #1
0
    def test_vjp_vjp(self, device):
        x = torch.randn(3, device=device)
        y, vjp_fn = vjp(torch.sin, x)
        self.assertEqual(y, x.sin())

        y, vjp_fn = vjp(lambda x: vjp_fn(x)[0], x)
        self.assertEqual(y, x * x.cos())

        y = vjp_fn(x)[0]
예제 #2
0
    def test_vjp_vmap(self, device):
        x = torch.randn(3, device=device)
        y, vjp_fn = vjp(vmap(torch.sin), x)
        self.assertEqual(y, x.sin())

        v = torch.randn(3, device=device)
        self.assertEqual(vjp_fn(v)[0], x.cos() * v)
예제 #3
0
    def test_vjp_grad(self, device):
        x = torch.randn([], device=device)
        y, vjp_fn = vjp(grad(torch.sin), x)
        self.assertEqual(y, x.cos())

        v = torch.randn([])
        self.assertEqual(vjp_fn(v)[0], -x.sin() * v)
예제 #4
0
 def vhp(model, inp, v=None, strict=None):
     assert v is not None
     argnums = tuple(range(len(inp)))
     _, vjpfunc, aux = ft.vjp(ft.grad_and_value(model, argnums),
                              *inp,
                              has_aux=True)
     return aux, vjpfunc(v)
예제 #5
0
    def test_vjp(self, device):
        x = torch.randn([], device=device)
        out, vjp_fn = vjp(torch.sin, x)
        self.assertEqual(out, x.sin())

        v = torch.randn([], device=device)
        result, = vjp_fn(v)
        self.assertEqual(result, v * x.cos())
예제 #6
0
    def test_vjp_pytree_input(self, device):
        def f(x):
            return x[0] * x[1][0]

        x = torch.randn([], device=device)
        v = torch.randn([], device=device)
        out, vjp_fn = vjp(f, (x, (x, x)))
        self.assertEqual(out, x * x)
        result = vjp_fn(v)
        self.assertEqual(result, ((x * v, (x * v, 0.)),))
예제 #7
0
    def test_make_fx_vjp(self, device):
        def f(x):
            return torch.sin(x).sum()

        primals = torch.randn(3)
        _, vjp_fn = vjp(f, primals)
        cotangent = torch.randn(())
        fx_f = make_fx(vjp_fn)(cotangent, True, True)
        new_cotangent = torch.randn(())
        self.assertEqual(fx_f(new_cotangent, True, True), vjp_fn(new_cotangent))
예제 #8
0
    def test_vjp_pytree_error(self, device):
        def f(x):
            return x, (x, x)

        x = torch.randn([], device=device)
        v1 = torch.randn([], device=device)
        v2 = torch.randn([], device=device)
        v3 = torch.randn([], device=device)
        _, vjp_fn = vjp(f, x)
        with self.assertRaisesRegex(RuntimeError, 'Expected pytree structure'):
            result, = vjp_fn((v1, (v2, v3)))
예제 #9
0
    def test_vjp_pytree_output(self, device):
        def f(x):
            return x, (x, x)

        x = torch.randn([], device=device)
        v1 = torch.randn([], device=device)
        v2 = torch.randn([], device=device)
        v3 = torch.randn([], device=device)
        _, vjp_fn = vjp(f, x)
        result, = vjp_fn(v1, (v2, v3))
        self.assertEqual(result, v1 + v2 + v3)
예제 #10
0
    def test_unrelated_vjp(self, device):
        x = torch.tensor(1., device=device)
        y = torch.tensor(2., device=device)
        v = torch.tensor(1., device=device)

        def unrelated(x):
            return y

        out, vjp_fn = vjp(unrelated, x)
        result = vjp_fn(v)
        expected = (torch.zeros_like(x),)
        self.assertEqual(result, expected)
예제 #11
0
    def test_unrelated_vjp_multiple_inputs_outputs(self, device):
        w = torch.tensor(3., device=device)
        x = torch.tensor(4., device=device)
        y = torch.tensor(2., device=device)
        v = torch.tensor(1., device=device)

        def unrelated(w, x):
            return y, y, x

        out, vjp_fn = vjp(unrelated, w, x)
        result = vjp_fn(v, v, v)
        expected = (torch.zeros_like(x), torch.ones_like(x))
        self.assertEqual(result, expected)
예제 #12
0
    def test_vmap_vjp(self, device):
        x = torch.randn(3, device=device)
        _, vjp_fn = vjp(torch.sin, x)

        def foo(x):
            _, vjp_fn = vjp(torch.sin, x)
            return vjp_fn(x)

        y = vmap(foo)(x)
        self.assertEqual(y, vjp_fn(x))

        # TODO: there's a very interesting error message when the following
        # is on CPU
        xs = torch.randn(5, 3, device=device)
        expected = torch.stack([vjp_fn(x)[0] for x in xs])
        result = vmap(lambda x: vjp_fn(x)[0])(xs)
        self.assertEqual(result, expected)
예제 #13
0
 def vjp(model, inp, v=None, strict=None):
     assert v is not None
     out, vjpfunc = ft.vjp(model, *inp)
     return out, vjpfunc(v)
예제 #14
0
def compute_jac(xp):
    jacobian_rows = [
        torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0]
        for vec in unit_vectors
    ]
    return torch.stack(jacobian_rows)


jacobian = compute_jac(xp)

# Instead of computing the jacobian row-by-row, we can use ``vmap`` to get rid
# of the for-loop and vectorize the computation. We can't directly apply vmap
# to PyTorch Autograd; instead, functorch provides a ``vjp`` transform:
from functorch import vmap, vjp
_, vjp_fn = vjp(partial(predict, weight, bias), x)
ft_jacobian, = vmap(vjp_fn)(unit_vectors)
assert torch.allclose(ft_jacobian, jacobian)

# In another tutorial a composition of reverse-mode AD and vmap gave us
# per-sample-gradients. In this tutorial, composing reverse-mode AD and vmap
# gives us Jacobian computation! Various compositions of vmap and autodiff
# transforms can give us different interesting quantities.
#
# functorch provides ``jacrev`` as a convenience function that performs
# the vmap-vjp composition to compute jacobians. ``jacrev`` accepts an argnums
# argument that says which argument we would like to compute Jacobians with
# respect to.
from functorch import jacrev
ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)
assert torch.allclose(ft_jacobian, jacobian)
예제 #15
0
 def foo(x, y):
     out, vjp_fn = vjp(grad(torch.sin), x)
     return vjp_fn(y)[0]
예제 #16
0
 def foo(x, y):
     df, vjp_fn = vjp(grad(lambda x: -torch.cos(x)), x)
     return grad(lambda y: vjp_fn(y)[0])(y)
예제 #17
0
 def foo(x):
     _, vjp_fn = vjp(torch.sin, x)
     return vjp_fn(x)
예제 #18
0
 def foo(x, y):
     out, vjp_fn = vjp(torch.sin, x)
     return grad(lambda y: vjp_fn(y)[0])(y)