def leapfrog(q, p, potential_fn, inverse_mass_matrix, step_size): r""" Second order symplectic integrator that uses the velocity leapfrog algorithm. :param dict q: dictionary of sample site names and their current values (type :class:`~torch.Tensor`). :param dict p: dictionary of sample site names and corresponding momenta (type :class:`~torch.Tensor`). :param callable potential_fn: function that returns potential energy given q for each sample site. The negative gradient of the function with respect to ``q`` determines the rate of change of the corresponding sites' momenta ``r``. :param torch.Tensor inverse_mass_matrix: a tensor :math:`M^{-1}` which is used to calculate kinetic energy: :math:`E_{kinetic} = \frac{1}{2}z^T M^{-1} q`. Here :math:`M` can be a 1D tensor (diagonal matrix) or a 2D tensor (dense matrix). :param float step_size: step size for each time step iteration. :param int num_steps: number of discrete time steps over which to integrate. :param torch.Tensor q_grads: optional gradients of potential energy at current ``q``. :return tuple (q_next, p_next, q_grads, potential_energy): next position and momenta, together with the potential energy and its gradient w.r.t. ``q_next``. """ q_grads = grad(potential_fn)(q) p = p + 0.5 * step_size * (-q_grads) p_grads = _kinetic_grad(inverse_mass_matrix, p) q = q + step_size * p_grads # q(n+1) q_grads = grad(potential_fn)(q) # potential_energy = potential_fn(q) # q_grads, potential_energy = grad_and_value(potential_fn)(q) p = p + 0.5 * step_size * (-q_grads) return q, p, q_grads, q_grads
def _test_attributes(self, get_attr_lambda, device): x = torch.randn(2, 3, 5, dtype=torch.double, device=device) expected = get_attr_lambda(x) def foo(x): self.assertEqual(get_attr_lambda(x), expected) return x.sum() grad(foo)(x)
def test_grad_vjp(self, device): x = torch.randn(3, device=device) def foo(x): _, vjp_fn = vjp(torch.sin, x) return vjp_fn(x)[0].sum() y = grad(foo)(x) expected = grad(lambda x: (x * x.cos()).sum())(x) self.assertEqual(y, expected)
def test_per_sample_grads_simple(self, device): def compute_loss(weight, x, t): y = x @ weight return ((y - t) ** 2).sum() weight = torch.randn(16, 2, device=device) x = torch.randn(64, 16, device=device) t = torch.randn(64, 2, device=device) result = vmap(partial(grad(compute_loss), weight))(x, t) expected = [grad(compute_loss)(weight, x[i], t[i]) for i in range(64)] expected = torch.stack(expected) # TODO: Check if the rtol is a problem self.assertEqual(result, expected, atol=0, rtol=5e-4)
def test_argnums(self, device): x = torch.randn([]) y = torch.randn([]) gx = grad(torch.mul, argnums=0)(x, y) self.assertEqual(gx, y) gy = grad(torch.mul, argnums=1)(x, y) self.assertEqual(gy, x) gx, = grad(torch.mul, argnums=(0,))(x, y) self.assertEqual(gx, y) gx, gy = grad(torch.mul, argnums=(0, 1))(x, y) self.assertEqual(gx, y) self.assertEqual(gy, x)
def test_make_fx_no_decompose(self, device): # FIXME return self.skipTest("error: maximum recursion reached") def f(x): return torch.tanh(x).sum() fx_f = make_fx(grad(f))(torch.randn(5)) ops = set([i.target for i in fx_f.graph.nodes]) self.assertEqual(torch.ops.aten.tanh_backward in ops, True) fx_f = make_fx(grad(f), decomposition_table)(torch.randn(5)) ops = set([i.target for i in fx_f.graph.nodes]) self.assertEqual(torch.ops.aten.tanh_backward in ops, False)
def test_vjp_grad(self, device): x = torch.randn([], device=device) y, vjp_fn = vjp(grad(torch.sin), x) self.assertEqual(y, x.cos()) v = torch.randn([]) self.assertEqual(vjp_fn(v)[0], -x.sin() * v)
def test_zero_grad(self, device): def f(x): return (x['a']**2.0).sum() inps = ({'a':torch.randn(10, device=device) + 3, 'b':torch.randn(10, device=device)}) grads = grad(f)(inps) self.assertNotEqual(grads['a'].sum(), 0.0) self.assertEqual(grads['b'].sum(), 0.0)
def test_per_sample_grads_embeddingnet(self, device): class SampleNet(nn.Module): def __init__(self, vocab_size: int): super().__init__() self.emb = nn.Embedding(vocab_size, 16) self.fc1 = nn.Linear(16, 16) self.fc2 = nn.Linear(16, 2) def forward(self, x): x = self.emb(x) x = torch.transpose(x, -1, -2) x = torch.mean(x, -1) x = self.fc1(x) x = F.relu(x) x = self.fc2(x) return x def name(self): return "SampleNet" # Create our inputs... vocab_size = 1000 batch_shape = [64] words_per_sentence = 5 data = torch.randint(0, vocab_size, (*batch_shape, words_per_sentence), device=device) targets = torch.randint(0, 1, (*batch_shape,), device=device) # Construct our module net = SampleNet(vocab_size).to(device=device) criterion = nn.CrossEntropyLoss() params = dict(net.named_parameters()) weights, net_func, _ = make_functional(net) def compute_loss(weights, data, target): output = net_func(weights, (data,)) result = criterion(output, target) return result expected = [grad(compute_loss)(weights, data[i], targets[i]) for i in range(64)] expected = zip(*expected) expected = tuple(torch.stack(shards) for shards in expected) result = vmap(partial(grad(compute_loss), weights))(data, targets) for r, e in zip(result, expected): # TODO: Check if the rtol is a problem self.assertEqual(r, e, atol=0, rtol=1e-4)
def test_inplace(self, device): x = torch.randn([], device=device) def foo(x): return x.clone().sin_() result = grad(foo)(x) self.assertEqual(result, x.cos())
def test_grad_vmap(self, device): def foo(x): y = vmap(torch.sin)(x) return y.sum() x = torch.randn(3) y = grad(foo)(x) self.assertEqual(y, x.cos())
def test_unrelated_grad(self, device): x = torch.tensor(1., device=device) y = torch.tensor(2., device=device) def unrelated(x): return y result = grad(unrelated)(x) self.assertEqual(result, torch.zeros_like(x))
def test_view_inplace_simple(self, device): def foo(x): x = x.clone() x.view([]).sin_() return x x = torch.randn([], requires_grad=True, device=device) result = grad(foo)(x) self.assertEqual(result, x.cos())
def test_make_fx_grad(self, device): def f(x): return torch.sin(x).sum() inp = torch.randn(3) f = grad(f) fx_f = make_fx(f)(inp) new_inp = torch.randn(3) self.assertEqual(fx_f(new_inp), f(new_inp))
def test_invalid_argnums(self, device): x = torch.randn([]) y = torch.randn([]) with self.assertRaisesRegex(RuntimeError, 'but only'): grad(torch.mul, argnums=-1)(x, y) with self.assertRaisesRegex(RuntimeError, 'but only'): grad(torch.mul, argnums=2)(x, y) with self.assertRaisesRegex(RuntimeError, 'int or Tuple'): grad(torch.mul, argnums=[0])(x, y) with self.assertRaisesRegex(RuntimeError, 'must be int'): grad(torch.mul, argnums=('0',))(x, y)
def functorch_per_sample_grad(): compute_grad = grad(compute_loss) compute_per_sample_grad = vmap(compute_grad, (None, 0, 0)) start = time.time() result = compute_per_sample_grad(weights, images, targets) torch.cuda.synchronize() end = time.time() return result, end - start # end - start in seconds
def test_escaped_wrappers_are_marked_as_dead(self, device): x = torch.randn([], device=device) escaped = [] def foo(x): y = x.sin() escaped.append(y) return y result = grad(foo)(x) self.assertEqual(functorch._C.dlevel(escaped[0]), -1)
def get_loss_for_task(x1, y1, x2, y2): def inner_loss(params, x1, y1): f = net(params, x1) loss = mse_loss(f, y1) return loss grads = grad(inner_loss)(tuple(params), x1, y1) new_params = [(params[i] - alpha * grads[i]) for i in range(len(params))] v_f = net(new_params, x2) return mse_loss(v_f, y2)
def test_escaped_wrappers_are_ignored(self, device): x = torch.randn([], device=device) escaped = [] def foo(x): y = x.sin() escaped.append(y) return y result = grad(foo)(x) something = escaped[0].sum() self.assertEqual(functorch._C.dlevel(something), 0) self.assertEqual(something, x.sin().sum())
def test_new_empty_materializes_tensor(self, device): N = 3 C = 5 def foo(y, x): result = x.new_empty((C,)) result.copy_(y) return result.sum() x = torch.randn(N, device=device) y = torch.randn(N, C, device=device) result = vmap(grad(foo))(y, x) self.assertEqual(result, torch.ones_like(y))
def test_composite_two_ops(self, device): N, C = 2, 5 y = torch.randn(N, C, device=device) targets = torch.randint(0, C, (N,), device=device) def foo(y, targets): return F.cross_entropy(y, targets) result = grad(foo)(y, targets) y.requires_grad_() expected, = torch.autograd.grad(foo(y, targets), y) self.assertEqual(result, expected)
def get_loss_for_task(use_transform, x1, y1, x2, y2): def inner_loss(params, x1, y1): f = net(params, (x1,)) loss = mse_loss(f, y1) return loss if use_transform: grads = grad(inner_loss)(params, x1, y1) else: loss = inner_loss(params, x1, y1) grads = torch.autograd.grad(loss, params, create_graph=True) new_params = [(params[i] - alpha*grads[i]) for i in range(len(params))] v_f = net(new_params, (x2,)) return mse_loss(v_f, y2)
def test_composite_complicated(self, device): x = torch.randn(3, device=device) y = torch.randn(3, 5, device=device) def foo(x, y): result = x @ y return result.sum() result = grad(foo)(x, y) x.requires_grad_() out = foo(x, y) expected, = torch.autograd.grad(out, x) self.assertEqual(result, expected)
def test_inplace_on_view_base(self, device): x = torch.randn(3, device=device) def foo(x): y = x.clone() y0 = y[0] y.sin_() return y0 result = grad(foo)(x) x.requires_grad_() out = foo(x) expected, = torch.autograd.grad(out, x) self.assertEqual(result, expected)
def update_fn(step_size, inverse_mass_matrix, state): """ :param float step_size: Size of a single step. :param inverse_mass_matrix: Inverse of mass matrix, which is used to calculate kinetic energy. :param state: Current state of the integrator. :return: new state for the integrator. """ q, p, _, q_grad = state # maps a function over a pytree, returning a new pytree p = tree_multimap(lambda p, q_grad: p - 0.5 * step_size * q_grad, p, q_grad) # p(n+1/2) p_grad = grad(kinetic_fn, argnums=1)(inverse_mass_matrix, p) q = tree_multimap(lambda q, p_grad: q + step_size * p_grad, q, p_grad) # q(n+1) potential_energy, q_grad = value_and_grad(potential_fn)(q) p = tree_multimap(lambda p, q_grad: p - 0.5 * step_size * q_grad, p, q_grad) # p(n+1) return IntegratorState(q, p, potential_energy, q_grad)
def loss_for_task(net, n_inner_iter, x_spt, y_spt, x_qry, y_qry): params, buffers, fnet = net querysz = x_qry.size(0) def compute_loss(new_params, buffers, x, y): logits = fnet(new_params, buffers, (x, )) loss = F.cross_entropy(logits, y) return loss new_params = params for _ in range(n_inner_iter): grads = grad(compute_loss)(new_params, buffers, x_spt, y_spt) new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)] # The final set of adapted parameters will induce some # final loss and accuracy on the query dataset. # These will be used to update the model's meta-parameters. qry_logits = fnet(new_params, buffers, (x_qry, )) qry_loss = F.cross_entropy(qry_logits, y_qry) qry_acc = (qry_logits.argmax(dim=1) == y_qry).sum() / querysz return qry_loss, qry_acc
def loss_for_task(net, n_inner_iter, use_transform, x_spt, y_spt, x_qry, y_qry): params, buffers, fnet = net querysz = x_qry.size(0) def compute_loss(new_params, buffers, x, y): logits = fnet(new_params, buffers, (x,)) loss = F.cross_entropy(logits, y) return loss new_params = params for _ in range(n_inner_iter): if use_transform: grads = grad(compute_loss)(new_params, buffers, x_spt, y_spt) else: res = compute_loss(new_params, buffers, x_spt, y_spt) grads = torch.autograd.grad(res, new_params, create_graph=True) new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)] qry_logits = fnet(new_params, buffers, (x_qry,)) qry_loss = F.cross_entropy(qry_logits, y_qry) qry_acc = (qry_logits.argmax( dim=1) == y_qry).sum() / querysz return qry_loss, qry_acc
# Next, let's define a function to compute the loss of the model given a single # input rather than a batch of inputs. It is important that this function accepts the # parameters, the input, and the target, because we will be transforming over them. # Because the model was originally written to handle batches, we'll use # ``torch.unsqueeze`` to add a batch dimension. def compute_loss(params, buffers, sample, target): batch = sample.unsqueeze(0) targets = target.unsqueeze(0) predictions = fmodel(params, buffers, batch) loss = loss_fn(predictions, targets) return loss # Now, let's use ``grad`` to create a new function that computes the gradient # with respect to the first argument of compute_loss (i.e. the params). ft_compute_grad = grad(compute_loss) # ``ft_compute_grad`` computes the gradient for a single (sample, target) pair. # We can use ``vmap`` to get it to compute the gradient over an entire batch # of samples and targets. Note that in_dims=(None, None, 0, 0) because we wish # to map ``ft_compute_grad`` over the 0th dimension of the data and targets # and use the same params and buffers for each. ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0)) # Finally, let's used our transformed function to compute per-sample-gradients: ft_per_sample_grads = ft_compute_sample_grad(params, buffers, data, targets) for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads): assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=1e-6,
from functorch import grad, nnc_jit, make_fx, make_nnc import torch import time def f(x): return torch.sin(x).sum() inp = torch.randn(100) grad_pt = grad(f) grad_fx = make_fx(grad_pt)(inp) grad_nnc = nnc_jit(grad_pt) loopnest = make_nnc(grad_pt)(inp) print(loopnest) def bench(name, f, iters=10000, warmup=3): for _ in range(warmup): f() begin = time.time() for _ in range(iters): f() print(f"{name}: ", time.time() - begin) bench("Pytorch: ", lambda: grad_pt(inp)) bench("FX: ", lambda: grad_fx(inp)) bench("NNC: ", lambda: grad_nnc(inp))
def test_resnet18_per_sample_grads(self, device): # Straight out of opacus def _replace_child( root: nn.Module, child_name: str, converter: Callable[[nn.Module], nn.Module] ) -> None: # find the immediate parent parent = root nameList = child_name.split(".") for name in nameList[:-1]: parent = parent._modules[name] # set to identity parent._modules[nameList[-1]] = converter(parent._modules[nameList[-1]]) def replace_all_modules( root: nn.Module, target_class: Type[nn.Module], converter: Callable[[nn.Module], nn.Module], ) -> nn.Module: # base case if isinstance(root, target_class): return converter(root) for name, obj in root.named_modules(): if isinstance(obj, target_class): _replace_child(root, name, converter) return root def _batchnorm_to_groupnorm(module: nn.modules.batchnorm._BatchNorm) -> nn.Module: return nn.GroupNorm(min(32, module.num_features), module.num_features, affine=True) def convert_batchnorm_modules( model: nn.Module, converter: Callable[ [nn.modules.batchnorm._BatchNorm], nn.Module ] = _batchnorm_to_groupnorm, ) -> nn.Module: return replace_all_modules(model, nn.modules.batchnorm._BatchNorm, converter) import torchvision.models as models model = convert_batchnorm_modules(models.resnet18(num_classes=10)).to(device) criterion = nn.CrossEntropyLoss() weights, func_model, descriptors = make_functional(model) def compute_loss(weights, image, target): images = image.unsqueeze(0) targets = target.unsqueeze(0) output = func_model(weights, (images,)) loss = criterion(output, targets) return loss batch_size = 3 images = torch.randn(batch_size, 3, 32, 32, device=device) targets = torch.randint(0, 10, (batch_size,), device=device) result_grads = vmap(grad(compute_loss), in_dims=(None, 0, 0))(weights, images, targets) expected_grads = [ torch.autograd.grad(compute_loss(weights, images[i], targets[i]), weights) for i in range(batch_size) ] expected_grads = [torch.stack(shards) for shards in zip(*expected_grads)] self.assertEqual(result_grads, expected_grads)