def test_reparamertize_module_fail_reset_to_original(self): module = MockModule() torch.nn.utils.parametrizations.spectral_norm(module.l1) self.assertTrue('l1.parametrizations.weight.original' in dict( module.named_parameters())) orig_sn_weight = module.l1.weight.clone() # We substitute the parameter inside the parametrization # the parametrization itself is not overwritten so it will be applied with a different # value for the original tensor parameters = { 'l1.parametrizations.weight.original': torch.nn.Parameter(torch.tensor([[1.0]])), 'l1.bias': torch.tensor([0.0]), 'buffer': torch.tensor([0.0]) } with self.assertRaisesRegex(RuntimeError, "shapes cannot be multiplied"): x = torch.rand((4, 5)) # to work, it should be of size (1, 1) stateless.functional_call( module, parameters, x) # this call will fail because x is the wrong size # verify that the spectral normalization is still applied self.assertTrue('l1.parametrizations.weight.original' in dict( module.named_parameters())) self.assertEqual(orig_sn_weight, module.l1.weight)
def test_setattr(self): class Foo(torch.nn.Module): def __init__(self): super().__init__() self.register_buffer('foo', torch.zeros(())) def forward(self, x): self.foo = self.foo + 1 return x + self.foo a = {'foo': torch.zeros(())} mod = Foo() stateless.functional_call(mod, a, torch.ones(())) self.assertEqual(mod.foo, torch.zeros(())) self.assertEqual(a['foo'], torch.ones(()))
def _run_call_with_mock_module(self, module, device='cpu', prefix=''): x = torch.rand((1, 1)).to(device) weight = torch.tensor([[1.0]], device=device) bias = torch.tensor([0.0], device=device) buffer = torch.tensor([0.0], device=device) if prefix != '': parameters = {f'{prefix}.l1.weight': weight, f'{prefix}.l1.bias': bias, f'{prefix}.buffer': buffer} else: parameters = {'l1.weight': weight, 'l1.bias': bias, 'buffer': buffer} to_check = module if prefix != '': to_check = getattr(module, prefix) prev_weight = to_check.l1.weight.clone() prev_buffer = to_check.buffer.clone() # the parameters represent an identity function contrary to the # existing params in module. So here we expect the result to be the # same as the input if the weight swapping went well. res = stateless.functional_call(module, parameters, x) self.assertEqual(x, res) # check that the weight remain unmodified cur_weight = to_check.l1.weight cur_buffer = to_check.buffer self.assertEqual(cur_weight, prev_weight) self.assertEqual(cur_buffer, prev_buffer)
def test_functional_batch_norm(self): module = torch.nn.BatchNorm1d(10) module.train() # Allow stats update # lets replace the running_mean buffer and check if its correctly updated x = torch.full((20, 10), 128.0) rm = torch.zeros(10) parameters = {'running_mean': rm} prev_rm = module.running_mean.clone() res = stateless.functional_call(module, parameters, x) cur_rm = module.running_mean self.assertEqual(cur_rm, prev_rm) self.assertEqual(rm, torch.full((10,), 12.8)) # Now run functional without reparametrization and check that the module has # been updated res = stateless.functional_call(module, {}, x) self.assertEqual(module.running_mean, torch.full((10,), 12.8))
def wrapper(*args, **kwargs): wrapper_batch_size = batch_size if wrapper_batch_size is None: wrapper_batch_size = compute_batch_size(*args, **kwargs) params = { name: maybe_build_expanded_weight(value, wrapper_batch_size) for (name, value) in module.named_parameters() } return functional_call(module, params, args, kwargs)
def test_reparametrized_module_change_parametrization_original(self): module = MockModule() torch.nn.utils.parametrizations.spectral_norm(module.l1) self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters())) orig_sn_weight = module.l1.weight.clone() x = torch.rand((1, 1)) # We substitute the parameter inside the parametrization # the parametrization itself is not overwritten so it will be applied with a different # value for the original tensor parameters = {'l1.parametrizations.weight.original': torch.nn.Parameter(torch.tensor([[1.0]])), 'l1.bias': torch.tensor([0.0]), 'buffer': torch.tensor([0.0])} res = stateless.functional_call(module, parameters, x) self.assertEqual(x, res) # verify that the spectral normalization is still applied self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters())) self.assertEqual(orig_sn_weight, module.l1.weight)
def test_functional_call_with_gradient(self): module = MockModule() x = torch.rand((1, 1)) weight = torch.tensor([[1.0]], requires_grad=True) bias = torch.tensor([0.0], requires_grad=True) buffer = torch.tensor([0.0]) parameters = {'l1.weight': weight, 'l1.bias': bias, 'buffer': buffer} res = stateless.functional_call(module, parameters, x) # Check that a backward step calculates the gradient of the supplied parameters res.backward() self.assertIsNotNone(weight.grad) self.assertIsNotNone(bias.grad) self.assertIsNone(buffer.grad) # Gradient was not calculated for the module stated and buffers self.assertIsNone(module.l1.weight.grad) self.assertIsNone(module.l1.bias.grad) self.assertIsNone(module.buffer.grad)
def test_circular_references(self): module = MockModule() # Add a circular reference module.l1.m = module x = torch.rand((1, 1)) weight = torch.tensor([[1.0]]) bias = torch.tensor([0.0]) buffer = torch.tensor([0.0]) parameters = {'l1.m.l1.weight': weight, 'l1.bias': bias, 'l1.m.buffer': buffer} prev_weight = module.l1.weight.clone() prev_buffer = module.buffer.clone() res = stateless.functional_call(module, parameters, x) self.assertEqual(x, res) # check that the weights remain unmodified and were correctly accesed cur_weight = module.l1.weight cur_buffer = module.buffer self.assertEqual(cur_weight, prev_weight) self.assertEqual(cur_buffer, prev_buffer)
def functional_call(named_params, named_buffers, *args, **kwargs): params_and_buffers = {**named_params, **named_buffers} return stateless.functional_call(mod, params_and_buffers, args, kwargs)