コード例 #1
0
    def squash_mask(self, use_path=False, *args, **kwargs):
        for config in self.module_groups:
            modules = []
            if use_path:
                if type(config['module']) is tuple:  # (Conv2d, BN)
                    for fqn in config['fqn']:
                        module = fqn_to_module(self.model, fqn)
                        modules.append(module)
                else:
                    module = fqn_to_module(self.model, config['fqn'])
                    modules.append(module)
            else:
                if type(config['module']) is tuple:
                    for module in config['module']:
                        modules.append(module)
                else:
                    module = config['module']
                    modules.append(module)

            for module in modules:
                parametrize.remove_parametrizations(module,
                                                    'weight',
                                                    leave_parametrized=True)
                if getattr(module._parameters, 'mask', None):
                    del module._parameters['mask']
                elif getattr(module._buffers, 'mask', None):
                    del module._buffers['mask']
                delattr(module, 'mask')
コード例 #2
0
ファイル: test_subclass.py プロジェクト: yuguo68/pytorch
    def test_parametrization(self, tensor_cls, leave_parametrized):
        # TODO: Either implement set_() properly for these tensor subclasses or apply a
        # more general fix to avoid the need for special set_() handling. For now, skip
        # testing these as they're expected to fail.
        if tensor_cls in [LoggingTensor, DiagTensorBelow]:
            return

        create_fn = partial(self._create_tensor, tensor_cls)

        class MyModule(nn.Module):
            def __init__(self):
                super().__init__()
                self.weight = nn.Parameter(create_fn())

            def forward(self, x):
                return self.weight + x

        class MyParametrization(nn.Module):
            def forward(self, X):
                return -X

        m = MyModule()
        self.assertEqual(len(m.state_dict()), 1)
        register_parametrization(m, 'weight', MyParametrization())
        self.assertIsInstance(m.weight, tensor_cls)
        output = m(self._create_tensor(torch.Tensor))
        self.assertIsInstance(output, tensor_cls)
        remove_parametrizations(m,
                                'weight',
                                leave_parametrized=leave_parametrized)
コード例 #3
0
ファイル: __init__.py プロジェクト: ain-soph/trojanzoo
 def parametrize_(self, parametrize: bool = True):
     if parametrize:
         if not self.parametrize:
             P.register_parametrization(self, 'weight', Std())
     elif self.parametrize:
         P.remove_parametrizations(self, 'weight')
     self.parametrize = parametrize
     return self
コード例 #4
0
 def squash_mask(self, *args, leave_parametrized=True, names=None, **kwargs):
     r"""Squashes the sparse masks into the appropriate tensors. Also, accepts list of strings
     to squash mask for. If none, squashes mask for all the keys
     kwargs:
         * names: list of strings to squash mask for
         * sparsified: if true - applies the mask before squashing
                       if false - does not apply the mask before squashing
     """
     if names is None:
         names = list(self.data_groups.keys())
     for name in names:
         parametrize.remove_parametrizations(self._container, name, leave_parametrized=leave_parametrized)
コード例 #5
0
 def convert(self, use_path=False, *args, **kwargs):
     for config in self.module_groups:
         if use_path:
             module = _path_to_module(self.model, config['path'])
         else:
             module = config['module']
         parametrize.remove_parametrizations(module,
                                             'weight',
                                             leave_parametrized=True)
         if getattr(module._parameters, 'mask', None):
             del module._parameters['mask']
         elif getattr(module._buffers, 'mask', None):
             del module._buffers['mask']
         delattr(module, 'mask')
コード例 #6
0
ファイル: base_pruner.py プロジェクト: timgates42/pytorch
    def squash_mask(self, use_path=False, *args, **kwargs):
        for config in self.groups:
            modules, tensor_names = self._get_modules_and_tensor_names(
                config, use_path)

            for module, tensor_name in zip(modules, tensor_names):
                parametrize.remove_parametrizations(module,
                                                    tensor_name,
                                                    leave_parametrized=True)
                if getattr(module._parameters, 'mask', None):
                    del module._parameters['mask']
                elif getattr(module._buffers, 'mask', None):
                    del module._buffers['mask']
                delattr(module, 'mask')
