def validate_config(self, model, config_list):
        schema = PrunerSchema([{
            Optional('sparsity'): And(float, lambda n: 0 < n < 1),
            Optional('op_types'): ['Conv2d'],
            Optional('op_names'): [str],
            Optional('exclude'): bool
        }], model, logger)

        schema.validate(config_list)
Beispiel #2
0
    def validate_config(self, model, config_list):
        schema = PrunerSchema([{
            Optional('sparsity'): And(float, lambda n: 0 < n < 1),
            'op_types': ['BatchNorm2d'],
            Optional('op_names'): [str],
            Optional('exclude'): bool
        }], model, logger)

        schema.validate(config_list)

        if len(config_list) > 1:
            logger.warning('Slim pruner only supports 1 configuration')
Beispiel #3
0
    def validate_config(self, model, config_list):
        """
        Parameters
        ----------
        model : torch.nn.Module
            Model to be pruned
        config_list : list
            List on pruning configs
        """
        schema = PrunerSchema([{
            Optional('sparsity'): And(float, lambda n: 0 <= n <= 1),
            Optional('op_types'): [str],
            Optional('op_names'): [str],
            Optional('exclude'): bool
        }], model, logger)

        schema.validate(config_list)
Beispiel #4
0
    def validate_config(self, model, config_list):
        """
        Parameters
        ----------
        model : torch.nn.module
            Model to be pruned
        config_list : list
            List on pruning configs
        """

        if self.base_algo == 'level':
            schema = PrunerSchema(
                [{
                    Optional('sparsity'): And(float, lambda n: 0 < n < 1),
                    Optional('op_types'): [str],
                    Optional('op_names'): [str],
                    Optional('exclude'): bool
                }], model, _logger)
        elif self.base_algo in ['l1', 'l2', 'fpgm']:
            schema = PrunerSchema(
                [{
                    Optional('sparsity'): And(float, lambda n: 0 < n < 1),
                    'op_types': ['Conv2d'],
                    Optional('op_names'): [str],
                    Optional('exclude'): bool
                }], model, _logger)

        schema.validate(config_list)
Beispiel #5
0
    def validate_config(self, model, config_list):
        """
        Parameters
        ----------
        model : torch.nn.Module
            Model to be pruned
        config_list : list
            Supported keys:
                - prune_iterations : The number of rounds for the iterative pruning.
                - sparsity : The final sparsity when the compression is done.
        """
        schema = PrunerSchema(
            [{
                Optional('sparsity'): And(float, lambda n: 0 < n < 1),
                'prune_iterations': And(int, lambda n: n > 0),
                Optional('op_types'): [str],
                Optional('op_names'): [str],
                Optional('exclude'): bool
            }], model, logger)

        schema.validate(config_list)
        assert len(set([
            x['prune_iterations'] for x in config_list
        ])) == 1, 'The values of prune_iterations must be equal in your config'