def test_traceable(self): r"""Test the jit scripting and tracing of a parametrized model.""" model = nn.Linear(5, 5) parametrize.register_parametrization(model, "weight", self.Symmetric()) x = torch.randn(3, 5) y = model(x) # Check the tracing works. Because traced functions cannot be called # directly, we run the comparison on the activations. traced_model = torch.jit.trace_module(model, {'forward': x}) y_hat = traced_model(x) self.assertEqual(y, y_hat) # Check traced model works with caching with parametrize.cached(): y_hat = traced_model(x) self.assertEqual(y, y_hat) # Check the tracing throws an error when caching with self.assertRaisesRegex(RuntimeError, 'Cannot trace a model while caching'): with parametrize.cached(): traced_model = torch.jit.trace_module(model, {'forward': x})
def test_scriptable(self): # TODO: Need to fix the scripting in parametrizations # Currently, all the tests below will throw UnsupportedNodeError model = nn.Linear(5, 5) parametrize.register_parametrization(model, "weight", self.Symmetric()) x = torch.randn(3, 5) y = model(x) with self.assertRaises(torch.jit.frontend.UnsupportedNodeError): # Check scripting works scripted_model = torch.jit.script(model) y_hat = scripted_model(x) self.assertEqual(y, y_hat) with parametrize.cached(): # Check scripted model works when caching y_hat = scripted_model(x) self.assertEqual(y, y_hat) # Check the scripting process throws an error when caching with self.assertRaisesRegex(RuntimeError, 'Caching is not implemented'): scripted_model = torch.jit.trace_module(model)
# -------------------------------------- # # Parametrizations come with an inbuilt caching system via the context manager # ``parametrize.cached()`` class NoisyParametrization(nn.Module): def forward(self, X): print("Computing the Parametrization") return X layer = nn.Linear(4, 4) parametrize.register_parametrization(layer, "weight", NoisyParametrization()) print("Here, layer.weight is recomputed every time we call it") foo = layer.weight + layer.weight.T bar = layer.weight.sum() with parametrize.cached(): print("Here, it is computed just the first time layer.weight is called") foo = layer.weight + layer.weight.T bar = layer.weight.sum() ############################################################################### # Concatenating Parametrizations # ------------------------------ # # Concatenating two parametrizations is as easy as registering them on the same tensor. # We may use this to create more complex parametrizations from simpler ones. For example, the # `Cayley map <https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map>`_ # maps the skew-symmetric matrices to the orthogonal matrices of positive determinant. We can # concatenate ``Skew`` and a parametrization that implements the Cayley map to get a layer with # orthogonal weights