Beispiel #1
0
    def test_scriptable(self):
        # TODO: Need to fix the scripting in parametrizations
        #       Currently, all the tests below will throw UnsupportedNodeError
        model = nn.Linear(5, 5)
        parametrize.register_parametrization(model, "weight", self.Symmetric())

        x = torch.randn(3, 5)
        y = model(x)

        with self.assertRaises(torch.jit.frontend.UnsupportedNodeError):
            # Check scripting works
            scripted_model = torch.jit.script(model)
            y_hat = scripted_model(x)
            self.assertEqual(y, y_hat)

            with parametrize.cached():
                # Check scripted model works when caching
                y_hat = scripted_model(x)
                self.assertEqual(y, y_hat)

                # Check the scripting process throws an error when caching
                with self.assertRaisesRegex(RuntimeError,
                                            'Caching is not implemented'):
                    scripted_model = torch.jit.trace_module(model)
Beispiel #2
0
    def load_state_dict(self, state_dict, strict=True):
        module_groups = copy.deepcopy(state_dict['module_groups'])
        states = state_dict['state']
        for fqn, s in states.items():
            layer = fqn_to_module(self.model, fqn)
            if strict and layer is None:
                raise RuntimeError(f'Error loading {fqn} into the model')

            found = False
            for p in layer.parametrizations['weight']:
                if isinstance(p, FakeSparsity):
                    found = True
                    break
            if not found:
                p = FakeSparsity(torch.ones(layer.weight.shape))
                parametrize.register_parametrization(layer, 'weight', p)
            if s.get('mask', None) is not None:
                mask = s.pop('mask')
                p.mask = mask

            for mg in module_groups:
                if mg['fqn'] == fqn:
                    mg['module'] = layer
        self.__setstate__({'state': states, 'module_groups': module_groups})
Beispiel #3
0
    def test_traceable(self):
        r"""Test the jit scripting and tracing of a parametrized model."""
        model = nn.Linear(5, 5)
        parametrize.register_parametrization(model, "weight", self.Symmetric())

        x = torch.randn(3, 5)
        y = model(x)

        # Check the tracing works. Because traced functions cannot be called
        # directly, we run the comparison on the activations.
        traced_model = torch.jit.trace_module(model, {'forward': x})
        y_hat = traced_model(x)
        self.assertEqual(y, y_hat)

        # Check traced model works with caching
        with parametrize.cached():
            y_hat = traced_model(x)
            self.assertEqual(y, y_hat)

        # Check the tracing throws an error when caching
        with self.assertRaisesRegex(RuntimeError,
                                    'Cannot trace a model while caching'):
            with parametrize.cached():
                traced_model = torch.jit.trace_module(model, {'forward': x})
Beispiel #4
0
    def test_jit_trace(self):
        model = ModelUnderTest(bias=False)

        mask = torch.eye(16)
        parametrize.register_parametrization(model.linear, 'weight',
                                             utils.FakeSparsity(mask))
        mask = torch.eye(16)
        parametrize.register_parametrization(model.seq[0], 'weight',
                                             utils.FakeSparsity(mask))
        mask = torch.eye(16)
        parametrize.register_parametrization(model.seq[1], 'weight',
                                             utils.FakeSparsity(mask))

        # Tracing
        example_x = torch.ones(3, 16)
        model_trace = torch.jit.trace_module(model, {'forward': example_x})

        x = torch.randn(3, 16)
        y = model(x)
        y_hat = model_trace(x)
        self.assertEqual(y_hat, y)
Beispiel #5
0
    def test_weights_parametrized(self):
        model = ModelUnderTest(bias=False)

        assert not hasattr(model.linear, 'parametrizations')
        assert not hasattr(model.seq[0], 'parametrizations')
        assert not hasattr(model.seq[1], 'parametrizations')
        mask = torch.eye(16)
        parametrize.register_parametrization(model.linear, 'weight',
                                             utils.FakeSparsity(mask))
        mask = torch.eye(16)
        parametrize.register_parametrization(model.seq[0], 'weight',
                                             utils.FakeSparsity(mask))
        mask = torch.eye(16)
        parametrize.register_parametrization(model.seq[1], 'weight',
                                             utils.FakeSparsity(mask))

        assert hasattr(model.linear, 'parametrizations')
        assert parametrize.is_parametrized(model.linear, 'weight')
        assert hasattr(model.seq[0], 'parametrizations')
        assert parametrize.is_parametrized(model.linear, 'weight')
        assert hasattr(model.seq[1], 'parametrizations')
        assert parametrize.is_parametrized(model.linear, 'weight')
