def test_masking_logic(self): model = nn.Linear(16, 16, bias=False) model.weight = nn.Parameter(torch.eye(16)) x = torch.randn(3, 16) self.assertEqual(torch.mm(x, torch.eye(16)), model(x)) mask = torch.zeros(16, 16) sparsity = utils.FakeSparsity(mask) parametrize.register_parametrization(model, 'weight', sparsity) x = torch.randn(3, 16) self.assertEqual(torch.zeros(3, 16), model(x))
def test_jit_trace(self): model = ModelUnderTest(bias=False) mask = torch.eye(16) parametrize.register_parametrization(model.linear, 'weight', utils.FakeSparsity(mask)) mask = torch.eye(16) parametrize.register_parametrization(model.seq[0], 'weight', utils.FakeSparsity(mask)) mask = torch.eye(16) parametrize.register_parametrization(model.seq[1], 'weight', utils.FakeSparsity(mask)) # Tracing example_x = torch.ones(3, 16) model_trace = torch.jit.trace_module(model, {'forward': example_x}) x = torch.randn(3, 16) y = model(x) y_hat = model_trace(x) self.assertEqual(y_hat, y)
def test_weights_parametrized(self): model = ModelUnderTest(bias=False) assert not hasattr(model.linear, 'parametrizations') assert not hasattr(model.seq[0], 'parametrizations') assert not hasattr(model.seq[1], 'parametrizations') mask = torch.eye(16) parametrize.register_parametrization(model.linear, 'weight', utils.FakeSparsity(mask)) mask = torch.eye(16) parametrize.register_parametrization(model.seq[0], 'weight', utils.FakeSparsity(mask)) mask = torch.eye(16) parametrize.register_parametrization(model.seq[1], 'weight', utils.FakeSparsity(mask)) assert hasattr(model.linear, 'parametrizations') assert parametrize.is_parametrized(model.linear, 'weight') assert hasattr(model.seq[0], 'parametrizations') assert parametrize.is_parametrized(model.linear, 'weight') assert hasattr(model.seq[1], 'parametrizations') assert parametrize.is_parametrized(model.linear, 'weight')
def test_state_dict_preserved(self): model_save = ModelUnderTest(bias=False) mask = torch.eye(16) parametrize.register_parametrization(model_save.linear, 'weight', utils.FakeSparsity(mask)) mask = torch.eye(16) parametrize.register_parametrization(model_save.seq[0], 'weight', utils.FakeSparsity(mask)) mask = torch.eye(16) parametrize.register_parametrization(model_save.seq[1], 'weight', utils.FakeSparsity(mask)) state_dict = model_save.state_dict() model_load = ModelUnderTest(bias=False) mask = torch.zeros(model_load.linear.weight.shape) parametrize.register_parametrization(model_load.linear, 'weight', utils.FakeSparsity(mask)) mask = torch.zeros(model_load.seq[0].weight.shape) parametrize.register_parametrization(model_load.seq[0], 'weight', utils.FakeSparsity(mask)) mask = torch.zeros(model_load.seq[1].weight.shape) parametrize.register_parametrization(model_load.seq[1], 'weight', utils.FakeSparsity(mask)) # Keep this strict, as we are not loading the 'mask' model_load.load_state_dict(state_dict, strict=False) # Check the parametrizations are preserved assert hasattr(model_load.linear, 'parametrizations') assert parametrize.is_parametrized(model_load.linear, 'weight') assert hasattr(model_load.seq[0], 'parametrizations') assert parametrize.is_parametrized(model_load.linear, 'weight') assert hasattr(model_load.seq[1], 'parametrizations') assert parametrize.is_parametrized(model_load.linear, 'weight') # Check the weigths are preserved self.assertEqual(model_save.linear.parametrizations['weight'].original, model_load.linear.parametrizations['weight'].original) self.assertEqual(model_save.seq[0].parametrizations['weight'].original, model_load.seq[0].parametrizations['weight'].original) self.assertEqual(model_save.seq[1].parametrizations['weight'].original, model_load.seq[1].parametrizations['weight'].original) # Check the masks are not preserved in the state_dict # We store the state_dicts in the sparsifier, not in the model itself. # TODO: Need to find a clean way of exporting the parametrized model self.assertNotEqual( model_save.linear.parametrizations['weight'][0].mask, model_load.linear.parametrizations['weight'][0].mask) self.assertNotEqual( model_save.seq[0].parametrizations['weight'][0].mask, model_load.seq[0].parametrizations['weight'][0].mask) self.assertNotEqual( model_save.seq[1].parametrizations['weight'][0].mask, model_load.seq[1].parametrizations['weight'][0].mask)
def test_state_dict_preserved(self): model_save = ModelUnderTest(bias=False) mask = torch.eye(16) parametrize.register_parametrization(model_save.linear, 'weight', utils.FakeSparsity(mask)) mask = torch.eye(16) parametrize.register_parametrization(model_save.seq[0], 'weight', utils.FakeSparsity(mask)) mask = torch.eye(16) parametrize.register_parametrization(model_save.seq[1], 'weight', utils.FakeSparsity(mask)) state_dict = model_save.state_dict() model_load = ModelUnderTest(bias=False) mask = torch.zeros(model_load.linear.weight.shape) parametrize.register_parametrization(model_load.linear, 'weight', utils.FakeSparsity(mask)) mask = torch.zeros(model_load.seq[0].weight.shape) parametrize.register_parametrization(model_load.seq[0], 'weight', utils.FakeSparsity(mask)) mask = torch.zeros(model_load.seq[1].weight.shape) parametrize.register_parametrization(model_load.seq[1], 'weight', utils.FakeSparsity(mask)) model_load.load_state_dict(state_dict) # Check the parametrizations are preserved assert hasattr(model_load.linear, 'parametrizations') assert parametrize.is_parametrized(model_load.linear, 'weight') assert hasattr(model_load.seq[0], 'parametrizations') assert parametrize.is_parametrized(model_load.linear, 'weight') assert hasattr(model_load.seq[1], 'parametrizations') assert parametrize.is_parametrized(model_load.linear, 'weight') # Check the weigths are preserved self.assertEqual(model_save.linear.parametrizations['weight'].original, model_load.linear.parametrizations['weight'].original) self.assertEqual(model_save.seq[0].parametrizations['weight'].original, model_load.seq[0].parametrizations['weight'].original) self.assertEqual(model_save.seq[1].parametrizations['weight'].original, model_load.seq[1].parametrizations['weight'].original) # Check the masks are preserved self.assertEqual(model_save.linear.parametrizations['weight'][0].mask, model_load.linear.parametrizations['weight'][0].mask) self.assertEqual(model_save.seq[0].parametrizations['weight'][0].mask, model_load.seq[0].parametrizations['weight'][0].mask) self.assertEqual(model_save.seq[1].parametrizations['weight'][0].mask, model_load.seq[1].parametrizations['weight'][0].mask)