Esempio n. 1
0
def test_weights_size_attr(dataset, arch, parallel):
    model = create_model(False, dataset, arch, parallel=parallel)
    sgraph = SummaryGraph(model, distiller.get_dummy_input(dataset))

    distiller.assign_layer_fq_names(model)
    for name, mod in model.named_modules():
        if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.Linear):
            op = sgraph.find_op(name)
            assert op is not None
            assert op['attrs']['weights_vol'] == distiller.volume(mod.weight)
Esempio n. 2
0
    def rank_and_prune_filters(self,
                               fraction_to_prune,
                               param,
                               param_name,
                               zeros_mask_dict,
                               model,
                               binary_map=None):
        assert param.dim(
        ) == 4, "This pruning is only supported for 4D weights"

        # Use the parameter name to locate the module that has the activation sparsity statistics
        fq_name = param_name.replace(".conv", ".relu")[:-len(".weight")]
        distiller.assign_layer_fq_names(model)
        module = distiller.find_module_by_fq_name(model, fq_name)
        assert module is not None

        if not hasattr(module, self.activation_rank_criterion):
            raise ValueError(
                "Could not find attribute \"%s\" in module %s\n"
                "\tThis is pruner uses activation statistics collected during forward-"
                "passes of the network.\n"
                "\tThis error is an indication that these statistics "
                "have not been collected yet.\n"
                "\tMake sure to use SummaryActivationStatsCollector(\"%s\")\n"
                "\tFor more info see issue #444 (https://github.com/NervanaSystems/distiller/issues/444)"
                % (self.activation_rank_criterion, fq_name,
                   self.activation_rank_criterion))

        quality_criterion, std = getattr(
            module, self.activation_rank_criterion).value()
        num_filters = param.size(0)
        num_filters_to_prune = int(fraction_to_prune * num_filters)
        if num_filters_to_prune == 0:
            msglogger.info("Too few filters - can't prune %.1f%% filters",
                           100 * fraction_to_prune)
            return

        # Sort from low to high, and remove the bottom 'num_filters_to_prune' filters
        filters_ordered_by_criterion = np.argsort(
            quality_criterion)[:-num_filters_to_prune]
        mask, binary_map = _mask_from_filter_order(
            filters_ordered_by_criterion, param, num_filters, binary_map)
        zeros_mask_dict[param_name].mask = mask

        msglogger.info(
            "ActivationL1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
            param_name,
            distiller.sparsity_3D(zeros_mask_dict[param_name].mask),
            fraction_to_prune, num_filters_to_prune, num_filters)
        return binary_map
Esempio n. 3
0
def collect_intermediate_featuremap_samples(model, forward_fn, module_filter_fn, 
                                            fm_caching_fwd_hook=basic_featuremaps_caching_fwd_hook):
    """Collect pairs of input/output feature-maps.
    """
    from functools import partial

    def install_io_collectors(m, intermediate_fms):
        if module_filter_fn(m):
            intermediate_fms['output_fms'][m.distiller_name] = []
            intermediate_fms['input_fms'][m.distiller_name] = []
            hook_handles.append(m.register_forward_hook(partial(fm_caching_fwd_hook, 
                                                                intermediate_fms=intermediate_fms)))

    # Register to the forward hooks, then run the forward-pass and collect the data
    msglogger.warning("==> Collecting input/ouptput feature-map pairs")
    distiller.assign_layer_fq_names(model)
    hook_handles = []
    intermediate_fms = {"output_fms": dict(), "input_fms": dict()}
    model.apply(partial(install_io_collectors, intermediate_fms=intermediate_fms))
    
    forward_fn()
    
    # Unregister from the forward hooks
    for handle in hook_handles:
        handle.remove()

    # We now need to concatenate the list of feature-maps to torch tensors.
    msglogger.info("Concatenating FMs...")
    model.intermediate_fms = {"output_fms": dict(), "input_fms": dict()}
    outputs = model.intermediate_fms['output_fms']
    inputs = model.intermediate_fms['input_fms']

    for (layer_name, X), Y in zip(intermediate_fms['input_fms'].items(), intermediate_fms['output_fms'].values()):                
        inputs[layer_name] = torch.cat(X, dim=0)
        outputs[layer_name] = torch.cat(Y, dim=0)

    msglogger.warning("<== Done.")
    del intermediate_fms 