Beispiel #6
0
# --------------------------------
#
# Parametrizations can solve all these problems as well as others.
#
# Let's start by reimplementing the code above using ``torch.nn.utils.parametrize``.
# The only thing that we have to do is to write the parametrization as a regular ``nn.Module``
class Symmetric(nn.Module):
    def forward(self, X):
        return X.triu() + X.triu(1).transpose(-1, -2)


###############################################################################
# This is all we need to do. Once we have this, we can transform any regular layer into a
# symmetric layer by doing
layer = nn.Linear(3, 3)
parametrize.register_parametrization(layer, "weight", Symmetric())

###############################################################################
# Now, the matrix of the linear layer is symmetric
A = layer.weight
print(A)
assert torch.allclose(A, A.T)


###############################################################################
# We can do the same thing with any other layer. For example, we can create a CNN with
# `skew-symmetric <https://en.wikipedia.org/wiki/Skew-symmetric_matrix>`_ kernels.
# We use a similar parametrization, copying the upper-triangular part with signs
# reversed into the lower-triangular part
class Skew(nn.Module):
    def forward(self, X):
Beispiel #7
0
 def __init__(self, *args, parametrize: bool = True, **kwargs) -> None:
     super().__init__(*args, **kwargs)
     self.parametrize = parametrize
     if parametrize:
         P.register_parametrization(self, 'weight', Std())
Beispiel #8
0
    def test_state_dict_preserved(self):
        model_save = ModelUnderTest(bias=False)

        mask = torch.eye(16)
        parametrize.register_parametrization(model_save.linear, 'weight',
                                             utils.FakeSparsity(mask))
        mask = torch.eye(16)
        parametrize.register_parametrization(model_save.seq[0], 'weight',
                                             utils.FakeSparsity(mask))
        mask = torch.eye(16)
        parametrize.register_parametrization(model_save.seq[1], 'weight',
                                             utils.FakeSparsity(mask))
        state_dict = model_save.state_dict()

        model_load = ModelUnderTest(bias=False)
        mask = torch.zeros(model_load.linear.weight.shape)
        parametrize.register_parametrization(model_load.linear, 'weight',
                                             utils.FakeSparsity(mask))
        mask = torch.zeros(model_load.seq[0].weight.shape)
        parametrize.register_parametrization(model_load.seq[0], 'weight',
                                             utils.FakeSparsity(mask))
        mask = torch.zeros(model_load.seq[1].weight.shape)
        parametrize.register_parametrization(model_load.seq[1], 'weight',
                                             utils.FakeSparsity(mask))
        # Keep this strict, as we are not loading the 'mask'
        model_load.load_state_dict(state_dict, strict=False)

        # Check the parametrizations are preserved
        assert hasattr(model_load.linear, 'parametrizations')
        assert parametrize.is_parametrized(model_load.linear, 'weight')
        assert hasattr(model_load.seq[0], 'parametrizations')
        assert parametrize.is_parametrized(model_load.linear, 'weight')
        assert hasattr(model_load.seq[1], 'parametrizations')
        assert parametrize.is_parametrized(model_load.linear, 'weight')

        # Check the weigths are preserved
        self.assertEqual(model_save.linear.parametrizations['weight'].original,
                         model_load.linear.parametrizations['weight'].original)
        self.assertEqual(model_save.seq[0].parametrizations['weight'].original,
                         model_load.seq[0].parametrizations['weight'].original)
        self.assertEqual(model_save.seq[1].parametrizations['weight'].original,
                         model_load.seq[1].parametrizations['weight'].original)

        # Check the masks are not preserved in the state_dict
        # We store the state_dicts in the sparsifier, not in the model itself.
        # TODO: Need to find a clean way of exporting the parametrized model
        self.assertNotEqual(model_save.linear.parametrizations['weight'][0].mask,
                            model_load.linear.parametrizations['weight'][0].mask)
        self.assertNotEqual(model_save.seq[0].parametrizations['weight'][0].mask,
                            model_load.seq[0].parametrizations['weight'][0].mask)
        self.assertNotEqual(model_save.seq[1].parametrizations['weight'][0].mask,
                            model_load.seq[1].parametrizations['weight'][0].mask)
    def test_state_dict_preserved(self):
        model_save = ModelUnderTest(bias=False)

        mask = torch.eye(16)
        parametrize.register_parametrization(model_save.linear, 'weight',
                                             utils.FakeSparsity(mask))
        mask = torch.eye(16)
        parametrize.register_parametrization(model_save.seq[0], 'weight',
                                             utils.FakeSparsity(mask))
        mask = torch.eye(16)
        parametrize.register_parametrization(model_save.seq[1], 'weight',
                                             utils.FakeSparsity(mask))
        state_dict = model_save.state_dict()

        model_load = ModelUnderTest(bias=False)
        mask = torch.zeros(model_load.linear.weight.shape)
        parametrize.register_parametrization(model_load.linear, 'weight',
                                             utils.FakeSparsity(mask))
        mask = torch.zeros(model_load.seq[0].weight.shape)
        parametrize.register_parametrization(model_load.seq[0], 'weight',
                                             utils.FakeSparsity(mask))
        mask = torch.zeros(model_load.seq[1].weight.shape)
        parametrize.register_parametrization(model_load.seq[1], 'weight',
                                             utils.FakeSparsity(mask))
        model_load.load_state_dict(state_dict)

        # Check the parametrizations are preserved
        assert hasattr(model_load.linear, 'parametrizations')
        assert parametrize.is_parametrized(model_load.linear, 'weight')
        assert hasattr(model_load.seq[0], 'parametrizations')
        assert parametrize.is_parametrized(model_load.linear, 'weight')
        assert hasattr(model_load.seq[1], 'parametrizations')
        assert parametrize.is_parametrized(model_load.linear, 'weight')

        # Check the weigths are preserved
        self.assertEqual(model_save.linear.parametrizations['weight'].original,
                         model_load.linear.parametrizations['weight'].original)
        self.assertEqual(model_save.seq[0].parametrizations['weight'].original,
                         model_load.seq[0].parametrizations['weight'].original)
        self.assertEqual(model_save.seq[1].parametrizations['weight'].original,
                         model_load.seq[1].parametrizations['weight'].original)

        # Check the masks are preserved
        self.assertEqual(model_save.linear.parametrizations['weight'][0].mask,
                         model_load.linear.parametrizations['weight'][0].mask)
        self.assertEqual(model_save.seq[0].parametrizations['weight'][0].mask,
                         model_load.seq[0].parametrizations['weight'][0].mask)
        self.assertEqual(model_save.seq[1].parametrizations['weight'][0].mask,
                         model_load.seq[1].parametrizations['weight'][0].mask)
