def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]): schema_list = [deepcopy(NORMAL_SCHEMA), deepcopy(INTERNAL_SCHEMA)] for schema in schema_list: schema.update({SchemaOptional('rho'): And(float, lambda n: n > 0)}) schema_list.append(deepcopy(EXCLUDE_SCHEMA)) schema = CompressorSchema(schema_list, model, _logger) schema.validate(config_list)
def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]): schema_list = [deepcopy(NORMAL_SCHEMA), deepcopy(EXCLUDE_SCHEMA), deepcopy(INTERNAL_SCHEMA)] for sub_shcema in schema_list: sub_shcema[SchemaOptional('op_types')] = ['Conv2d', 'Linear'] schema = CompressorSchema(schema_list, model, _logger) schema.validate(config_list)
def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]): schema_list = [ deepcopy(NORMAL_SCHEMA), deepcopy(EXCLUDE_SCHEMA), deepcopy(INTERNAL_SCHEMA) ] schema = CompressorSchema(schema_list, model, _logger) schema.validate(config_list)
def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]): schema_list = [deepcopy(EXCLUDE_SCHEMA), deepcopy(INTERNAL_SCHEMA)] if self.mode == 'global': schema_list.append(deepcopy(GLOBAL_SCHEMA)) else: schema_list.append(deepcopy(NORMAL_SCHEMA)) for sub_shcema in schema_list: sub_shcema[SchemaOptional('op_types')] = ['BatchNorm2d'] schema = CompressorSchema(schema_list, model, _logger) schema.validate(config_list)
def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]): schema_list = [deepcopy(EXCLUDE_SCHEMA), deepcopy(INTERNAL_SCHEMA)] if self.mode == 'global': schema_list.append(deepcopy(GLOBAL_SCHEMA)) else: schema_list.append(deepcopy(NORMAL_SCHEMA)) for sub_shcema in schema_list: sub_shcema[SchemaOptional('op_types')] = ['Conv2d', 'Linear'] schema = CompressorSchema(schema_list, model, _logger) try: schema.validate(config_list) except SchemaError as e: if "Missing key: 'total_sparsity'" in str(e): _logger.error('`config_list` validation failed. If global mode is set in this pruner, `sparsity_per_layer` and `sparsity` are not supported, make sure `total_sparsity` is set in config_list.') raise e