def test_weights_size_attr(dataset, arch, parallel): model = create_model(False, dataset, arch, parallel=parallel) sgraph = SummaryGraph(model, distiller.get_dummy_input(dataset)) distiller.assign_layer_fq_names(model) for name, mod in model.named_modules(): if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.Linear): op = sgraph.find_op(name) assert op is not None assert op['attrs']['weights_vol'] == distiller.volume(mod.weight)
def rank_and_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map=None): assert param.dim( ) == 4, "This pruning is only supported for 4D weights" # Use the parameter name to locate the module that has the activation sparsity statistics fq_name = param_name.replace(".conv", ".relu")[:-len(".weight")] distiller.assign_layer_fq_names(model) module = distiller.find_module_by_fq_name(model, fq_name) assert module is not None if not hasattr(module, self.activation_rank_criterion): raise ValueError( "Could not find attribute \"%s\" in module %s\n" "\tThis is pruner uses activation statistics collected during forward-" "passes of the network.\n" "\tThis error is an indication that these statistics " "have not been collected yet.\n" "\tMake sure to use SummaryActivationStatsCollector(\"%s\")\n" "\tFor more info see issue #444 (https://github.com/NervanaSystems/distiller/issues/444)" % (self.activation_rank_criterion, fq_name, self.activation_rank_criterion)) quality_criterion, std = getattr( module, self.activation_rank_criterion).value() num_filters = param.size(0) num_filters_to_prune = int(fraction_to_prune * num_filters) if num_filters_to_prune == 0: msglogger.info("Too few filters - can't prune %.1f%% filters", 100 * fraction_to_prune) return # Sort from low to high, and remove the bottom 'num_filters_to_prune' filters filters_ordered_by_criterion = np.argsort( quality_criterion)[:-num_filters_to_prune] mask, binary_map = _mask_from_filter_order( filters_ordered_by_criterion, param, num_filters, binary_map) zeros_mask_dict[param_name].mask = mask msglogger.info( "ActivationL1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name, distiller.sparsity_3D(zeros_mask_dict[param_name].mask), fraction_to_prune, num_filters_to_prune, num_filters) return binary_map
def collect_intermediate_featuremap_samples(model, forward_fn, module_filter_fn, fm_caching_fwd_hook=basic_featuremaps_caching_fwd_hook): """Collect pairs of input/output feature-maps. """ from functools import partial def install_io_collectors(m, intermediate_fms): if module_filter_fn(m): intermediate_fms['output_fms'][m.distiller_name] = [] intermediate_fms['input_fms'][m.distiller_name] = [] hook_handles.append(m.register_forward_hook(partial(fm_caching_fwd_hook, intermediate_fms=intermediate_fms))) # Register to the forward hooks, then run the forward-pass and collect the data msglogger.warning("==> Collecting input/ouptput feature-map pairs") distiller.assign_layer_fq_names(model) hook_handles = [] intermediate_fms = {"output_fms": dict(), "input_fms": dict()} model.apply(partial(install_io_collectors, intermediate_fms=intermediate_fms)) forward_fn() # Unregister from the forward hooks for handle in hook_handles: handle.remove() # We now need to concatenate the list of feature-maps to torch tensors. msglogger.info("Concatenating FMs...") model.intermediate_fms = {"output_fms": dict(), "input_fms": dict()} outputs = model.intermediate_fms['output_fms'] inputs = model.intermediate_fms['input_fms'] for (layer_name, X), Y in zip(intermediate_fms['input_fms'].items(), intermediate_fms['output_fms'].values()): inputs[layer_name] = torch.cat(X, dim=0) outputs[layer_name] = torch.cat(Y, dim=0) msglogger.warning("<== Done.") del intermediate_fms
def rank_and_prune_channels(fraction_to_prune, param, param_name=None, zeros_mask_dict=None, model=None, binary_map=None, magnitude_fn=distiller.norms.l1_norm, group_size=1, rounding_fn=math.floor, noise=0): assert binary_map is None if binary_map is None: bottomk_channels, channel_mags = distiller.norms.rank_channels(param, group_size, magnitude_fn, fraction_to_prune, rounding_fn, noise) # Todo: this little piece of code can be refactored if bottomk_channels is None: # Empty list means that fraction_to_prune is too low to prune anything return threshold = bottomk_channels[-1] binary_map = channel_mags.gt(threshold) # These are the indices of channels we want to keep indices = binary_map.nonzero().squeeze() if len(indices.shape) == 0: indices = indices.expand(1) # Find the module representing this layer distiller.assign_layer_fq_names(model) layer_name = _param_name_2_layer_name(param_name) conv = distiller.find_module_by_fq_name(model, layer_name) try: Y = model.intermediate_fms['output_fms'][layer_name] X = model.intermediate_fms['input_fms'][layer_name] except AttributeError: raise ValueError("To use FMReconstructionChannelPruner you must first collect input statistics") # We need to remove the chosen weights channels. Because we are using # min(MSE) to compute the weights, we need to start by removing feature-map # channels from the input. Then we perform the MSE regression to generate # a smaller weights tensor. if op_type == 'fc': X = X[:, binary_map] elif conv.kernel_size == (1, 1): X = X[:, binary_map, :] X = X.transpose(1, 2) X = X.contiguous().view(-1, X.size(2)) else: # X is (batch, ck^2, num_pts) # we want: (batch, c, k^2, num_pts) X = X.view(X.size(0), -1, np.prod(conv.kernel_size), X.size(2)) X = X[:, binary_map, :, :] X = X.view(X.size(0), -1, X.size(3)) X = X.transpose(1, 2) X = X.contiguous().view(-1, X.size(2)) # Approximate the weights given input-FMs and output-FMs new_w = _least_square_sklearn(X, Y) new_w = torch.from_numpy(new_w) # shape: (num_filters, num_non_masked_channels * k^2) cnt_retained_channels = binary_map.sum() if op_type == 'conv': # Expand the weights back to their original size, new_w = new_w.contiguous().view(param.size(0), cnt_retained_channels, param.size(2), param.size(3)) # Copy the weights that we learned from minimizing the feature-maps least squares error, # to our actual weights tensor. param.detach()[:, indices, :, :] = new_w.type(param.type()) else: param.detach()[:, indices] = new_w.type(param.type()) if zeros_mask_dict is not None: binary_map = binary_map.type(param.type()) if op_type == 'conv': zeros_mask_dict[param_name].mask, _ = distiller.thresholding.expand_binary_map(param, 'Channels', binary_map) msglogger.info("FMReconstructionChannelPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name, distiller.sparsity_ch(zeros_mask_dict[param_name].mask), fraction_to_prune, binary_map.sum().item(), param.size(1)) else: msglogger.error("fc sparsity = %.2f" % (1 - binary_map.sum().item() / binary_map.size(0))) zeros_mask_dict[param_name].mask = binary_map.expand(param.size(0), param.size(1)) msglogger.info("FMReconstructionChannelPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name, distiller.sparsity_cols(zeros_mask_dict[param_name].mask), fraction_to_prune, binary_map.sum().item(), param.size(1)) return binary_map
def prepare_model(self, dummy_input=None): """ Traverses the model and replaces sub-modules with quantized counterparts according to the bit-width and overrides configuration provided to __init__(), and according to the replacement_factory as defined by the Quantizer sub-class being used. Note: If multiple sub-modules within the model actually reference the same module, then that module is replaced only once, according to the configuration (bit-width and/or overrides) of the first encountered reference. Toy Example - say a module is constructed using this bit of code: shared_relu = nn.ReLU self.relu1 = shared_relu self.relu2 = shared_relu When traversing the model, a replacement will be generated when 'self.relu1' is encountered. Let's call it `new_relu1'. When 'self.relu2' will be encountered, it'll simply be replaced with a reference to 'new_relu1'. Any override configuration made specifically for 'self.relu2' will be ignored. A warning message will be shown. """ msglogger.info('Preparing model for quantization using {0}'.format( self.__class__.__name__)) self.model.quantizer_metadata["dummy_input"] = dummy_input if dummy_input is not None: summary_graph = distiller.SummaryGraph(self.model, dummy_input) self.adjacency_map = summary_graph.adjacency_map( dedicated_modules_only=False) model_device = distiller.model_device(self.model) self._pre_prepare_model(dummy_input) self._pre_process_container(self.model) for module_name, module in self.model.named_modules(): qbits = self.module_qbits_map[module_name] curr_parameters = dict(module.named_parameters()) for param_name, param in curr_parameters.items(): n_bits = qbits.bias if param_name.endswith( 'bias') else qbits.wts if n_bits is None: continue fp_attr_name = param_name if self.train_with_fp_copy: hack_float_backup_parameter(module, param_name, n_bits) fp_attr_name = FP_BKP_PREFIX + param_name self.params_to_quantize.append( _ParamToQuant(module, module_name, fp_attr_name, param_name, n_bits)) param_full_name = '.'.join([module_name, param_name]) msglogger.debug( "Parameter '{0}' will be quantized to {1} bits".format( param_full_name, n_bits)) # If an optimizer was passed, assume we need to update it if self.optimizer: for pg in self._get_new_optimizer_params_groups(): self.optimizer.add_param_group(pg) self._post_prepare_model() # Re-transfer model to the device it was on, in case the quantizer created new parameters/buffers self.model.to(model_device) distiller.assign_layer_fq_names(self.model) self.prepared = True msglogger.debug('Quantized model:\n\n{0}\n'.format(self.model))
def fuse_modules(model, types_sequence, fuse_fn, dummy_input=None, adjacency_map=None): """ Scans the module for sequences of modules of the specified types and "fuses" them. As an example, consider the following sequence of 3 modules: 'm_1' --> 'm_2' --> 'm_3'. Assuming they match the specified sequence of types, they will be fused such that the "fused" module replaces 'm_1', and 'm_2' and 'm_3' are replaced with identity operations. For a sequence of modules to be fused, it must not contain splits. That is - no module in the sequence can have more than a single output. For example, consider the following sequence: m_1 --> m_2 --> m_3 | | --> m_4 Even if m_1, m_2 and m_3 match the types sequence, they can't be fused because m_1's output also goes to m_4. The fused module is generated by the user specified function 'fuse_fn'. To infer the order of modules it is required to perform a forward pass on the model. Hence the need to pass the expected input shape. Args: model (nn.Module): Model instance on which the transformation is performed types_sequence (list or tuple): Sequence of module types. Each item in the sequence may itself be a list / tuple. For example - to fuse all possible convolution types with ReLU, pass: [[nn.Conv1d, nn.Conv2d, nn.Conv3d], nn.ReLU] fuse_fn (function): Function that takes a list of models to be fused, and returns a single fused module. If the sequence cannot be fused, this function should return None dummy_input (torch.Tensor or tuple): Dummy input to the model. Required if summary_graph is None adjacency_map (OrderedDict): Pre-computed adjacency map, via SummaryGraph.adjacency_map(). Must be based on the passed model, otherwise results are unexpected. If None, then the adjacency map will be created internally using the passed dummy_input. """ distiller.assign_layer_fq_names(model) if adjacency_map is None: if dummy_input is None: raise ValueError( 'Must pass either valid adjacency map instance or valid dummy input' ) summary_graph = distiller.SummaryGraph(model, dummy_input) adjacency_map = summary_graph.adjacency_map( dedicated_modules_only=False) named_modules = OrderedDict(model.named_modules()) in_sequence_idx = 0 curr_sequence = [] for node_name, adj_entry in adjacency_map.items(): module = named_modules.get(node_name, None) if module is None: reset = True else: reset = False if isinstance(module, types_sequence[in_sequence_idx]): curr_sequence.append(module) in_sequence_idx += 1 if in_sequence_idx == len(types_sequence): _fuse_sequence(curr_sequence, named_modules, fuse_fn) reset = True elif len(adj_entry.successors) > 1: msglogger.debug( node_name + " is connected to multiple outputs, not fuse-able") reset = True elif isinstance(module, types_sequence[0]): # Current module breaks the current sequence, check if it's the start of a new sequence in_sequence_idx = 1 curr_sequence = [module] else: reset = True if reset: in_sequence_idx = 0 curr_sequence = [] return model
def __init__(self, model, app_args, amc_cfg, services): self.pylogger = distiller.data_loggers.PythonLogger(msglogger) logdir = logging.getLogger().logdir self.tflogger = distiller.data_loggers.TensorBoardLogger(logdir) self.verbose = False self.orig_model = copy.deepcopy(model) self.app_args = app_args self.amc_cfg = amc_cfg self.services = services try: modules_list = amc_cfg.modules_dict[app_args.arch] except KeyError: msglogger.warning( "!!! The config file does not specify the modules to compress for %s" % app_args.arch) # Default to using all convolution layers distiller.assign_layer_fq_names(model) modules_list = [ mod.distiller_name for mod in model.modules() if type(mod) == torch.nn.Conv2d ] msglogger.warning("Using the following layers: %s" % ", ".join(modules_list)) self.net_wrapper = NetworkWrapper(model, app_args, services, modules_list, amc_cfg.pruning_pattern) self.original_model_macs, self.original_model_size = self.net_wrapper.get_resources_requirements( ) self.reset(init_only=True) msglogger.debug("Model %s has %d modules (%d pruned)", self.app_args.arch, self.net_wrapper.model_metadata.num_layers(), self.net_wrapper.model_metadata.num_pruned_layers()) msglogger.debug("\tTotal MACs: %s" % distiller.pretty_int(self.original_model_macs)) msglogger.debug("\tTotal weights: %s" % distiller.pretty_int(self.original_model_size)) self._max_episode_steps = self.net_wrapper.model_metadata.num_pruned_layers( ) # Hack for Coach-TD3 log_amc_config(amc_cfg) self.episode = 0 self.best_reward = float("-inf") self.action_low = amc_cfg.action_range[0] self.action_high = amc_cfg.action_range[1] if is_using_continuous_action_space(self.amc_cfg.agent_algo): if self.amc_cfg.agent_algo == "ClippedPPO-continuous": self.action_space = spaces.Box(PPO_MIN, PPO_MAX, shape=(1, )) else: self.action_space = spaces.Box(self.action_low, self.action_high, shape=(1, )) self.action_space.default_action = self.action_low else: self.action_space = spaces.Discrete(10) self.observation_space = spaces.Box(0, float("inf"), shape=(len(Observation._fields), )) self.stats_logger = AMCStatsLogger(os.path.join(logdir, 'amc.csv')) self.ft_stats_logger = FineTuneStatsLogger( os.path.join(logdir, 'ft_top1.csv')) if self.amc_cfg.pruning_method == "fm-reconstruction": if self.amc_cfg.pruning_pattern != "channels": raise ValueError( "Feature-map reconstruction is only supported when pruning weights channels" ) from functools import partial def acceptance_criterion(m, mod_names): # Collect feature-maps only for Conv2d layers, if they are in our modules list. return isinstance( m, torch.nn.Conv2d) and m.distiller_name in mod_names # For feature-map reconstruction we need to collect a representative set # of inter-layer feature-maps from distiller.pruning import FMReconstructionChannelPruner collect_intermediate_featuremap_samples( self.net_wrapper.model, self.net_wrapper.validate, partial(acceptance_criterion, mod_names=modules_list), partial( FMReconstructionChannelPruner.cache_featuremaps_fwd_hook, n_points_per_fm=self.amc_cfg.n_points_per_fm))