예제 #1
0
 def test_prepare(self):
     model = Model()
     sparsifier = NearlyDiagonalSparsifier(nearliness=1)
     sparsifier.prepare(model, config=None)
     for g in sparsifier.module_groups:
         module = g['module']
         # Check mask exists
         assert hasattr(module.parametrizations['weight'][0], 'mask')
         # Check parametrization exists and is correct
         assert is_parametrized(module, 'weight')
         assert type(module.parametrizations.weight[0]) == FakeSparsity
예제 #2
0
    def _check_on_train_end(self, pl_module, callback):
        """Confirms that the mask is squashed after the training ends
        This is achieved by making sure that each parameter in the internal container
        are not parametrized.
        """
        callback.on_train_end(42, pl_module)

        # check that the masks have been squashed
        for name, _ in pl_module.model.named_parameters():
            valid_name = _get_valid_name(name)
            assert not is_parametrized(callback.data_sparsifier._continer,
                                       valid_name)
예제 #3
0
def has_no_children_ignoring_parametrizations(module):
    """
    Checks if module._modules is empty or
    if module is a parametrization, checks that module._modules only has
    the 'parametrizations' module
    """
    if len(module._modules) == 0:
        return True
    elif is_parametrized(module):
        return len(module._modules) == 1 and 'parametrizations' in module._modules
    else:
        return False
예제 #4
0
파일: test_pruner.py 프로젝트: vors/pytorch
 def _check_pruner_prepared(self, model, pruner, device):
     for g in pruner.module_groups:
         module = g['module']
         assert module.weight.device == device
         # Check mask exists
         assert hasattr(module, 'mask')
         # Check parametrization exists and is correct
         assert parametrize.is_parametrized(module)
         assert hasattr(module, "parametrizations")
         # Assume that this is the 1st/only parametrization
         assert type(
             module.parametrizations.weight[0]) == PruningParametrization
예제 #5
0
    def check_state_dict(self, data_list, data_with_config, defaults, **kwargs):
        sparsifier1 = self._make_sparsifier(data_list, data_with_config, defaults=defaults, **kwargs)
        sparsifier2 = self._make_sparsifier(data_list=[data_list[0]], data_with_config=[], defaults=defaults, **kwargs)
        sparsifier1.step()

        state_dict1 = sparsifier1.state_dict()

        assert sparsifier1.state != sparsifier2.state
        name, _, _ = self._get_name_data_config(data_list[0])
        self.assertNotEqual(sparsifier1.get_mask(name), sparsifier2.get_mask(name))

        sparsifier2.load_state_dict(state_dict1)
        assert len(sparsifier1.state) == len(sparsifier2.state)
        assert len(sparsifier1.data_groups) == len(sparsifier2.data_groups)

        state1 = state_dict1['state']
        for name in state1.keys():
            # compare mask
            assert name in sparsifier2.state
            assert 'mask' in sparsifier2.state[name]
            assert 'mask' in sparsifier1.state[name]
            mask1, mask2 = state1[name]['mask'], sparsifier2.state[name]['mask']
            assert mask1.is_sparse and not mask2.is_sparse
            assert torch.all(mask1.to_dense() == mask2)  # mask1 is stored as sparse coo now

            # compare data_groups
            dg1, dg2 = sparsifier1.data_groups, sparsifier2.data_groups
            assert name in dg1 and name in dg2
            assert dg1[name] == dg2[name]

            # compare container
            container1, container2 = sparsifier1._container, sparsifier2._container
            assert torch.all(getattr(container1, name) == getattr(container2, name))
            assert is_parametrized(container1, name) == is_parametrized(container2, name)
            if is_parametrized(container1, name):
                param1 = getattr(container1.parametrizations, name)[0]
                param2 = getattr(container2.parametrizations, name)[0]
                assert hasattr(param1, 'mask')
                assert hasattr(param2, 'mask')
                self.assertEqual(param1.__dict__, param2.__dict__)
예제 #6
0
 def test_mask_squash(self):
     model = Model()
     sparsifier = NearlyDiagonalSparsifier(nearliness=1)
     sparsifier.prepare(model, config=None)
     sparsifier.step()
     sparsifier.squash_mask()
     for g in sparsifier.groups:
         module = g['module']
         assert not is_parametrized(module, 'weight')
         assert not hasattr(module, 'mask')
         weights = module.weight
         height, width = weights.shape
         assert torch.all(weights == torch.eye(height, width) * weights)  # only diagonal to be present
