def rank_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model): assert param.dim() == 4, "This thresholding is only supported for 4D weights" # Use the parameter name to locate the module that has the activation sparsity statistics fq_name = param_name.replace(".conv", ".relu")[:-len(".weight")] module = distiller.find_module_by_fq_name(model, fq_name) if module is None: raise ValueError("Could not find a layer named %s in the model." "\nMake sure to use assign_layer_fq_names()" % fq_name) if not hasattr(module, 'apoz_channels'): raise ValueError("Could not find attribute \'apoz_channels\' in module %s." "\nMake sure to use SummaryActivationStatsCollector(\"apoz_channels\")" % fq_name) apoz, std = module.apoz_channels.value() num_filters = param.size(0) num_filters_to_prune = int(fraction_to_prune * num_filters) if num_filters_to_prune == 0: msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune) return # Sort from high to low, and remove the bottom 'num_filters_to_prune' filters filters_ordered_by_apoz = np.argsort(-apoz)[:-num_filters_to_prune] zeros_mask_dict[param_name].mask = RankedFiltersParameterPruner.mask_from_filter_order(filters_ordered_by_apoz, param, num_filters) msglogger.info("ActivationL1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name, distiller.sparsity_3D(zeros_mask_dict[param_name].mask), fraction_to_prune, num_filters_to_prune, num_filters)
def rank_and_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map=None): assert param.dim() == 4, "This pruning is only supported for 4D weights" # Use the parameter name to locate the module that has the activation sparsity statistics fq_name = param_name.replace(".conv", ".relu")[:-len(".weight")] module = distiller.find_module_by_fq_name(model, fq_name) if module is None: raise ValueError("Could not find a layer named %s in the model." "\nMake sure to use assign_layer_fq_names()" % fq_name) if not hasattr(module, self.activation_rank_criterion): raise ValueError("Could not find attribute \"{}\" in module %s" "\nMake sure to use SummaryActivationStatsCollector(\"{}\")". format(self.activation_rank_criterion, fq_name, self.activation_rank_criterion)) quality_criterion, std = getattr(module, self.activation_rank_criterion).value() num_filters = param.size(0) num_filters_to_prune = int(fraction_to_prune * num_filters) if num_filters_to_prune == 0: msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune) return # Sort from low to high, and remove the bottom 'num_filters_to_prune' filters filters_ordered_by_criterion = np.argsort(quality_criterion)[:-num_filters_to_prune] mask, binary_map = _mask_from_filter_order(filters_ordered_by_criterion, param, num_filters, binary_map) zeros_mask_dict[param_name].mask = mask msglogger.info("ActivationL1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name, distiller.sparsity_3D(zeros_mask_dict[param_name].mask), fraction_to_prune, num_filters_to_prune, num_filters) return binary_map
def rank_and_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map=None): assert param.dim( ) == 4, "This pruning is only supported for 4D weights" # Use the parameter name to locate the module that has the activation sparsity statistics fq_name = param_name.replace(".conv", ".relu")[:-len(".weight")] distiller.assign_layer_fq_names(model) module = distiller.find_module_by_fq_name(model, fq_name) assert module is not None if not hasattr(module, self.activation_rank_criterion): raise ValueError( "Could not find attribute \"%s\" in module %s\n" "\tThis is pruner uses activation statistics collected during forward-" "passes of the network.\n" "\tThis error is an indication that these statistics " "have not been collected yet.\n" "\tMake sure to use SummaryActivationStatsCollector(\"%s\")\n" "\tFor more info see issue #444 (https://github.com/NervanaSystems/distiller/issues/444)" % (self.activation_rank_criterion, fq_name, self.activation_rank_criterion)) quality_criterion, std = getattr( module, self.activation_rank_criterion).value() num_filters = param.size(0) num_filters_to_prune = int(fraction_to_prune * num_filters) if num_filters_to_prune == 0: msglogger.info("Too few filters - can't prune %.1f%% filters", 100 * fraction_to_prune) return # Sort from low to high, and remove the bottom 'num_filters_to_prune' filters filters_ordered_by_criterion = np.argsort( quality_criterion)[:-num_filters_to_prune] mask, binary_map = _mask_from_filter_order( filters_ordered_by_criterion, param, num_filters, binary_map) zeros_mask_dict[param_name].mask = mask msglogger.info( "ActivationL1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name, distiller.sparsity_3D(zeros_mask_dict[param_name].mask), fraction_to_prune, num_filters_to_prune, num_filters) return binary_map
def rank_and_prune_channels(fraction_to_prune, param, param_name=None, zeros_mask_dict=None, model=None, binary_map=None, magnitude_fn=distiller.norms.l1_norm, group_size=1, rounding_fn=math.floor, noise=0): assert binary_map is None if binary_map is None: bottomk_channels, channel_mags = distiller.norms.rank_channels(param, group_size, magnitude_fn, fraction_to_prune, rounding_fn, noise) # Todo: this little piece of code can be refactored if bottomk_channels is None: # Empty list means that fraction_to_prune is too low to prune anything return threshold = bottomk_channels[-1] binary_map = channel_mags.gt(threshold) # These are the indices of channels we want to keep indices = binary_map.nonzero().squeeze() if len(indices.shape) == 0: indices = indices.expand(1) # Find the module representing this layer distiller.assign_layer_fq_names(model) layer_name = _param_name_2_layer_name(param_name) conv = distiller.find_module_by_fq_name(model, layer_name) try: Y = model.intermediate_fms['output_fms'][layer_name] X = model.intermediate_fms['input_fms'][layer_name] except AttributeError: raise ValueError("To use FMReconstructionChannelPruner you must first collect input statistics") # We need to remove the chosen weights channels. Because we are using # min(MSE) to compute the weights, we need to start by removing feature-map # channels from the input. Then we perform the MSE regression to generate # a smaller weights tensor. if op_type == 'fc': X = X[:, binary_map] elif conv.kernel_size == (1, 1): X = X[:, binary_map, :] X = X.transpose(1, 2) X = X.contiguous().view(-1, X.size(2)) else: # X is (batch, ck^2, num_pts) # we want: (batch, c, k^2, num_pts) X = X.view(X.size(0), -1, np.prod(conv.kernel_size), X.size(2)) X = X[:, binary_map, :, :] X = X.view(X.size(0), -1, X.size(3)) X = X.transpose(1, 2) X = X.contiguous().view(-1, X.size(2)) # Approximate the weights given input-FMs and output-FMs new_w = _least_square_sklearn(X, Y) new_w = torch.from_numpy(new_w) # shape: (num_filters, num_non_masked_channels * k^2) cnt_retained_channels = binary_map.sum() if op_type == 'conv': # Expand the weights back to their original size, new_w = new_w.contiguous().view(param.size(0), cnt_retained_channels, param.size(2), param.size(3)) # Copy the weights that we learned from minimizing the feature-maps least squares error, # to our actual weights tensor. param.detach()[:, indices, :, :] = new_w.type(param.type()) else: param.detach()[:, indices] = new_w.type(param.type()) if zeros_mask_dict is not None: binary_map = binary_map.type(param.type()) if op_type == 'conv': zeros_mask_dict[param_name].mask, _ = distiller.thresholding.expand_binary_map(param, 'Channels', binary_map) msglogger.info("FMReconstructionChannelPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name, distiller.sparsity_ch(zeros_mask_dict[param_name].mask), fraction_to_prune, binary_map.sum().item(), param.size(1)) else: msglogger.error("fc sparsity = %.2f" % (1 - binary_map.sum().item() / binary_map.size(0))) zeros_mask_dict[param_name].mask = binary_map.expand(param.size(0), param.size(1)) msglogger.info("FMReconstructionChannelPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name, distiller.sparsity_cols(zeros_mask_dict[param_name].mask), fraction_to_prune, binary_map.sum().item(), param.size(1)) return binary_map