def training_fn(params, buffers, args): params_and_buffers = {**params, **buffers} _stateless.functional_call(bert_model, params_and_buffers, args, {}).sum().backward() optim = torch.optim.SGD(get_sorted_params(params), lr=0.01) optim.step() return params, buffers
def test_functional_batch_norm(self): module = torch.nn.BatchNorm1d(10) module.train() # Allow stats update # lets replace the running_mean buffer and check if its correctly updated x = torch.full((20, 10), 128.0) rm = torch.zeros(10) parameters = {'running_mean': rm} prev_rm = module.running_mean.clone() res = _stateless.functional_call(module, parameters, x) cur_rm = module.running_mean self.assertEqual(cur_rm, prev_rm) self.assertEqual(rm, torch.full((10,), 12.8)) # Now run functional without reparametrization and check that the module has # been updated res = _stateless.functional_call(module, {}, x) self.assertEqual(module.running_mean, torch.full((10,), 12.8))
def _run_call_with_mock_module(self, module, device='cpu', prefix=''): x = torch.rand((1, 1)).to(device) weight = torch.tensor([[1.0]], device=device) bias = torch.tensor([0.0], device=device) buffer = torch.tensor([0.0], device=device) if prefix != '': parameters = {f'{prefix}.l1.weight': weight, f'{prefix}.l1.bias': bias, f'{prefix}.buffer': buffer} else: parameters = {'l1.weight': weight, 'l1.bias': bias, 'buffer': buffer} to_check = module if prefix != '': to_check = getattr(module, prefix) prev_weight = to_check.l1.weight.clone() prev_buffer = to_check.buffer.clone() # the parameters represent an identity function contrary to the # existing params in module. So here we expect the result to be the # same as the input if the weight swapping went well. res = _stateless.functional_call(module, parameters, x) self.assertEqual(x, res) # check that the weight remain unmodified cur_weight = to_check.l1.weight cur_buffer = to_check.buffer self.assertEqual(cur_weight, prev_weight) self.assertEqual(cur_buffer, prev_buffer)
def test_functional_call_with_gradient(self): module = MockModule() x = torch.rand((1, 1)) weight = torch.tensor([[1.0]], requires_grad=True) bias = torch.tensor([0.0], requires_grad=True) buffer = torch.tensor([0.0]) parameters = {'l1.weight': weight, 'l1.bias': bias, 'buffer': buffer} res = _stateless.functional_call(module, parameters, x) # Check that a backward step calculates the gradient of the supplied parameters res.backward() self.assertIsNotNone(weight.grad) self.assertIsNotNone(bias.grad) self.assertIsNone(buffer.grad) # Gradient was not calculated for the module stated and buffers self.assertIsNone(module.l1.weight.grad) self.assertIsNone(module.l1.bias.grad) self.assertIsNone(module.buffer.grad)
def test_circular_references(self): module = MockModule() # Add a circular reference module.l1.m = module x = torch.rand((1, 1)) weight = torch.tensor([[1.0]]) bias = torch.tensor([0.0]) buffer = torch.tensor([0.0]) parameters = {'l1.m.l1.weight': weight, 'l1.bias': bias, 'l1.m.buffer': buffer} prev_weight = module.l1.weight.clone() prev_buffer = module.buffer.clone() res = _stateless.functional_call(module, parameters, x) self.assertEqual(x, res) # check that the weights remain unmodified and were correctly accesed cur_weight = module.l1.weight cur_buffer = module.buffer self.assertEqual(cur_weight, prev_weight) self.assertEqual(cur_buffer, prev_buffer)
def func(*params: torch.Tensor): _output: torch.Tensor = _stateless.functional_call( module, {n: p for n, p in zip(keys, params)}, _input) return _output.log_softmax(dim=1) # (N, C)
def functional_call(named_params, named_buffers, *args, **kwargs): params_and_buffers = {**named_params, **named_buffers} return _stateless.functional_call(mod, params_and_buffers, args, kwargs)
def call_for_per_sample_grads(module, batch_size, args, kwargs=None): r""" call_for_per_sample_grads(module, batch_size, args, kwargs=None) -> Tensor Invoked just like a forward pass, ``call_for_per_sample_grads`` will produce the same forward result. Then, when backward is invoked, the parameters of ``module`` will have a ``grad_sample`` field populated with the per sample gradients instead of the regular gradients Args: module: The ``nn.Module`` to get per sample gradients with respect to. All trainable parameters will compute per sample gradients, located in a ``grad_sample`` field when ``backward`` is invoked batch_size: The batch size of the input. Typically the input's first dimension args: Tuple of positional args passed to ``module`` to perform the forward pass kwargs: Dict of named args passed to ``module`` to perform the forward pass. Default: None Examples:: >>> model = nn.Linear(4, 3) >>> batched_input = torch.randn(5, 4) # batch size of 5 >>> res = call_for_per_sample_grads(model, batched_input.shape[0], batched_input).sum() >>> res.backward() >>> assert model.weight.shape == (3, 4) >>> assert model.weight.grad_sample.shape == (5, 3, 4) >>> assert model.weight.grad == None >>> assert model.bias.shape == (3,) >>> assert model.bias.grad_sample.shape == (5, 3) >>> assert model.bias.grad == None Note:: Does not work with any `nn.RNN`, including `nn.GRU` or `nn.LSTM`. Please use custom rewrites that wrap an `nn.Linear` module. See Opacus for an example """ def maybe_build_expanded_weight(og_tensor): if og_tensor.requires_grad: return ExpandedWeight(og_tensor, batch_size) else: return og_tensor if not isinstance(module, torch.nn.Module): raise RuntimeError( f"Module passed must be nn.Module, got {type(module).__name__}") if not isinstance(batch_size, int): raise RuntimeError( f"Batch size passed must be an integer, got {type(batch_size).__name__}" ) if batch_size < 1: raise RuntimeError(f"Batch size must be positive, got {batch_size}") for weight in module.parameters(): if hasattr( weight, "grad_sample" ) and weight.grad_sample is not None: # type: ignore[attr-defined] raise RuntimeError( "Current Expanded Weights accumulates the gradients, which will be incorrect for multiple " f"calls without clearing gradients. Please clear out the grad_sample parameter of {weight} or " "post an issue to pytorch/pytorch to prioritize correct behavior" ) params = { name: maybe_build_expanded_weight(value) for (name, value) in module.named_parameters() } return functional_call(module, params, args, kwargs)
# Another way to use ``nn.Module`` with forward AD is to utilize # the stateless API. NB: At the time of writing the stateless API is still # experimental and may be subject to change. from torch.nn.utils._stateless import functional_call # We need a fresh module because the functional call requires the # the model to have parameters registered. model = nn.Linear(5, 5) dual_params = {} with fwAD.dual_level(): for name, p in params.items(): # Using the same ``tangents`` from the above section dual_params[name] = fwAD.make_dual(p, tangents[name]) out = functional_call(model, dual_params, input) jvp2 = fwAD.unpack_dual(out).tangent # Check our results assert torch.allclose(jvp, jvp2) ###################################################################### # Custom autograd Function # -------------------------------------------------------------------- # Custom Functions also support forward-mode AD. To create custom Function # supporting forward-mode AD, register the ``jvp()`` static method. It is # possible, but not mandatory for custom Functions to support both forward # and backward AD. See the # `documentation <https://pytorch.org/docs/master/notes/extending.html#forward-mode-ad>`_ # for more information.
def func(*params: torch.Tensor, _input: torch.Tensor = None): _output: torch.Tensor = _stateless.functional_call( module, {n: p for n, p in zip(names, params)}, _input) return _output # (N, C)