def _create_analyzers(self, layers: List[Module], layer_names: List[str], param_names: List[str]): return [ ModulePruningAnalyzer(layer, layer_name, param_name) for (layer, layer_name, param_name) in zip(layers, layer_names, param_names) ]
def initialize(self, module: Module, optimizer: Optimizer): """ Grab the params to control kernel sparsity for :param module: module to modify :param optimizer: optimizer to modify """ super().initialize(module, optimizer) param_names = (self._params if self._params != ALL_TOKEN and ALL_TOKEN not in self._params else ["re:.*"]) named_layers_and_params = get_named_layers_and_params_by_regex( module, param_names, params_strict=True, ) self._analyzers = [] for layer_name, layer, param_name, _ in named_layers_and_params: self._module_masks.append( ModuleParamPruningMask( layer, param_name, mask_creator=self._mask_creator, layer_name=layer_name, )) self._analyzers.append( ModulePruningAnalyzer(layer, layer_name, param_name)) if len(self._analyzers) == 0: raise ValueError( "Could not find any params matching {} in {}".format( self._params, self.__class__.__name__))
def initialize(self, module: Module, optimizer: Optimizer): """ Grab the params to control kernel sparsity for. :param module: module to modify :param optimizer: optimizer to modify """ super().initialize(module, optimizer) param_names = (self._params if self._params != ALL_TOKEN and ALL_TOKEN not in self._params else ["re:.*"]) named_layers_and_params = get_named_layers_and_params_by_regex( module, param_names, params_strict=True, ) self._analyzers = [] for layer_name, layer, param_name, _ in named_layers_and_params: self._module_masks.append( ModuleParamPruningMask(layer, param_name, layer_name=layer_name)) self._analyzers.append( ModulePruningAnalyzer(layer, layer_name, param_name))
def initialize(self, module: Module, optimizer: Optimizer): """ Grab the params to control kernel sparsity for :param module: module to modify :param optimizer: optimizer to modify """ super().initialize(module, optimizer) if self._params == ALL_TOKEN or ALL_TOKEN in self._params: param_names = ["re:.*"] elif self._params == ALL_PRUNABLE_TOKEN or ALL_PRUNABLE_TOKEN in self._params: param_names = [ name + ".weight" for (name, _) in get_prunable_layers(module) ] else: param_names = self._params named_layers_and_params = get_named_layers_and_params_by_regex( module, param_names, params_strict=True, ) layers = [ named_layer_param.layer for named_layer_param in named_layers_and_params ] param_names = [ named_layer_param.param_name for named_layer_param in named_layers_and_params ] layer_names = [ named_layer_param.layer_name for named_layer_param in named_layers_and_params ] self._module_masks = ModuleParamPruningMask( layers, param_names, layer_names=layer_names, global_sparsity=self._global_sparsity, ) self._analyzers = [] for layer_name, layer, param_name, _ in named_layers_and_params: self._analyzers.append( ModulePruningAnalyzer(layer, layer_name, param_name)) if len(self._analyzers) == 0: raise ValueError( "Could not find any params matching {} in {}".format( self._params, self.__class__.__name__))
def initialize(self, module: Module, optimizer: Optimizer): """ Grab the params to control kernel sparsity for. :param module: module to modify :param optimizer: optimizer to modify """ super().initialize(module, optimizer) if self._params == ALL_TOKEN or ALL_TOKEN in self._params: param_names = ["re:.*"] elif self._params == ALL_PRUNABLE_TOKEN or ALL_PRUNABLE_TOKEN in self._params: param_names = [ name + ".weight" for (name, _) in get_prunable_layers(module) ] else: param_names = self._params named_layers_and_params = get_named_layers_and_params_by_regex( module, param_names, params_strict=True, ) layers = [ named_layer_param.layer for named_layer_param in named_layers_and_params ] param_names = [ named_layer_param.param_name for named_layer_param in named_layers_and_params ] layer_names = [ named_layer_param.layer_name for named_layer_param in named_layers_and_params ] self._module_masks = ModuleParamPruningMask(layers, param_names, layer_names=layer_names) self._analyzers = [] for layer_name, layer, param_name, _ in named_layers_and_params: self._analyzers.append( ModulePruningAnalyzer(layer, layer_name, param_name))