コード例 #7
0
ファイル: base_sparsifier.py プロジェクト: yanboliang/pytorch
    def squash_mask(
            self,
            params_to_keep: Optional[Tuple[str, ...]] = None,
            params_to_keep_per_layer: Optional[Dict[str, Tuple[str,
                                                               ...]]] = None,
            *args,
            **kwargs):
        r"""Squashes the sparse masks into the appropriate tensors.

        If either the `params_to_keep` or `params_to_keep_per_layer` is set,
        the module will have a `sparse_params` dict attached to it.

        Args:
            params_to_keep: List of keys to save in the module or a dict
                            representing the modules and keys that will have
                            sparsity parameters saved
            params_to_keep_per_layer: Dict to specify the params that should be
                            saved for specific layers. The keys in the dict
                            should be the module fqn, while the values should
                            be a list of strings with the names of the variables
                            to save in the `sparse_params`

        Examples:
            >>> # Don't save any sparse params
            >>> sparsifier.squash_mask()
            >>> hasattr(model.submodule1, 'sparse_params')
            False

            >>> # Keep sparse params per layer
            >>> sparsifier.squash_mask(
            ...     params_to_keep_per_layer={
            ...         'submodule1.linear1': ('foo', 'bar'),
            ...         'submodule2.linear42': ('baz',)
            ...     })
            >>> print(model.submodule1.linear1.sparse_params)
            {'foo': 42, 'bar': 24}
            >>> print(model.submodule2.linear42.sparse_params)
            {'baz': 0.1}

            >>> # Keep sparse params for all layers
            >>> sparsifier.squash_mask(params_to_keep=('foo', 'bar'))
            >>> print(model.submodule1.linear1.sparse_params)
            {'foo': 42, 'bar': 24}
            >>> print(model.submodule2.linear42.sparse_params)
            {'foo': 42, 'bar': 24}

            >>> # Keep some sparse params for all layers, and specific ones for
            >>> # some other layers
            >>> sparsifier.squash_mask(
            ...     params_to_keep=('foo', 'bar'),
            ...     params_to_keep_per_layer={
            ...         'submodule2.linear42': ('baz',)
            ...     })
            >>> print(model.submodule1.linear1.sparse_params)
            {'foo': 42, 'bar': 24}
            >>> print(model.submodule2.linear42.sparse_params)
            {'foo': 42, 'bar': 24, 'baz': 0.1}
        """
        for config in self.groups:
            module = config['module']
            tensor_name = config['tensor_name']
            parametrize.remove_parametrizations(module,
                                                tensor_name,
                                                leave_parametrized=True)
            sparse_params = dict()
            if params_to_keep is not None:
                global_params = {k: config[k] for k in params_to_keep}
                sparse_params.update(global_params)
            if params_to_keep_per_layer is not None:
                params = params_to_keep_per_layer.get(config["module_fqn"],
                                                      None)
                if params is not None:
                    per_layer_params = {k: config[k] for k in params}
                    sparse_params.update(per_layer_params)
            if sparse_params:
                # TODO handle multiple tensor being quantized on a single module, where to store sparse_params?
                module.sparse_params = sparse_params
コード例 #8
0
print(layer.weight)

# Removing a Parametrization
# --------------------------
#
# We may remove all the parametrizations from a parameter or a buffer in a module
# by using ``parametrize.remove_parametrizations()``
layer = nn.Linear(3, 3)
print("Before:")
print(layer)
print(layer.weight)
parametrize.register_parametrization(layer, "weight", Skew())
print("Parametrized:")
print(layer)
print(layer.weight)
parametrize.remove_parametrizations(layer, "weight")
print("After. Weight has skew-symmetric values but it is unconstrained:")
print(layer)
print(layer.weight)

###############################################################################
# When removing a parametrization, we may choose to leave the original parameter (i.e. that in
# ``layer.parametriations.weight.original``) rather than its parametrized version by setting
# the flag ``leave_parametrized=False``
layer = nn.Linear(3, 3)
print("Before:")
print(layer)
print(layer.weight)
parametrize.register_parametrization(layer, "weight", Skew())
print("Parametrized:")
print(layer)
コード例 #9
0
ファイル: base_sparsifier.py プロジェクト: zacker150/pytorch
 def squash_mask(self, *args, **kwargs):
     for config in self.module_groups:
         module = config['module']
         parametrize.remove_parametrizations(module,
                                             'weight',
                                             leave_parametrized=True)
コード例 #10
0
 def create_transform(self):
     parametrize.remove_parametrizations(self, 'magnitude')
     self.requires_grad_(False)
     self.cpu()
     self.eval()
     return self