예제 #7
0
    def test_weights_parametrized(self):
        model = ModelUnderTest(bias=False)

        assert not hasattr(model.linear, 'parametrizations')
        assert not hasattr(model.seq[0], 'parametrizations')
        assert not hasattr(model.seq[1], 'parametrizations')
        mask = torch.eye(16)
        parametrize.register_parametrization(model.linear, 'weight',
                                             utils.FakeSparsity(mask))
        mask = torch.eye(16)
        parametrize.register_parametrization(model.seq[0], 'weight',
                                             utils.FakeSparsity(mask))
        mask = torch.eye(16)
        parametrize.register_parametrization(model.seq[1], 'weight',
                                             utils.FakeSparsity(mask))

        assert hasattr(model.linear, 'parametrizations')
        assert parametrize.is_parametrized(model.linear, 'weight')
        assert hasattr(model.seq[0], 'parametrizations')
        assert parametrize.is_parametrized(model.linear, 'weight')
        assert hasattr(model.seq[1], 'parametrizations')
        assert parametrize.is_parametrized(model.linear, 'weight')
예제 #8
0
 def test_prepare(self):
     model = Model()
     pruner = ImplementedPruner(model, None, None)
     pruner.prepare()
     for g in pruner.module_groups:
         module = g['module']
         # Check mask exists
         assert hasattr(module, 'mask')
         # Check parametrization exists and is correct
         assert parametrize.is_parametrized(module)
         assert hasattr(module, "parametrizations")
         assert type(
             module.parametrizations.weight[0]) == PruningParametrization
예제 #9
0
 def test_mask_squash_with_params1(self):
     model = Model()
     sparsifier = ImplementedSparsifier(foo=3, bar=2, baz=1)
     sparsifier.prepare(model, [{
         'tensor_fqn': 'linear.weight'
     }, {
         'tensor_fqn': 'seq.0.weight'
     }])
     sparsifier.squash_mask(params_to_keep_per_layer={
         'linear': ('foo', 'bar'),
         'seq.0': ('baz', )
     })
     assert not is_parametrized(model.seq[0], 'weight')
     assert not is_parametrized(model.linear, 'weight')
     assert hasattr(model.seq[0], 'sparse_params')
     assert hasattr(model.linear, 'sparse_params')
     assert model.seq[0].sparse_params.get('foo', None) is None
     assert model.seq[0].sparse_params.get('bar', None) is None
     assert model.seq[0].sparse_params.get('baz', None) == 1
     assert model.linear.sparse_params.get('foo', None) == 3
     assert model.linear.sparse_params.get('bar', None) == 2
     assert model.linear.sparse_params.get('baz', None) is None
예제 #10
0
 def test_prepare_bias(self):
     model = ModelB()
     pruner = SimplePruner(model, None, None)
     pruner.prepare()
     for g in pruner.module_groups:
         module = g['module']
         # Check mask exists
         assert hasattr(module, 'mask')
         # Check parametrization exists and is correct
         assert parametrize.is_parametrized(module)
         assert hasattr(module, "parametrizations")
         # Assume that this is the 1st/only parametrization
         assert type(
             module.parametrizations.weight[0]) == PruningParametrization
예제 #11
0
 def get_module_pruned_outputs(self, module):
     r"""Returns the set of pruned indices of module"""
     assert parametrize.is_parametrized(
         module)  # can only get pruned indices of pruned module
     modules = {config['module'] for config in self.module_groups}
     module_list = set()
     for m in modules:
         if type(m) is tuple:
             module_list.update(m)
         else:
             module_list.add(m)
     assert module in module_list  # check that module is in pruner.module_groups
     return module.parametrizations.weight[
         0].pruned_outputs  # assume only one parametrization attached
예제 #12
0
    def get_data(self, name: str, return_original: bool = True):
        r"""Returns weight tensor (or data)
        Args:
            - name: name of the data to be returned
            - return_original returns weight tensor without applying parametrization if True
                else - returns the sparsified version (parametrized)
        """
        if name not in self.data_groups:
            raise ValueError("data with specified name does not exist")

        if return_original:
            if not parametrize.is_parametrized(self._container, name):
                raise ValueError("mask squashed - original mask value does not exist")
            data = getattr(self._container.parametrizations, name).original
            return data
        else:
            return getattr(self._container, name)
예제 #13
0
 def _check_pruner_prepared(self, model, pruner, device):
     for config in pruner.module_groups:
         modules = []
         if type(config['module']) is tuple:
             for module in config['module']:
                 modules.append(module)
         else:
             module = config['module']
             modules.append(module)
         for module in modules:
             assert module.weight.device == device
             # Check mask exists
             assert hasattr(module, 'mask')
             # Check parametrization exists and is correct
             assert parametrize.is_parametrized(module)
             assert hasattr(module, "parametrizations")
             # Assume that this is the 1st/only parametrization
             if isinstance(module, tuple(NEEDS_ZEROS)):
                 assert type(module.parametrizations.weight[0]) == ZeroesParametrization
             else:
                 assert type(module.parametrizations.weight[0]) == PruningParametrization