def rank_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model):
        assert param.dim() == 4, "This thresholding is only supported for 4D weights"

        # Use the parameter name to locate the module that has the activation sparsity statistics
        fq_name = param_name.replace(".conv", ".relu")[:-len(".weight")]
        module = distiller.find_module_by_fq_name(model, fq_name)
        if module is None:
            raise ValueError("Could not find a layer named %s in the model."
                             "\nMake sure to use assign_layer_fq_names()" % fq_name)
        if not hasattr(module, 'apoz_channels'):
            raise ValueError("Could not find attribute \'apoz_channels\' in module %s."
                             "\nMake sure to use SummaryActivationStatsCollector(\"apoz_channels\")" % fq_name)

        apoz, std = module.apoz_channels.value()
        num_filters = param.size(0)
        num_filters_to_prune = int(fraction_to_prune * num_filters)
        if num_filters_to_prune == 0:
            msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune)
            return

        # Sort from high to low, and remove the bottom 'num_filters_to_prune' filters
        filters_ordered_by_apoz = np.argsort(-apoz)[:-num_filters_to_prune]
        zeros_mask_dict[param_name].mask = RankedFiltersParameterPruner.mask_from_filter_order(filters_ordered_by_apoz,
                                                                                               param, num_filters)

        msglogger.info("ActivationL1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
                       param_name,
                       distiller.sparsity_3D(zeros_mask_dict[param_name].mask),
                       fraction_to_prune, num_filters_to_prune, num_filters)
Exemple #2
0
    def rank_and_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map=None):
        assert param.dim() == 4, "This pruning is only supported for 4D weights"

        # Use the parameter name to locate the module that has the activation sparsity statistics
        fq_name = param_name.replace(".conv", ".relu")[:-len(".weight")]
        module = distiller.find_module_by_fq_name(model, fq_name)
        if module is None:
            raise ValueError("Could not find a layer named %s in the model."
                             "\nMake sure to use assign_layer_fq_names()" % fq_name)
        if not hasattr(module, self.activation_rank_criterion):
            raise ValueError("Could not find attribute \"{}\" in module %s"
                             "\nMake sure to use SummaryActivationStatsCollector(\"{}\")".
                             format(self.activation_rank_criterion, fq_name, self.activation_rank_criterion))

        quality_criterion, std = getattr(module, self.activation_rank_criterion).value()
        num_filters = param.size(0)
        num_filters_to_prune = int(fraction_to_prune * num_filters)
        if num_filters_to_prune == 0:
            msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune)
            return

        # Sort from low to high, and remove the bottom 'num_filters_to_prune' filters
        filters_ordered_by_criterion = np.argsort(quality_criterion)[:-num_filters_to_prune]
        mask, binary_map = _mask_from_filter_order(filters_ordered_by_criterion, param, num_filters, binary_map)
        zeros_mask_dict[param_name].mask = mask

        msglogger.info("ActivationL1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
                       param_name,
                       distiller.sparsity_3D(zeros_mask_dict[param_name].mask),
                       fraction_to_prune, num_filters_to_prune, num_filters)
        return binary_map
Exemple #3
0
    def rank_and_prune_filters(self,
                               fraction_to_prune,
                               param,
                               param_name,
                               zeros_mask_dict,
                               model,
                               binary_map=None):
        assert param.dim(
        ) == 4, "This pruning is only supported for 4D weights"

        # Use the parameter name to locate the module that has the activation sparsity statistics
        fq_name = param_name.replace(".conv", ".relu")[:-len(".weight")]
        distiller.assign_layer_fq_names(model)
        module = distiller.find_module_by_fq_name(model, fq_name)
        assert module is not None

        if not hasattr(module, self.activation_rank_criterion):
            raise ValueError(
                "Could not find attribute \"%s\" in module %s\n"
                "\tThis is pruner uses activation statistics collected during forward-"
                "passes of the network.\n"
                "\tThis error is an indication that these statistics "
                "have not been collected yet.\n"
                "\tMake sure to use SummaryActivationStatsCollector(\"%s\")\n"
                "\tFor more info see issue #444 (https://github.com/NervanaSystems/distiller/issues/444)"
                % (self.activation_rank_criterion, fq_name,
                   self.activation_rank_criterion))

        quality_criterion, std = getattr(
            module, self.activation_rank_criterion).value()
        num_filters = param.size(0)
        num_filters_to_prune = int(fraction_to_prune * num_filters)
        if num_filters_to_prune == 0:
            msglogger.info("Too few filters - can't prune %.1f%% filters",
                           100 * fraction_to_prune)
            return

        # Sort from low to high, and remove the bottom 'num_filters_to_prune' filters
        filters_ordered_by_criterion = np.argsort(
            quality_criterion)[:-num_filters_to_prune]
        mask, binary_map = _mask_from_filter_order(
            filters_ordered_by_criterion, param, num_filters, binary_map)
        zeros_mask_dict[param_name].mask = mask

        msglogger.info(
            "ActivationL1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
            param_name,
            distiller.sparsity_3D(zeros_mask_dict[param_name].mask),
            fraction_to_prune, num_filters_to_prune, num_filters)
        return binary_map