Beispiel #10
0
    def _prepare(self, use_path=False, *args, **kwargs):
        r"""Adds mask parametrization to the layer weight
        """
        self.activation_handles = []  # store removable hook handles
        self.bias_handles = []

        for config in self.groups:
            modules, tensor_names = self._get_modules_and_tensor_names(
                config, use_path)

            for module, tensor_name in zip(modules, tensor_names):
                if not isinstance(module, tuple(NEEDS_ZEROS)):
                    # add pruning parametrization and forward hooks
                    if getattr(module, 'mask', None) is None:
                        module.register_buffer(
                            'mask',
                            torch.tensor(
                                getattr(module, tensor_name).shape[0]))
                    param = config.get('parametrization',
                                       PruningParametrization)
                    parametrize.register_parametrization(module,
                                                         tensor_name,
                                                         param(module.mask),
                                                         unsafe=True)

                    assert isinstance(module.parametrizations,
                                      ModuleDict)  # make mypy happy
                    assert isinstance(module.parametrizations.weight,
                                      ModuleList)
                    if isinstance(module, tuple(SUPPORTED_MODULES)):
                        self.activation_handles.append(
                            module.register_forward_hook(
                                ActivationReconstruction(
                                    getattr(module.parametrizations,
                                            tensor_name)[0])))
                    else:
                        raise NotImplementedError(
                            "This module type is not supported yet.")

                else:  # needs zeros
                    if getattr(module, 'mask', None) is None:
                        module.register_buffer(
                            'mask',
                            torch.tensor(
                                getattr(module, tensor_name).shape[0]))
                    param = config.get('parametrization',
                                       ZeroesParametrization)
                    parametrize.register_parametrization(module,
                                                         tensor_name,
                                                         param(module.mask),
                                                         unsafe=True)

                if module.bias is not None:
                    module.register_parameter(
                        '_bias', nn.Parameter(module.bias.detach()))
                    module.bias = None
                self.bias_handles.append(
                    module.register_forward_hook(
                        BiasHook(module.parametrizations.weight[0],
                                 self.prune_bias)))

            if len(modules) == 2:  # (Conv2d, BN)
                # should have the same set of pruned outputs
                modules[1].parametrizations.weight[0].pruned_outputs = modules[
                    0].parametrizations.weight[0].pruned_outputs