def test_scriptable(self): # TODO: Need to fix the scripting in parametrizations # Currently, all the tests below will throw UnsupportedNodeError model = nn.Linear(5, 5) parametrize.register_parametrization(model, "weight", self.Symmetric()) x = torch.randn(3, 5) y = model(x) with self.assertRaises(torch.jit.frontend.UnsupportedNodeError): # Check scripting works scripted_model = torch.jit.script(model) y_hat = scripted_model(x) self.assertEqual(y, y_hat) with parametrize.cached(): # Check scripted model works when caching y_hat = scripted_model(x) self.assertEqual(y, y_hat) # Check the scripting process throws an error when caching with self.assertRaisesRegex(RuntimeError, 'Caching is not implemented'): scripted_model = torch.jit.trace_module(model)
def load_state_dict(self, state_dict, strict=True): module_groups = copy.deepcopy(state_dict['module_groups']) states = state_dict['state'] for fqn, s in states.items(): layer = fqn_to_module(self.model, fqn) if strict and layer is None: raise RuntimeError(f'Error loading {fqn} into the model') found = False for p in layer.parametrizations['weight']: if isinstance(p, FakeSparsity): found = True break if not found: p = FakeSparsity(torch.ones(layer.weight.shape)) parametrize.register_parametrization(layer, 'weight', p) if s.get('mask', None) is not None: mask = s.pop('mask') p.mask = mask for mg in module_groups: if mg['fqn'] == fqn: mg['module'] = layer self.__setstate__({'state': states, 'module_groups': module_groups})
def test_traceable(self): r"""Test the jit scripting and tracing of a parametrized model.""" model = nn.Linear(5, 5) parametrize.register_parametrization(model, "weight", self.Symmetric()) x = torch.randn(3, 5) y = model(x) # Check the tracing works. Because traced functions cannot be called # directly, we run the comparison on the activations. traced_model = torch.jit.trace_module(model, {'forward': x}) y_hat = traced_model(x) self.assertEqual(y, y_hat) # Check traced model works with caching with parametrize.cached(): y_hat = traced_model(x) self.assertEqual(y, y_hat) # Check the tracing throws an error when caching with self.assertRaisesRegex(RuntimeError, 'Cannot trace a model while caching'): with parametrize.cached(): traced_model = torch.jit.trace_module(model, {'forward': x})
def test_jit_trace(self): model = ModelUnderTest(bias=False) mask = torch.eye(16) parametrize.register_parametrization(model.linear, 'weight', utils.FakeSparsity(mask)) mask = torch.eye(16) parametrize.register_parametrization(model.seq[0], 'weight', utils.FakeSparsity(mask)) mask = torch.eye(16) parametrize.register_parametrization(model.seq[1], 'weight', utils.FakeSparsity(mask)) # Tracing example_x = torch.ones(3, 16) model_trace = torch.jit.trace_module(model, {'forward': example_x}) x = torch.randn(3, 16) y = model(x) y_hat = model_trace(x) self.assertEqual(y_hat, y)
def test_weights_parametrized(self): model = ModelUnderTest(bias=False) assert not hasattr(model.linear, 'parametrizations') assert not hasattr(model.seq[0], 'parametrizations') assert not hasattr(model.seq[1], 'parametrizations') mask = torch.eye(16) parametrize.register_parametrization(model.linear, 'weight', utils.FakeSparsity(mask)) mask = torch.eye(16) parametrize.register_parametrization(model.seq[0], 'weight', utils.FakeSparsity(mask)) mask = torch.eye(16) parametrize.register_parametrization(model.seq[1], 'weight', utils.FakeSparsity(mask)) assert hasattr(model.linear, 'parametrizations') assert parametrize.is_parametrized(model.linear, 'weight') assert hasattr(model.seq[0], 'parametrizations') assert parametrize.is_parametrized(model.linear, 'weight') assert hasattr(model.seq[1], 'parametrizations') assert parametrize.is_parametrized(model.linear, 'weight')
# -------------------------------- # # Parametrizations can solve all these problems as well as others. # # Let's start by reimplementing the code above using ``torch.nn.utils.parametrize``. # The only thing that we have to do is to write the parametrization as a regular ``nn.Module`` class Symmetric(nn.Module): def forward(self, X): return X.triu() + X.triu(1).transpose(-1, -2) ############################################################################### # This is all we need to do. Once we have this, we can transform any regular layer into a # symmetric layer by doing layer = nn.Linear(3, 3) parametrize.register_parametrization(layer, "weight", Symmetric()) ############################################################################### # Now, the matrix of the linear layer is symmetric A = layer.weight print(A) assert torch.allclose(A, A.T) ############################################################################### # We can do the same thing with any other layer. For example, we can create a CNN with # `skew-symmetric <https://en.wikipedia.org/wiki/Skew-symmetric_matrix>`_ kernels. # We use a similar parametrization, copying the upper-triangular part with signs # reversed into the lower-triangular part class Skew(nn.Module): def forward(self, X):
def __init__(self, *args, parametrize: bool = True, **kwargs) -> None: super().__init__(*args, **kwargs) self.parametrize = parametrize if parametrize: P.register_parametrization(self, 'weight', Std())
def test_state_dict_preserved(self): model_save = ModelUnderTest(bias=False) mask = torch.eye(16) parametrize.register_parametrization(model_save.linear, 'weight', utils.FakeSparsity(mask)) mask = torch.eye(16) parametrize.register_parametrization(model_save.seq[0], 'weight', utils.FakeSparsity(mask)) mask = torch.eye(16) parametrize.register_parametrization(model_save.seq[1], 'weight', utils.FakeSparsity(mask)) state_dict = model_save.state_dict() model_load = ModelUnderTest(bias=False) mask = torch.zeros(model_load.linear.weight.shape) parametrize.register_parametrization(model_load.linear, 'weight', utils.FakeSparsity(mask)) mask = torch.zeros(model_load.seq[0].weight.shape) parametrize.register_parametrization(model_load.seq[0], 'weight', utils.FakeSparsity(mask)) mask = torch.zeros(model_load.seq[1].weight.shape) parametrize.register_parametrization(model_load.seq[1], 'weight', utils.FakeSparsity(mask)) # Keep this strict, as we are not loading the 'mask' model_load.load_state_dict(state_dict, strict=False) # Check the parametrizations are preserved assert hasattr(model_load.linear, 'parametrizations') assert parametrize.is_parametrized(model_load.linear, 'weight') assert hasattr(model_load.seq[0], 'parametrizations') assert parametrize.is_parametrized(model_load.linear, 'weight') assert hasattr(model_load.seq[1], 'parametrizations') assert parametrize.is_parametrized(model_load.linear, 'weight') # Check the weigths are preserved self.assertEqual(model_save.linear.parametrizations['weight'].original, model_load.linear.parametrizations['weight'].original) self.assertEqual(model_save.seq[0].parametrizations['weight'].original, model_load.seq[0].parametrizations['weight'].original) self.assertEqual(model_save.seq[1].parametrizations['weight'].original, model_load.seq[1].parametrizations['weight'].original) # Check the masks are not preserved in the state_dict # We store the state_dicts in the sparsifier, not in the model itself. # TODO: Need to find a clean way of exporting the parametrized model self.assertNotEqual(model_save.linear.parametrizations['weight'][0].mask, model_load.linear.parametrizations['weight'][0].mask) self.assertNotEqual(model_save.seq[0].parametrizations['weight'][0].mask, model_load.seq[0].parametrizations['weight'][0].mask) self.assertNotEqual(model_save.seq[1].parametrizations['weight'][0].mask, model_load.seq[1].parametrizations['weight'][0].mask)
def test_state_dict_preserved(self): model_save = ModelUnderTest(bias=False) mask = torch.eye(16) parametrize.register_parametrization(model_save.linear, 'weight', utils.FakeSparsity(mask)) mask = torch.eye(16) parametrize.register_parametrization(model_save.seq[0], 'weight', utils.FakeSparsity(mask)) mask = torch.eye(16) parametrize.register_parametrization(model_save.seq[1], 'weight', utils.FakeSparsity(mask)) state_dict = model_save.state_dict() model_load = ModelUnderTest(bias=False) mask = torch.zeros(model_load.linear.weight.shape) parametrize.register_parametrization(model_load.linear, 'weight', utils.FakeSparsity(mask)) mask = torch.zeros(model_load.seq[0].weight.shape) parametrize.register_parametrization(model_load.seq[0], 'weight', utils.FakeSparsity(mask)) mask = torch.zeros(model_load.seq[1].weight.shape) parametrize.register_parametrization(model_load.seq[1], 'weight', utils.FakeSparsity(mask)) model_load.load_state_dict(state_dict) # Check the parametrizations are preserved assert hasattr(model_load.linear, 'parametrizations') assert parametrize.is_parametrized(model_load.linear, 'weight') assert hasattr(model_load.seq[0], 'parametrizations') assert parametrize.is_parametrized(model_load.linear, 'weight') assert hasattr(model_load.seq[1], 'parametrizations') assert parametrize.is_parametrized(model_load.linear, 'weight') # Check the weigths are preserved self.assertEqual(model_save.linear.parametrizations['weight'].original, model_load.linear.parametrizations['weight'].original) self.assertEqual(model_save.seq[0].parametrizations['weight'].original, model_load.seq[0].parametrizations['weight'].original) self.assertEqual(model_save.seq[1].parametrizations['weight'].original, model_load.seq[1].parametrizations['weight'].original) # Check the masks are preserved self.assertEqual(model_save.linear.parametrizations['weight'][0].mask, model_load.linear.parametrizations['weight'][0].mask) self.assertEqual(model_save.seq[0].parametrizations['weight'][0].mask, model_load.seq[0].parametrizations['weight'][0].mask) self.assertEqual(model_save.seq[1].parametrizations['weight'][0].mask, model_load.seq[1].parametrizations['weight'][0].mask)
def _prepare(self, use_path=False, *args, **kwargs): r"""Adds mask parametrization to the layer weight """ self.activation_handles = [] # store removable hook handles self.bias_handles = [] for config in self.groups: modules, tensor_names = self._get_modules_and_tensor_names( config, use_path) for module, tensor_name in zip(modules, tensor_names): if not isinstance(module, tuple(NEEDS_ZEROS)): # add pruning parametrization and forward hooks if getattr(module, 'mask', None) is None: module.register_buffer( 'mask', torch.tensor( getattr(module, tensor_name).shape[0])) param = config.get('parametrization', PruningParametrization) parametrize.register_parametrization(module, tensor_name, param(module.mask), unsafe=True) assert isinstance(module.parametrizations, ModuleDict) # make mypy happy assert isinstance(module.parametrizations.weight, ModuleList) if isinstance(module, tuple(SUPPORTED_MODULES)): self.activation_handles.append( module.register_forward_hook( ActivationReconstruction( getattr(module.parametrizations, tensor_name)[0]))) else: raise NotImplementedError( "This module type is not supported yet.") else: # needs zeros if getattr(module, 'mask', None) is None: module.register_buffer( 'mask', torch.tensor( getattr(module, tensor_name).shape[0])) param = config.get('parametrization', ZeroesParametrization) parametrize.register_parametrization(module, tensor_name, param(module.mask), unsafe=True) if module.bias is not None: module.register_parameter( '_bias', nn.Parameter(module.bias.detach())) module.bias = None self.bias_handles.append( module.register_forward_hook( BiasHook(module.parametrizations.weight[0], self.prune_bias))) if len(modules) == 2: # (Conv2d, BN) # should have the same set of pruned outputs modules[1].parametrizations.weight[0].pruned_outputs = modules[ 0].parametrizations.weight[0].pruned_outputs