Exemple #4
0
    def rank_and_prune_channels(fraction_to_prune, param, param_name=None,
                                zeros_mask_dict=None, model=None, binary_map=None, 
                                magnitude_fn=distiller.norms.l1_norm, group_size=1, rounding_fn=math.floor,
                                noise=0):
        assert binary_map is None
        if binary_map is None:
            bottomk_channels, channel_mags = distiller.norms.rank_channels(param, group_size, magnitude_fn,
                                                                           fraction_to_prune, rounding_fn, noise)

            # Todo: this little piece of code can be refactored
            if bottomk_channels is None:
                # Empty list means that fraction_to_prune is too low to prune anything
                return

            threshold = bottomk_channels[-1]
            binary_map = channel_mags.gt(threshold)

            # These are the indices of channels we want to keep
            indices = binary_map.nonzero().squeeze()
            if len(indices.shape) == 0:
                indices = indices.expand(1)

            # Find the module representing this layer
            distiller.assign_layer_fq_names(model)
            layer_name = _param_name_2_layer_name(param_name)
            conv = distiller.find_module_by_fq_name(model, layer_name)
            try:
                Y = model.intermediate_fms['output_fms'][layer_name]
                X = model.intermediate_fms['input_fms'][layer_name]
            except AttributeError:
                raise ValueError("To use FMReconstructionChannelPruner you must first collect input statistics")

            # We need to remove the chosen weights channels.  Because we are using 
            # min(MSE) to compute the weights, we need to start by removing feature-map 
            # channels from the input.  Then we perform the MSE regression to generate
            # a smaller weights tensor.
            if op_type == 'fc':
                X = X[:, binary_map]
            elif conv.kernel_size == (1, 1):
                X = X[:, binary_map, :]
                X = X.transpose(1, 2)
                X = X.contiguous().view(-1, X.size(2))
            else:
                # X is (batch, ck^2, num_pts)
                # we want:   (batch, c, k^2, num_pts)
                X = X.view(X.size(0), -1, np.prod(conv.kernel_size), X.size(2))
                X = X[:, binary_map, :, :]
                X = X.view(X.size(0), -1, X.size(3))
                X = X.transpose(1, 2)
                X = X.contiguous().view(-1, X.size(2))

            # Approximate the weights given input-FMs and output-FMs
            new_w = _least_square_sklearn(X, Y)
            new_w = torch.from_numpy(new_w) # shape: (num_filters, num_non_masked_channels * k^2)
            cnt_retained_channels = binary_map.sum()

            if op_type == 'conv':
                # Expand the weights back to their original size,
                new_w = new_w.contiguous().view(param.size(0), cnt_retained_channels, param.size(2), param.size(3))

                # Copy the weights that we learned from minimizing the feature-maps least squares error,
                # to our actual weights tensor.
                param.detach()[:, indices, :,   :] = new_w.type(param.type())
            else:
                param.detach()[:, indices] = new_w.type(param.type())

        if zeros_mask_dict is not None:
            binary_map = binary_map.type(param.type())
            if op_type == 'conv':
                zeros_mask_dict[param_name].mask, _ = distiller.thresholding.expand_binary_map(param,
                                                                                               'Channels', binary_map)
                msglogger.info("FMReconstructionChannelPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
                               param_name,
                               distiller.sparsity_ch(zeros_mask_dict[param_name].mask),
                               fraction_to_prune, binary_map.sum().item(), param.size(1))
            else:
                msglogger.error("fc sparsity = %.2f" % (1 - binary_map.sum().item() / binary_map.size(0)))
                zeros_mask_dict[param_name].mask = binary_map.expand(param.size(0), param.size(1))
                msglogger.info("FMReconstructionChannelPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
                               param_name,
                               distiller.sparsity_cols(zeros_mask_dict[param_name].mask),
                               fraction_to_prune, binary_map.sum().item(), param.size(1))
        return binary_map