Ejemplo n.º 1
0
    def test_masking_logic(self):
        model = nn.Linear(16, 16, bias=False)
        model.weight = nn.Parameter(torch.eye(16))
        x = torch.randn(3, 16)
        self.assertEqual(torch.mm(x, torch.eye(16)), model(x))

        mask = torch.zeros(16, 16)
        sparsity = utils.FakeSparsity(mask)
        parametrize.register_parametrization(model, 'weight', sparsity)

        x = torch.randn(3, 16)
        self.assertEqual(torch.zeros(3, 16), model(x))
Ejemplo n.º 2
0
    def test_jit_trace(self):
        model = ModelUnderTest(bias=False)

        mask = torch.eye(16)
        parametrize.register_parametrization(model.linear, 'weight',
                                             utils.FakeSparsity(mask))
        mask = torch.eye(16)
        parametrize.register_parametrization(model.seq[0], 'weight',
                                             utils.FakeSparsity(mask))
        mask = torch.eye(16)
        parametrize.register_parametrization(model.seq[1], 'weight',
                                             utils.FakeSparsity(mask))

        # Tracing
        example_x = torch.ones(3, 16)
        model_trace = torch.jit.trace_module(model, {'forward': example_x})

        x = torch.randn(3, 16)
        y = model(x)
        y_hat = model_trace(x)
        self.assertEqual(y_hat, y)
Ejemplo n.º 3
0
    def test_weights_parametrized(self):
        model = ModelUnderTest(bias=False)

        assert not hasattr(model.linear, 'parametrizations')
        assert not hasattr(model.seq[0], 'parametrizations')
        assert not hasattr(model.seq[1], 'parametrizations')
        mask = torch.eye(16)
        parametrize.register_parametrization(model.linear, 'weight',
                                             utils.FakeSparsity(mask))
        mask = torch.eye(16)
        parametrize.register_parametrization(model.seq[0], 'weight',
                                             utils.FakeSparsity(mask))
        mask = torch.eye(16)
        parametrize.register_parametrization(model.seq[1], 'weight',
                                             utils.FakeSparsity(mask))

        assert hasattr(model.linear, 'parametrizations')
        assert parametrize.is_parametrized(model.linear, 'weight')
        assert hasattr(model.seq[0], 'parametrizations')
        assert parametrize.is_parametrized(model.linear, 'weight')
        assert hasattr(model.seq[1], 'parametrizations')
        assert parametrize.is_parametrized(model.linear, 'weight')
