Exemplo n.º 1
0
 def test_make_fx_jacrev(self, device):
     def f(x):
         return x.sin().sum()
     inp = torch.randn(3)
     f = jacrev(jacrev(f))
     fx_f = make_fx(f)(inp)
     new_inp = torch.randn(3)
     self.assertEqual(fx_f(new_inp), f(new_inp))
Exemplo n.º 2
0
    def test_hessian_simple(self, device):
        def foo(x):
            return x.sin().sum()

        x = torch.randn(3, device=device)
        y = jacrev(jacrev(foo))(x)
        expected = torch.diagflat(-x.sin())
        assert torch.allclose(y, expected)
Exemplo n.º 3
0
    def test_unrelated_hessian(self, device):
        N = 5
        M = 3
        W = torch.randn(N, M, device=device)

        def f(x):
            return W @ x

        x = torch.randn(M)
        result = jacrev(jacrev(f))(x)
        expected = torch.zeros(N, M, M, device=device)
        self.assertEqual(result, expected)
Exemplo n.º 4
0
def make_prediction(model, drs):
    norms = torch.norm(drs, dim=1).reshape(-1, 1)
    energies = model(norms)

    network_derivs = vmap(jacrev(model))(norms).squeeze(-1)
    forces = -network_derivs * drs / norms
    return energies, forces
Exemplo n.º 5
0
 def jacrev(model, inp, v=None, strict=None):
     argnums = tuple(range(len(inp)))
     return ft.jacrev(model, argnums)(*inp)
Exemplo n.º 6
0
 def hessian_fwdrev(model, inp, v=None, strict=None):
     argnums = tuple(range(len(inp)))
     return ft.jacfwd(ft.jacrev(model, argnums=argnums),
                      argnums=argnums)(*inp)
Exemplo n.º 7
0
from functorch import vmap, vjp
_, vjp_fn = vjp(partial(predict, weight, bias), x)
ft_jacobian, = vmap(vjp_fn)(unit_vectors)
assert torch.allclose(ft_jacobian, jacobian)

# In another tutorial a composition of reverse-mode AD and vmap gave us
# per-sample-gradients. In this tutorial, composing reverse-mode AD and vmap
# gives us Jacobian computation! Various compositions of vmap and autodiff
# transforms can give us different interesting quantities.
#
# functorch provides ``jacrev`` as a convenience function that performs
# the vmap-vjp composition to compute jacobians. ``jacrev`` accepts an argnums
# argument that says which argument we would like to compute Jacobians with
# respect to.
from functorch import jacrev
ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)
assert torch.allclose(ft_jacobian, jacobian)

# Let's compare the performance of the two ways to compute jacobian.
# The functorch version is much faster (and becomes even faster the more outputs
# there are). In general, we expect that vectorization via ``vmap`` can help
# eliminate overhead and give better utilization of your hardware.
from torch.utils.benchmark import Timer
without_vmap = Timer(stmt="compute_jac(xp)", globals=globals())
with_vmap = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)",
                  globals=globals())
print(without_vmap.timeit(500))
print(with_vmap.timeit(500))

# It's pretty easy to flip the problem around and say we want to compute
# Jacobians of the parameters to our model (weight, bias) instead of the input.
Exemplo n.º 8
0
 def test_vmap_on_jacrev_simple(self, device):
     x = torch.randn(2, 3, device=device)
     y = vmap(jacrev(torch.sin))(x)
     expected = torch.stack([torch.diagflat(x[i].cos()) for i in range(2)])
     assert torch.allclose(y, expected)
Exemplo n.º 9
0
 def test_simple_not_flat(self, device):
     x = torch.randn(2, 3, device=device)
     y = jacrev(torch.sin)(x)
     expected = torch.diagflat(x.view(-1).cos())
     expected = expected.view(2, 3, 2, 3)
     assert torch.allclose(y, expected)
Exemplo n.º 10
0
 def test_simple(self, device):
     x = torch.randn(3, device=device)
     y = jacrev(torch.sin)(x)
     expected = torch.diagflat(x.cos())
     assert torch.allclose(y, expected)