Esempio n. 4
0
    def rank_and_prune_channels(fraction_to_prune, param, param_name=None,
                                zeros_mask_dict=None, model=None, binary_map=None, 
                                magnitude_fn=distiller.norms.l1_norm, group_size=1, rounding_fn=math.floor,
                                noise=0):
        assert binary_map is None
        if binary_map is None:
            bottomk_channels, channel_mags = distiller.norms.rank_channels(param, group_size, magnitude_fn,
                                                                           fraction_to_prune, rounding_fn, noise)

            # Todo: this little piece of code can be refactored
            if bottomk_channels is None:
                # Empty list means that fraction_to_prune is too low to prune anything
                return

            threshold = bottomk_channels[-1]
            binary_map = channel_mags.gt(threshold)

            # These are the indices of channels we want to keep
            indices = binary_map.nonzero().squeeze()
            if len(indices.shape) == 0:
                indices = indices.expand(1)

            # Find the module representing this layer
            distiller.assign_layer_fq_names(model)
            layer_name = _param_name_2_layer_name(param_name)
            conv = distiller.find_module_by_fq_name(model, layer_name)
            try:
                Y = model.intermediate_fms['output_fms'][layer_name]
                X = model.intermediate_fms['input_fms'][layer_name]
            except AttributeError:
                raise ValueError("To use FMReconstructionChannelPruner you must first collect input statistics")

            # We need to remove the chosen weights channels.  Because we are using 
            # min(MSE) to compute the weights, we need to start by removing feature-map 
            # channels from the input.  Then we perform the MSE regression to generate
            # a smaller weights tensor.
            if op_type == 'fc':
                X = X[:, binary_map]
            elif conv.kernel_size == (1, 1):
                X = X[:, binary_map, :]
                X = X.transpose(1, 2)
                X = X.contiguous().view(-1, X.size(2))
            else:
                # X is (batch, ck^2, num_pts)
                # we want:   (batch, c, k^2, num_pts)
                X = X.view(X.size(0), -1, np.prod(conv.kernel_size), X.size(2))
                X = X[:, binary_map, :, :]
                X = X.view(X.size(0), -1, X.size(3))
                X = X.transpose(1, 2)
                X = X.contiguous().view(-1, X.size(2))

            # Approximate the weights given input-FMs and output-FMs
            new_w = _least_square_sklearn(X, Y)
            new_w = torch.from_numpy(new_w) # shape: (num_filters, num_non_masked_channels * k^2)
            cnt_retained_channels = binary_map.sum()

            if op_type == 'conv':
                # Expand the weights back to their original size,
                new_w = new_w.contiguous().view(param.size(0), cnt_retained_channels, param.size(2), param.size(3))

                # Copy the weights that we learned from minimizing the feature-maps least squares error,
                # to our actual weights tensor.
                param.detach()[:, indices, :,   :] = new_w.type(param.type())
            else:
                param.detach()[:, indices] = new_w.type(param.type())

        if zeros_mask_dict is not None:
            binary_map = binary_map.type(param.type())
            if op_type == 'conv':
                zeros_mask_dict[param_name].mask, _ = distiller.thresholding.expand_binary_map(param,
                                                                                               'Channels', binary_map)
                msglogger.info("FMReconstructionChannelPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
                               param_name,
                               distiller.sparsity_ch(zeros_mask_dict[param_name].mask),
                               fraction_to_prune, binary_map.sum().item(), param.size(1))
            else:
                msglogger.error("fc sparsity = %.2f" % (1 - binary_map.sum().item() / binary_map.size(0)))
                zeros_mask_dict[param_name].mask = binary_map.expand(param.size(0), param.size(1))
                msglogger.info("FMReconstructionChannelPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
                               param_name,
                               distiller.sparsity_cols(zeros_mask_dict[param_name].mask),
                               fraction_to_prune, binary_map.sum().item(), param.size(1))
        return binary_map