Ejemplo n.º 4
0
    def test_state_dict_preserved(self):
        model_save = ModelUnderTest(bias=False)

        mask = torch.eye(16)
        parametrize.register_parametrization(model_save.linear, 'weight',
                                             utils.FakeSparsity(mask))
        mask = torch.eye(16)
        parametrize.register_parametrization(model_save.seq[0], 'weight',
                                             utils.FakeSparsity(mask))
        mask = torch.eye(16)
        parametrize.register_parametrization(model_save.seq[1], 'weight',
                                             utils.FakeSparsity(mask))
        state_dict = model_save.state_dict()

        model_load = ModelUnderTest(bias=False)
        mask = torch.zeros(model_load.linear.weight.shape)
        parametrize.register_parametrization(model_load.linear, 'weight',
                                             utils.FakeSparsity(mask))
        mask = torch.zeros(model_load.seq[0].weight.shape)
        parametrize.register_parametrization(model_load.seq[0], 'weight',
                                             utils.FakeSparsity(mask))
        mask = torch.zeros(model_load.seq[1].weight.shape)
        parametrize.register_parametrization(model_load.seq[1], 'weight',
                                             utils.FakeSparsity(mask))
        # Keep this strict, as we are not loading the 'mask'
        model_load.load_state_dict(state_dict, strict=False)

        # Check the parametrizations are preserved
        assert hasattr(model_load.linear, 'parametrizations')
        assert parametrize.is_parametrized(model_load.linear, 'weight')
        assert hasattr(model_load.seq[0], 'parametrizations')
        assert parametrize.is_parametrized(model_load.linear, 'weight')
        assert hasattr(model_load.seq[1], 'parametrizations')
        assert parametrize.is_parametrized(model_load.linear, 'weight')

        # Check the weigths are preserved
        self.assertEqual(model_save.linear.parametrizations['weight'].original,
                         model_load.linear.parametrizations['weight'].original)
        self.assertEqual(model_save.seq[0].parametrizations['weight'].original,
                         model_load.seq[0].parametrizations['weight'].original)
        self.assertEqual(model_save.seq[1].parametrizations['weight'].original,
                         model_load.seq[1].parametrizations['weight'].original)

        # Check the masks are not preserved in the state_dict
        # We store the state_dicts in the sparsifier, not in the model itself.
        # TODO: Need to find a clean way of exporting the parametrized model
        self.assertNotEqual(
            model_save.linear.parametrizations['weight'][0].mask,
            model_load.linear.parametrizations['weight'][0].mask)
        self.assertNotEqual(
            model_save.seq[0].parametrizations['weight'][0].mask,
            model_load.seq[0].parametrizations['weight'][0].mask)
        self.assertNotEqual(
            model_save.seq[1].parametrizations['weight'][0].mask,
            model_load.seq[1].parametrizations['weight'][0].mask)
    def test_state_dict_preserved(self):
        model_save = ModelUnderTest(bias=False)

        mask = torch.eye(16)
        parametrize.register_parametrization(model_save.linear, 'weight',
                                             utils.FakeSparsity(mask))
        mask = torch.eye(16)
        parametrize.register_parametrization(model_save.seq[0], 'weight',
                                             utils.FakeSparsity(mask))
        mask = torch.eye(16)
        parametrize.register_parametrization(model_save.seq[1], 'weight',
                                             utils.FakeSparsity(mask))
        state_dict = model_save.state_dict()

        model_load = ModelUnderTest(bias=False)
        mask = torch.zeros(model_load.linear.weight.shape)
        parametrize.register_parametrization(model_load.linear, 'weight',
                                             utils.FakeSparsity(mask))
        mask = torch.zeros(model_load.seq[0].weight.shape)
        parametrize.register_parametrization(model_load.seq[0], 'weight',
                                             utils.FakeSparsity(mask))
        mask = torch.zeros(model_load.seq[1].weight.shape)
        parametrize.register_parametrization(model_load.seq[1], 'weight',
                                             utils.FakeSparsity(mask))
        model_load.load_state_dict(state_dict)

        # Check the parametrizations are preserved
        assert hasattr(model_load.linear, 'parametrizations')
        assert parametrize.is_parametrized(model_load.linear, 'weight')
        assert hasattr(model_load.seq[0], 'parametrizations')
        assert parametrize.is_parametrized(model_load.linear, 'weight')
        assert hasattr(model_load.seq[1], 'parametrizations')
        assert parametrize.is_parametrized(model_load.linear, 'weight')

        # Check the weigths are preserved
        self.assertEqual(model_save.linear.parametrizations['weight'].original,
                         model_load.linear.parametrizations['weight'].original)
        self.assertEqual(model_save.seq[0].parametrizations['weight'].original,
                         model_load.seq[0].parametrizations['weight'].original)
        self.assertEqual(model_save.seq[1].parametrizations['weight'].original,
                         model_load.seq[1].parametrizations['weight'].original)

        # Check the masks are preserved
        self.assertEqual(model_save.linear.parametrizations['weight'][0].mask,
                         model_load.linear.parametrizations['weight'][0].mask)
        self.assertEqual(model_save.seq[0].parametrizations['weight'][0].mask,
                         model_load.seq[0].parametrizations['weight'][0].mask)
        self.assertEqual(model_save.seq[1].parametrizations['weight'][0].mask,
                         model_load.seq[1].parametrizations['weight'][0].mask)