Example #1
0
    def test_traceable(self):
        r"""Test the jit scripting and tracing of a parametrized model."""
        model = nn.Linear(5, 5)
        parametrize.register_parametrization(model, "weight", self.Symmetric())

        x = torch.randn(3, 5)
        y = model(x)

        # Check the tracing works. Because traced functions cannot be called
        # directly, we run the comparison on the activations.
        traced_model = torch.jit.trace_module(model, {'forward': x})
        y_hat = traced_model(x)
        self.assertEqual(y, y_hat)

        # Check traced model works with caching
        with parametrize.cached():
            y_hat = traced_model(x)
            self.assertEqual(y, y_hat)

        # Check the tracing throws an error when caching
        with self.assertRaisesRegex(RuntimeError,
                                    'Cannot trace a model while caching'):
            with parametrize.cached():
                traced_model = torch.jit.trace_module(model, {'forward': x})
Example #2
0
    def test_scriptable(self):
        # TODO: Need to fix the scripting in parametrizations
        #       Currently, all the tests below will throw UnsupportedNodeError
        model = nn.Linear(5, 5)
        parametrize.register_parametrization(model, "weight", self.Symmetric())

        x = torch.randn(3, 5)
        y = model(x)

        with self.assertRaises(torch.jit.frontend.UnsupportedNodeError):
            # Check scripting works
            scripted_model = torch.jit.script(model)
            y_hat = scripted_model(x)
            self.assertEqual(y, y_hat)

            with parametrize.cached():
                # Check scripted model works when caching
                y_hat = scripted_model(x)
                self.assertEqual(y, y_hat)

                # Check the scripting process throws an error when caching
                with self.assertRaisesRegex(RuntimeError, 'Caching is not implemented'):
                    scripted_model = torch.jit.trace_module(model)
Example #3
0
# --------------------------------------
#
# Parametrizations come with an inbuilt caching system via the context manager
# ``parametrize.cached()``
class NoisyParametrization(nn.Module):
    def forward(self, X):
        print("Computing the Parametrization")
        return X


layer = nn.Linear(4, 4)
parametrize.register_parametrization(layer, "weight", NoisyParametrization())
print("Here, layer.weight is recomputed every time we call it")
foo = layer.weight + layer.weight.T
bar = layer.weight.sum()
with parametrize.cached():
    print("Here, it is computed just the first time layer.weight is called")
    foo = layer.weight + layer.weight.T
    bar = layer.weight.sum()


###############################################################################
# Concatenating Parametrizations
# ------------------------------
#
# Concatenating two parametrizations is as easy as registering them on the same tensor.
# We may use this to create more complex parametrizations from simpler ones. For example, the
# `Cayley map <https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map>`_
# maps the skew-symmetric matrices to the orthogonal matrices of positive determinant. We can
# concatenate ``Skew`` and a parametrization that implements the Cayley map to get a layer with
# orthogonal weights