Esempio n. 5
0
    def prepare_model(self, dummy_input=None):
        """
        Traverses the model and replaces sub-modules with quantized counterparts according to the bit-width
        and overrides configuration provided to __init__(), and according to the replacement_factory as
        defined by the Quantizer sub-class being used.

        Note:
            If multiple sub-modules within the model actually reference the same module, then that module
            is replaced only once, according to the configuration (bit-width and/or overrides) of the
            first encountered reference.
            Toy Example - say a module is constructed using this bit of code:

                shared_relu = nn.ReLU
                self.relu1 = shared_relu
                self.relu2 = shared_relu

            When traversing the model, a replacement will be generated when 'self.relu1' is encountered.
            Let's call it `new_relu1'. When 'self.relu2' will be encountered, it'll simply be replaced
            with a reference to 'new_relu1'. Any override configuration made specifically for 'self.relu2'
            will be ignored. A warning message will be shown.
        """
        msglogger.info('Preparing model for quantization using {0}'.format(
            self.__class__.__name__))

        self.model.quantizer_metadata["dummy_input"] = dummy_input
        if dummy_input is not None:
            summary_graph = distiller.SummaryGraph(self.model, dummy_input)
            self.adjacency_map = summary_graph.adjacency_map(
                dedicated_modules_only=False)

        model_device = distiller.model_device(self.model)

        self._pre_prepare_model(dummy_input)

        self._pre_process_container(self.model)
        for module_name, module in self.model.named_modules():
            qbits = self.module_qbits_map[module_name]
            curr_parameters = dict(module.named_parameters())
            for param_name, param in curr_parameters.items():
                n_bits = qbits.bias if param_name.endswith(
                    'bias') else qbits.wts
                if n_bits is None:
                    continue
                fp_attr_name = param_name
                if self.train_with_fp_copy:
                    hack_float_backup_parameter(module, param_name, n_bits)
                    fp_attr_name = FP_BKP_PREFIX + param_name
                self.params_to_quantize.append(
                    _ParamToQuant(module, module_name, fp_attr_name,
                                  param_name, n_bits))

                param_full_name = '.'.join([module_name, param_name])
                msglogger.debug(
                    "Parameter '{0}' will be quantized to {1} bits".format(
                        param_full_name, n_bits))

        # If an optimizer was passed, assume we need to update it
        if self.optimizer:
            for pg in self._get_new_optimizer_params_groups():
                self.optimizer.add_param_group(pg)

        self._post_prepare_model()

        # Re-transfer model to the device it was on, in case the quantizer created new parameters/buffers
        self.model.to(model_device)

        distiller.assign_layer_fq_names(self.model)

        self.prepared = True

        msglogger.debug('Quantized model:\n\n{0}\n'.format(self.model))
Esempio n. 6
0
def fuse_modules(model,
                 types_sequence,
                 fuse_fn,
                 dummy_input=None,
                 adjacency_map=None):
    """
    Scans the module for sequences of modules of the specified types and "fuses" them. As an example, consider the
    following sequence of 3 modules: 'm_1' --> 'm_2' --> 'm_3'. Assuming they match the specified sequence of types,
    they will be fused such that the "fused" module replaces 'm_1', and 'm_2' and 'm_3' are replaced with identity
    operations.

    For a sequence of modules to be fused, it must not contain splits. That is - no module in the sequence can
    have more than a single output. For example, consider the following sequence:

    m_1 --> m_2 --> m_3
        |
        |
        --> m_4

    Even if m_1, m_2 and m_3 match the types sequence, they can't be fused because m_1's output also goes to m_4.

    The fused module is generated by the user specified function 'fuse_fn'.

    To infer the order of modules it is required to perform a forward pass on the model. Hence the need to pass the
    expected input shape.

    Args:
        model (nn.Module): Model instance on which the transformation is performed
        types_sequence (list or tuple): Sequence of module types. Each item in the sequence may itself be a
          list / tuple. For example - to fuse all possible convolution types with ReLU, pass:
          [[nn.Conv1d, nn.Conv2d, nn.Conv3d], nn.ReLU]
        fuse_fn (function): Function that takes a list of models to be fused, and returns a single fused module.
          If the sequence cannot be fused, this function should return None
        dummy_input (torch.Tensor or tuple): Dummy input to the model. Required if summary_graph is None
        adjacency_map (OrderedDict): Pre-computed adjacency map, via SummaryGraph.adjacency_map(). Must be based
          on the passed model, otherwise results are unexpected. If None, then the adjacency map will be created
          internally using the passed dummy_input.
    """
    distiller.assign_layer_fq_names(model)
    if adjacency_map is None:
        if dummy_input is None:
            raise ValueError(
                'Must pass either valid adjacency map instance or valid dummy input'
            )
        summary_graph = distiller.SummaryGraph(model, dummy_input)
        adjacency_map = summary_graph.adjacency_map(
            dedicated_modules_only=False)
    named_modules = OrderedDict(model.named_modules())
    in_sequence_idx = 0
    curr_sequence = []

    for node_name, adj_entry in adjacency_map.items():
        module = named_modules.get(node_name, None)
        if module is None:
            reset = True
        else:
            reset = False
            if isinstance(module, types_sequence[in_sequence_idx]):
                curr_sequence.append(module)
                in_sequence_idx += 1
                if in_sequence_idx == len(types_sequence):
                    _fuse_sequence(curr_sequence, named_modules, fuse_fn)
                    reset = True
                elif len(adj_entry.successors) > 1:
                    msglogger.debug(
                        node_name +
                        " is connected to multiple outputs, not fuse-able")
                    reset = True
            elif isinstance(module, types_sequence[0]):
                # Current module breaks the current sequence, check if it's the start of a new sequence
                in_sequence_idx = 1
                curr_sequence = [module]
            else:
                reset = True
        if reset:
            in_sequence_idx = 0
            curr_sequence = []
    return model
