Esempio n. 1
0
def training_fn(params, buffers, args):
    params_and_buffers = {**params, **buffers}
    _stateless.functional_call(bert_model, params_and_buffers, args,
                               {}).sum().backward()
    optim = torch.optim.SGD(get_sorted_params(params), lr=0.01)
    optim.step()
    return params, buffers
Esempio n. 2
0
 def test_functional_batch_norm(self):
     module = torch.nn.BatchNorm1d(10)
     module.train()  # Allow stats update
     # lets replace the running_mean buffer and check if its correctly updated
     x = torch.full((20, 10), 128.0)
     rm = torch.zeros(10)
     parameters = {'running_mean': rm}
     prev_rm = module.running_mean.clone()
     res = _stateless.functional_call(module, parameters, x)
     cur_rm = module.running_mean
     self.assertEqual(cur_rm, prev_rm)
     self.assertEqual(rm, torch.full((10,), 12.8))
     # Now run functional without reparametrization and check that the module has
     # been updated
     res = _stateless.functional_call(module, {}, x)
     self.assertEqual(module.running_mean, torch.full((10,), 12.8))
Esempio n. 3
0
 def _run_call_with_mock_module(self, module, device='cpu', prefix=''):
     x = torch.rand((1, 1)).to(device)
     weight = torch.tensor([[1.0]], device=device)
     bias = torch.tensor([0.0], device=device)
     buffer = torch.tensor([0.0], device=device)
     if prefix != '':
         parameters = {f'{prefix}.l1.weight': weight,
                       f'{prefix}.l1.bias': bias,
                       f'{prefix}.buffer': buffer}
     else:
         parameters = {'l1.weight': weight,
                       'l1.bias': bias,
                       'buffer': buffer}
     to_check = module
     if prefix != '':
         to_check = getattr(module, prefix)
     prev_weight = to_check.l1.weight.clone()
     prev_buffer = to_check.buffer.clone()
     # the parameters represent an identity function contrary to the
     # existing params in module. So here we expect the result to be the
     # same as the input if the weight swapping went well.
     res = _stateless.functional_call(module, parameters, x)
     self.assertEqual(x, res)
     # check that the weight remain unmodified
     cur_weight = to_check.l1.weight
     cur_buffer = to_check.buffer
     self.assertEqual(cur_weight, prev_weight)
     self.assertEqual(cur_buffer, prev_buffer)
Esempio n. 4
0
 def test_functional_call_with_gradient(self):
     module = MockModule()
     x = torch.rand((1, 1))
     weight = torch.tensor([[1.0]], requires_grad=True)
     bias = torch.tensor([0.0], requires_grad=True)
     buffer = torch.tensor([0.0])
     parameters = {'l1.weight': weight,
                   'l1.bias': bias,
                   'buffer': buffer}
     res = _stateless.functional_call(module, parameters, x)
     # Check that a backward step calculates the gradient of the supplied parameters
     res.backward()
     self.assertIsNotNone(weight.grad)
     self.assertIsNotNone(bias.grad)
     self.assertIsNone(buffer.grad)
     # Gradient was not calculated for the module stated and buffers
     self.assertIsNone(module.l1.weight.grad)
     self.assertIsNone(module.l1.bias.grad)
     self.assertIsNone(module.buffer.grad)
Esempio n. 5
0
 def test_circular_references(self):
     module = MockModule()
     # Add a circular reference
     module.l1.m = module
     x = torch.rand((1, 1))
     weight = torch.tensor([[1.0]])
     bias = torch.tensor([0.0])
     buffer = torch.tensor([0.0])
     parameters = {'l1.m.l1.weight': weight,
                   'l1.bias': bias,
                   'l1.m.buffer': buffer}
     prev_weight = module.l1.weight.clone()
     prev_buffer = module.buffer.clone()
     res = _stateless.functional_call(module, parameters, x)
     self.assertEqual(x, res)
     # check that the weights remain unmodified and were correctly accesed
     cur_weight = module.l1.weight
     cur_buffer = module.buffer
     self.assertEqual(cur_weight, prev_weight)
     self.assertEqual(cur_buffer, prev_buffer)
Esempio n. 6
0
 def func(*params: torch.Tensor):
     _output: torch.Tensor = _stateless.functional_call(
         module, {n: p
                  for n, p in zip(keys, params)}, _input)
     return _output.log_softmax(dim=1)  # (N, C)
