def test_prepare(self): model = Model() sparsifier = NearlyDiagonalSparsifier(nearliness=1) sparsifier.prepare(model, config=None) for g in sparsifier.module_groups: module = g['module'] # Check mask exists assert hasattr(module.parametrizations['weight'][0], 'mask') # Check parametrization exists and is correct assert is_parametrized(module, 'weight') assert type(module.parametrizations.weight[0]) == FakeSparsity
def _check_on_train_end(self, pl_module, callback): """Confirms that the mask is squashed after the training ends This is achieved by making sure that each parameter in the internal container are not parametrized. """ callback.on_train_end(42, pl_module) # check that the masks have been squashed for name, _ in pl_module.model.named_parameters(): valid_name = _get_valid_name(name) assert not is_parametrized(callback.data_sparsifier._continer, valid_name)
def has_no_children_ignoring_parametrizations(module): """ Checks if module._modules is empty or if module is a parametrization, checks that module._modules only has the 'parametrizations' module """ if len(module._modules) == 0: return True elif is_parametrized(module): return len(module._modules) == 1 and 'parametrizations' in module._modules else: return False
def _check_pruner_prepared(self, model, pruner, device): for g in pruner.module_groups: module = g['module'] assert module.weight.device == device # Check mask exists assert hasattr(module, 'mask') # Check parametrization exists and is correct assert parametrize.is_parametrized(module) assert hasattr(module, "parametrizations") # Assume that this is the 1st/only parametrization assert type( module.parametrizations.weight[0]) == PruningParametrization
def check_state_dict(self, data_list, data_with_config, defaults, **kwargs): sparsifier1 = self._make_sparsifier(data_list, data_with_config, defaults=defaults, **kwargs) sparsifier2 = self._make_sparsifier(data_list=[data_list[0]], data_with_config=[], defaults=defaults, **kwargs) sparsifier1.step() state_dict1 = sparsifier1.state_dict() assert sparsifier1.state != sparsifier2.state name, _, _ = self._get_name_data_config(data_list[0]) self.assertNotEqual(sparsifier1.get_mask(name), sparsifier2.get_mask(name)) sparsifier2.load_state_dict(state_dict1) assert len(sparsifier1.state) == len(sparsifier2.state) assert len(sparsifier1.data_groups) == len(sparsifier2.data_groups) state1 = state_dict1['state'] for name in state1.keys(): # compare mask assert name in sparsifier2.state assert 'mask' in sparsifier2.state[name] assert 'mask' in sparsifier1.state[name] mask1, mask2 = state1[name]['mask'], sparsifier2.state[name]['mask'] assert mask1.is_sparse and not mask2.is_sparse assert torch.all(mask1.to_dense() == mask2) # mask1 is stored as sparse coo now # compare data_groups dg1, dg2 = sparsifier1.data_groups, sparsifier2.data_groups assert name in dg1 and name in dg2 assert dg1[name] == dg2[name] # compare container container1, container2 = sparsifier1._container, sparsifier2._container assert torch.all(getattr(container1, name) == getattr(container2, name)) assert is_parametrized(container1, name) == is_parametrized(container2, name) if is_parametrized(container1, name): param1 = getattr(container1.parametrizations, name)[0] param2 = getattr(container2.parametrizations, name)[0] assert hasattr(param1, 'mask') assert hasattr(param2, 'mask') self.assertEqual(param1.__dict__, param2.__dict__)
def test_mask_squash(self): model = Model() sparsifier = NearlyDiagonalSparsifier(nearliness=1) sparsifier.prepare(model, config=None) sparsifier.step() sparsifier.squash_mask() for g in sparsifier.groups: module = g['module'] assert not is_parametrized(module, 'weight') assert not hasattr(module, 'mask') weights = module.weight height, width = weights.shape assert torch.all(weights == torch.eye(height, width) * weights) # only diagonal to be present
def test_weights_parametrized(self): model = ModelUnderTest(bias=False) assert not hasattr(model.linear, 'parametrizations') assert not hasattr(model.seq[0], 'parametrizations') assert not hasattr(model.seq[1], 'parametrizations') mask = torch.eye(16) parametrize.register_parametrization(model.linear, 'weight', utils.FakeSparsity(mask)) mask = torch.eye(16) parametrize.register_parametrization(model.seq[0], 'weight', utils.FakeSparsity(mask)) mask = torch.eye(16) parametrize.register_parametrization(model.seq[1], 'weight', utils.FakeSparsity(mask)) assert hasattr(model.linear, 'parametrizations') assert parametrize.is_parametrized(model.linear, 'weight') assert hasattr(model.seq[0], 'parametrizations') assert parametrize.is_parametrized(model.linear, 'weight') assert hasattr(model.seq[1], 'parametrizations') assert parametrize.is_parametrized(model.linear, 'weight')
def test_prepare(self): model = Model() pruner = ImplementedPruner(model, None, None) pruner.prepare() for g in pruner.module_groups: module = g['module'] # Check mask exists assert hasattr(module, 'mask') # Check parametrization exists and is correct assert parametrize.is_parametrized(module) assert hasattr(module, "parametrizations") assert type( module.parametrizations.weight[0]) == PruningParametrization
def test_mask_squash_with_params1(self): model = Model() sparsifier = ImplementedSparsifier(foo=3, bar=2, baz=1) sparsifier.prepare(model, [{ 'tensor_fqn': 'linear.weight' }, { 'tensor_fqn': 'seq.0.weight' }]) sparsifier.squash_mask(params_to_keep_per_layer={ 'linear': ('foo', 'bar'), 'seq.0': ('baz', ) }) assert not is_parametrized(model.seq[0], 'weight') assert not is_parametrized(model.linear, 'weight') assert hasattr(model.seq[0], 'sparse_params') assert hasattr(model.linear, 'sparse_params') assert model.seq[0].sparse_params.get('foo', None) is None assert model.seq[0].sparse_params.get('bar', None) is None assert model.seq[0].sparse_params.get('baz', None) == 1 assert model.linear.sparse_params.get('foo', None) == 3 assert model.linear.sparse_params.get('bar', None) == 2 assert model.linear.sparse_params.get('baz', None) is None
def test_prepare_bias(self): model = ModelB() pruner = SimplePruner(model, None, None) pruner.prepare() for g in pruner.module_groups: module = g['module'] # Check mask exists assert hasattr(module, 'mask') # Check parametrization exists and is correct assert parametrize.is_parametrized(module) assert hasattr(module, "parametrizations") # Assume that this is the 1st/only parametrization assert type( module.parametrizations.weight[0]) == PruningParametrization
def get_module_pruned_outputs(self, module): r"""Returns the set of pruned indices of module""" assert parametrize.is_parametrized( module) # can only get pruned indices of pruned module modules = {config['module'] for config in self.module_groups} module_list = set() for m in modules: if type(m) is tuple: module_list.update(m) else: module_list.add(m) assert module in module_list # check that module is in pruner.module_groups return module.parametrizations.weight[ 0].pruned_outputs # assume only one parametrization attached
def get_data(self, name: str, return_original: bool = True): r"""Returns weight tensor (or data) Args: - name: name of the data to be returned - return_original returns weight tensor without applying parametrization if True else - returns the sparsified version (parametrized) """ if name not in self.data_groups: raise ValueError("data with specified name does not exist") if return_original: if not parametrize.is_parametrized(self._container, name): raise ValueError("mask squashed - original mask value does not exist") data = getattr(self._container.parametrizations, name).original return data else: return getattr(self._container, name)
def _check_pruner_prepared(self, model, pruner, device): for config in pruner.module_groups: modules = [] if type(config['module']) is tuple: for module in config['module']: modules.append(module) else: module = config['module'] modules.append(module) for module in modules: assert module.weight.device == device # Check mask exists assert hasattr(module, 'mask') # Check parametrization exists and is correct assert parametrize.is_parametrized(module) assert hasattr(module, "parametrizations") # Assume that this is the 1st/only parametrization if isinstance(module, tuple(NEEDS_ZEROS)): assert type(module.parametrizations.weight[0]) == ZeroesParametrization else: assert type(module.parametrizations.weight[0]) == PruningParametrization