Esempio n. 7
0
    def __init__(self, model, app_args, amc_cfg, services):
        self.pylogger = distiller.data_loggers.PythonLogger(msglogger)
        logdir = logging.getLogger().logdir
        self.tflogger = distiller.data_loggers.TensorBoardLogger(logdir)
        self.verbose = False
        self.orig_model = copy.deepcopy(model)
        self.app_args = app_args
        self.amc_cfg = amc_cfg
        self.services = services

        try:
            modules_list = amc_cfg.modules_dict[app_args.arch]
        except KeyError:
            msglogger.warning(
                "!!! The config file does not specify the modules to compress for %s"
                % app_args.arch)
            # Default to using all convolution layers
            distiller.assign_layer_fq_names(model)
            modules_list = [
                mod.distiller_name for mod in model.modules()
                if type(mod) == torch.nn.Conv2d
            ]
            msglogger.warning("Using the following layers: %s" %
                              ", ".join(modules_list))

        self.net_wrapper = NetworkWrapper(model, app_args, services,
                                          modules_list,
                                          amc_cfg.pruning_pattern)
        self.original_model_macs, self.original_model_size = self.net_wrapper.get_resources_requirements(
        )
        self.reset(init_only=True)
        msglogger.debug("Model %s has %d modules (%d pruned)",
                        self.app_args.arch,
                        self.net_wrapper.model_metadata.num_layers(),
                        self.net_wrapper.model_metadata.num_pruned_layers())
        msglogger.debug("\tTotal MACs: %s" %
                        distiller.pretty_int(self.original_model_macs))
        msglogger.debug("\tTotal weights: %s" %
                        distiller.pretty_int(self.original_model_size))
        self._max_episode_steps = self.net_wrapper.model_metadata.num_pruned_layers(
        )  # Hack for Coach-TD3
        log_amc_config(amc_cfg)

        self.episode = 0
        self.best_reward = float("-inf")
        self.action_low = amc_cfg.action_range[0]
        self.action_high = amc_cfg.action_range[1]

        if is_using_continuous_action_space(self.amc_cfg.agent_algo):
            if self.amc_cfg.agent_algo == "ClippedPPO-continuous":
                self.action_space = spaces.Box(PPO_MIN, PPO_MAX, shape=(1, ))
            else:
                self.action_space = spaces.Box(self.action_low,
                                               self.action_high,
                                               shape=(1, ))
            self.action_space.default_action = self.action_low
        else:
            self.action_space = spaces.Discrete(10)
        self.observation_space = spaces.Box(0,
                                            float("inf"),
                                            shape=(len(Observation._fields), ))
        self.stats_logger = AMCStatsLogger(os.path.join(logdir, 'amc.csv'))
        self.ft_stats_logger = FineTuneStatsLogger(
            os.path.join(logdir, 'ft_top1.csv'))

        if self.amc_cfg.pruning_method == "fm-reconstruction":
            if self.amc_cfg.pruning_pattern != "channels":
                raise ValueError(
                    "Feature-map reconstruction is only supported when pruning weights channels"
                )

            from functools import partial

            def acceptance_criterion(m, mod_names):
                # Collect feature-maps only for Conv2d layers, if they are in our modules list.
                return isinstance(
                    m, torch.nn.Conv2d) and m.distiller_name in mod_names

            # For feature-map reconstruction we need to collect a representative set
            # of inter-layer feature-maps
            from distiller.pruning import FMReconstructionChannelPruner
            collect_intermediate_featuremap_samples(
                self.net_wrapper.model, self.net_wrapper.validate,
                partial(acceptance_criterion, mod_names=modules_list),
                partial(
                    FMReconstructionChannelPruner.cache_featuremaps_fwd_hook,
                    n_points_per_fm=self.amc_cfg.n_points_per_fm))