Ejemplo n.º 1
0
    def validate_config(self, model, config_list):
        """
        Parameters
        ----------
        model : torch.nn.Module
            Model to be pruned
        config_list : list of dict
            List of configurations
        """
        schema = CompressorSchema([{
            Optional('quant_types'):
            Schema([lambda x: x in ['weight', 'output']]),
            Optional('quant_bits'):
            Or(
                And(int, lambda n: 0 < n < 32),
                Schema({
                    Optional('weight'): And(int, lambda n: 0 < n < 32),
                    Optional('output'): And(int, lambda n: 0 < n < 32),
                })),
            Optional('quant_start_step'):
            And(int, lambda n: n >= 0),
            Optional('op_types'): [str],
            Optional('op_names'): [str]
        }], model, logger)

        schema.validate(config_list)
Ejemplo n.º 2
0
    def validate_config(self, model, config_list):
        schema = CompressorSchema([{
            Optional('quant_types'): ['weight'],
            Optional('quant_bits'): Or(8, {'weight': 8}),
            Optional('op_types'): [str],
            Optional('op_names'): [str]
        }], model, logger)

        schema.validate(config_list)
Ejemplo n.º 3
0
    def validate_config(self, model, config_list):
        schema = CompressorSchema([{
            'sparsity': And(float, lambda n: 0 < n < 1),
            'op_types': ['BatchNorm2d'],
            Optional('op_names'): [str]
        }], model, logger)

        schema.validate(config_list)

        if len(config_list) > 1:
            logger.warning('Slim pruner only supports 1 configuration')
Ejemplo n.º 4
0
    def validate_config(self, model, config_list):
        schema = CompressorSchema(
            [{
                Optional('sparsity'): And(float, lambda n: 0 < n < 1),
                Optional('op_types'): ['Conv2d'],
                Optional('op_names'): [str],
                Optional('exclude'): bool
            }], model, logger)

        schema.validate(config_list)
        for config in config_list:
            if 'exclude' not in config and 'sparsity' not in config:
                raise SchemaError(
                    'Either sparisty or exclude must be specified!')
Ejemplo n.º 5
0
    def validate_config(self, model, config_list):
        """
        Parameters
        ----------
        model : torch.nn.Module
            Model to be pruned
        config_list : list
            List on pruning configs
        """
        schema = CompressorSchema([{
            'sparsity': And(float, lambda n: 0 < n < 1),
            Optional('op_types'): [str],
            Optional('op_names'): [str]
        }], model, logger)

        schema.validate(config_list)
Ejemplo n.º 6
0
    def validate_config(self, model, config_list):
        """
        Parameters
        ----------
        model : torch.nn.Module
            Model to be pruned
        config_list : list
            Supported keys:
                - prune_iterations : The number of rounds for the iterative pruning.
                - sparsity : The final sparsity when the compression is done.
        """
        schema = CompressorSchema([{
            'sparsity': And(float, lambda n: 0 < n < 1),
            'prune_iterations': And(int, lambda n: n > 0),
            Optional('op_types'): [str],
            Optional('op_names'): [str]
        }], model, logger)

        schema.validate(config_list)
        assert len(set([x['prune_iterations'] for x in config_list])) == 1, 'The values of prune_iterations must be equal in your config'
Ejemplo n.º 7
0
    def validate_config(self, model, config_list):
        """
        Parameters
        ----------
        model : torch.nn.Module
            Model to be pruned
        config_list : list
            List on pruning configs
        """
        schema = CompressorSchema(
            [{
                'initial_sparsity': And(float, lambda n: 0 <= n <= 1),
                'final_sparsity': And(float, lambda n: 0 <= n <= 1),
                'start_epoch': And(int, lambda n: n >= 0),
                'end_epoch': And(int, lambda n: n >= 0),
                'frequency': And(int, lambda n: n > 0),
                Optional('op_types'): [str],
                Optional('op_names'): [str]
            }], model, logger)

        schema.validate(config_list)
Ejemplo n.º 8
0
    def validate_config(self, model, config_list):
        """
        Parameters
        ----------
        model : torch.nn.Module
            Model to be pruned
        config_list : list
            List on pruning configs
        """

        if self._base_algo == 'level':
            schema = CompressorSchema(
                [{
                    'sparsity': And(float, lambda n: 0 < n < 1),
                    Optional('op_types'): [str],
                    Optional('op_names'): [str],
                }], model, _logger)
        elif self._base_algo in ['l1', 'l2', 'fpgm']:
            schema = CompressorSchema(
                [{
                    'sparsity': And(float, lambda n: 0 < n < 1),
                    'op_types': ['Conv2d'],
                    Optional('op_names'): [str]
                }], model, _logger)

        schema.validate(config_list)