Esempio n. 7
0
 def functional_call(named_params, named_buffers, *args, **kwargs):
     params_and_buffers = {**named_params, **named_buffers}
     return _stateless.functional_call(mod, params_and_buffers, args,
                                       kwargs)
Esempio n. 8
0
def call_for_per_sample_grads(module, batch_size, args, kwargs=None):
    r"""
    call_for_per_sample_grads(module, batch_size, args, kwargs=None) -> Tensor
    Invoked just like a forward pass, ``call_for_per_sample_grads`` will produce the same
    forward result. Then, when backward is invoked, the parameters of ``module``
    will have a ``grad_sample`` field populated with the per sample gradients
    instead of the regular gradients

    Args:
        module: The ``nn.Module`` to get per sample gradients with respect to. All trainable
          parameters will compute per sample gradients, located in a ``grad_sample``
          field when ``backward`` is invoked
        batch_size: The batch size of the input. Typically the input's first dimension
        args: Tuple of positional args passed to ``module`` to perform the forward pass
        kwargs: Dict of named args passed to ``module`` to perform the forward pass. Default: None

    Examples::
        >>> model = nn.Linear(4, 3)
        >>> batched_input = torch.randn(5, 4)  # batch size of 5
        >>> res = call_for_per_sample_grads(model, batched_input.shape[0], batched_input).sum()
        >>> res.backward()
        >>> assert model.weight.shape == (3, 4)
        >>> assert model.weight.grad_sample.shape == (5, 3, 4)
        >>> assert model.weight.grad == None
        >>> assert model.bias.shape == (3,)
        >>> assert model.bias.grad_sample.shape == (5, 3)
        >>> assert model.bias.grad == None

    Note::
        Does not work with any `nn.RNN`, including `nn.GRU` or `nn.LSTM`. Please use custom
        rewrites that wrap an `nn.Linear` module. See Opacus for an example
    """
    def maybe_build_expanded_weight(og_tensor):
        if og_tensor.requires_grad:
            return ExpandedWeight(og_tensor, batch_size)
        else:
            return og_tensor

    if not isinstance(module, torch.nn.Module):
        raise RuntimeError(
            f"Module passed must be nn.Module, got {type(module).__name__}")
    if not isinstance(batch_size, int):
        raise RuntimeError(
            f"Batch size passed must be an integer, got {type(batch_size).__name__}"
        )
    if batch_size < 1:
        raise RuntimeError(f"Batch size must be positive, got {batch_size}")
    for weight in module.parameters():
        if hasattr(
                weight, "grad_sample"
        ) and weight.grad_sample is not None:  # type: ignore[attr-defined]
            raise RuntimeError(
                "Current Expanded Weights accumulates the gradients, which will be incorrect for multiple "
                f"calls without clearing gradients. Please clear out the grad_sample parameter of {weight} or "
                "post an issue to pytorch/pytorch to prioritize correct behavior"
            )
    params = {
        name: maybe_build_expanded_weight(value)
        for (name, value) in module.named_parameters()
    }
    return functional_call(module, params, args, kwargs)
# Another way to use ``nn.Module`` with forward AD is to utilize
# the stateless API. NB: At the time of writing the stateless API is still
# experimental and may be subject to change.

from torch.nn.utils._stateless import functional_call

# We need a fresh module because the functional call requires the
# the model to have parameters registered.
model = nn.Linear(5, 5)

dual_params = {}
with fwAD.dual_level():
    for name, p in params.items():
        # Using the same ``tangents`` from the above section
        dual_params[name] = fwAD.make_dual(p, tangents[name])
    out = functional_call(model, dual_params, input)
    jvp2 = fwAD.unpack_dual(out).tangent

# Check our results
assert torch.allclose(jvp, jvp2)

######################################################################
# Custom autograd Function
# --------------------------------------------------------------------
# Custom Functions also support forward-mode AD. To create custom Function
# supporting forward-mode AD, register the ``jvp()`` static method. It is
# possible, but not mandatory for custom Functions to support both forward
# and backward AD. See the
# `documentation <https://pytorch.org/docs/master/notes/extending.html#forward-mode-ad>`_
# for more information.
Esempio n. 10
0
 def func(*params: torch.Tensor, _input: torch.Tensor = None):
     _output: torch.Tensor = _stateless.functional_call(
         module, {n: p for n, p in zip(names, params)}, _input)
     return _output  # (N, C)