Exemplo n.º 1
0
 def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]):
     schema_list = [deepcopy(NORMAL_SCHEMA), deepcopy(INTERNAL_SCHEMA)]
     for schema in schema_list:
         schema.update({SchemaOptional('rho'): And(float, lambda n: n > 0)})
     schema_list.append(deepcopy(EXCLUDE_SCHEMA))
     schema = CompressorSchema(schema_list, model, _logger)
     schema.validate(config_list)
Exemplo n.º 2
0
    def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]):
        schema_list = [deepcopy(NORMAL_SCHEMA), deepcopy(EXCLUDE_SCHEMA), deepcopy(INTERNAL_SCHEMA)]
        for sub_shcema in schema_list:
            sub_shcema[SchemaOptional('op_types')] = ['Conv2d', 'Linear']
        schema = CompressorSchema(schema_list, model, _logger)

        schema.validate(config_list)
Exemplo n.º 3
0
 def _validate_config_before_canonical(self, model: Module,
                                       config_list: List[Dict]):
     schema_list = [
         deepcopy(NORMAL_SCHEMA),
         deepcopy(EXCLUDE_SCHEMA),
         deepcopy(INTERNAL_SCHEMA)
     ]
     schema = CompressorSchema(schema_list, model, _logger)
     schema.validate(config_list)
Exemplo n.º 4
0
    def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]):
        schema_list = [deepcopy(EXCLUDE_SCHEMA), deepcopy(INTERNAL_SCHEMA)]
        if self.mode == 'global':
            schema_list.append(deepcopy(GLOBAL_SCHEMA))
        else:
            schema_list.append(deepcopy(NORMAL_SCHEMA))
        for sub_shcema in schema_list:
            sub_shcema[SchemaOptional('op_types')] = ['BatchNorm2d']
        schema = CompressorSchema(schema_list, model, _logger)

        schema.validate(config_list)
Exemplo n.º 5
0
    def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]):
        schema_list = [deepcopy(EXCLUDE_SCHEMA), deepcopy(INTERNAL_SCHEMA)]
        if self.mode == 'global':
            schema_list.append(deepcopy(GLOBAL_SCHEMA))
        else:
            schema_list.append(deepcopy(NORMAL_SCHEMA))
        for sub_shcema in schema_list:
            sub_shcema[SchemaOptional('op_types')] = ['Conv2d', 'Linear']
        schema = CompressorSchema(schema_list, model, _logger)

        try:
            schema.validate(config_list)
        except SchemaError as e:
            if "Missing key: 'total_sparsity'" in str(e):
                _logger.error('`config_list` validation failed. If global mode is set in this pruner, `sparsity_per_layer` and `sparsity` are not supported, make sure `total_sparsity` is set in config_list.')
            raise e