Example #1
0
def test_magnitude_pruning():
    # Create a 4-D tensor of 1s
    a = torch.ones(3, 64, 32, 32)
    # Change one element
    a[1, 4, 17, 31] = 0.2
    # Create a masks dictionary and populate it with one ParameterMasker
    zeros_mask_dict = {}
    masker = distiller.ParameterMasker('a')
    zeros_mask_dict['a'] = masker
    # Try to use a MagnitudeParameterPruner with defining a default threshold
    with pytest.raises(AssertionError):
        pruner = distiller.pruning.MagnitudeParameterPruner("test", None)

    # Now define the default threshold
    thresholds = {"*": 0.4}
    pruner = distiller.pruning.MagnitudeParameterPruner("test", thresholds)
    assert distiller.sparsity(a) == 0
    # Create a mask for parameter 'a'
    pruner.set_param_mask(a, 'a', zeros_mask_dict, None)
    assert common.almost_equal(distiller.sparsity(zeros_mask_dict['a'].mask), 1/distiller.volume(a))

    # Let's now use the masker to prune a parameter
    masker = zeros_mask_dict['a']
    masker.apply_mask(a)
    assert common.almost_equal(distiller.sparsity(a), 1/distiller.volume(a))
    # We can use the masker on other tensors, if we want (and if they have the correct shape).
    # Remember that the mask was created already, so we're not thresholding - we are pruning
    b = torch.ones(3, 64, 32, 32)
    b[:] = 0.3
    masker.apply_mask(b)
    assert common.almost_equal(distiller.sparsity(b), 1/distiller.volume(a))
Example #2
0
def test_threshold_mask():
    # Create a 4-D tensor of 1s
    a = torch.ones(3, 64, 32, 32)
    # Change one element
    a[1, 4, 17, 31] = 0.2
    # Create and apply a mask
    mask = distiller.threshold_mask(a, threshold=0.3)
    assert np.sum(distiller.to_np(mask)) == (distiller.volume(a) - 1)
    assert mask[1, 4, 17, 31] == 0
    assert common.almost_equal(distiller.sparsity(mask), 1/distiller.volume(a))
Example #3
0
def module_visitor(self, input, output, df, model, weights_vol, macs, attrs=None):
    in_features_shape = input[0].size()
    out_features_shape = output.size()

    mod_name = distiller.model_find_module_name(model, self)
    df.loc[len(df.index)] = ([mod_name, self.__class__.__name__,
                              attrs if attrs is not None else '',
                              distiller.size_to_str(in_features_shape), distiller.volume(input[0]),
                              distiller.size_to_str(out_features_shape), distiller.volume(output),
                              int(weights_vol), int(macs)])
Example #4
0
def conv_visitor(self, input, output, df, model, memo):
    assert isinstance(self, torch.nn.Conv2d)
    if self in memo:
        return
    weights_vol = distiller.volume(self.weight)

    # Multiply-accumulate operations: MACs = volume(OFM) * (#IFM * K^2) / #Groups
    # Bias is ignored
    macs = (distiller.volume(output) *
            (self.in_channels / self.groups * self.kernel_size[0] * self.kernel_size[1]))
    attrs = 'k=' + '('+(', ').join(['%d' % v for v in self.kernel_size])+')'
    module_visitor(self, input, output, df, model, weights_vol, macs, attrs)
Example #5
0
    def get_model_representation(self):
        """Initialize an embedding representation of the entire model.

        At runtime, a specific row in the embedding matrix is chosen (depending on
        the current state) and the dynamic fields in the resulting state-embedding
        vector are updated. 
        """
        num_states = self.net_wrapper.num_pruned_layers()
        network_obs = np.empty(shape=(num_states, ObservationLen))
        for state_id, layer_id in enumerate(self.net_wrapper.model_metadata.pruned_idxs):
            layer = self.net_wrapper.get_layer(layer_id)
            layer_macs = self.net_wrapper.layer_macs(layer)
            conv_module = distiller.model_find_module(self.model, layer.name)
            obs = [state_id,
                   conv_module.out_channels,
                   conv_module.in_channels,
                   layer.ifm_h,
                   layer.ifm_w,
                   layer.stride[0],
                   layer.k,
                   distiller.volume(conv_module.weight),
                   layer_macs,
                   0, 0, 0]
            network_obs[state_id:] = np.array(obs)

        # Feature normalization
        for feature in range(ObservationLen):
            feature_vec = network_obs[:, feature]
            fmin = min(feature_vec)
            fmax = max(feature_vec)
            if fmax - fmin > 0:
                network_obs[:, feature] = (feature_vec - fmin) / (fmax - fmin)
        # msglogger.debug("model representation=\n{}".format(network_obs))
        return network_obs
def fc_visitor(self, input, output, df, model, memo):
    assert isinstance(self, torch.nn.Linear)
    if self in memo:
        return

    # Multiply-accumulate operations: MACs = #IFM * #OFM
    # Bias is ignored
    weights_vol = macs = distiller.volume(self.weight)
    module_visitor(self, input, output, df, model, weights_vol, macs)
Example #7
0
def test_weights_size_attr(dataset, arch, parallel):
    model = create_model(False, dataset, arch, parallel=parallel)
    sgraph = SummaryGraph(model, distiller.get_dummy_input(dataset))

    distiller.assign_layer_fq_names(model)
    for name, mod in model.named_modules():
        if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.Linear):
            op = sgraph.find_op(name)
            assert op is not None
            assert op['attrs']['weights_vol'] == distiller.volume(mod.weight)
Example #8
0
def module_visitor(self,
                   input,
                   output,
                   df,
                   model,
                   weights_vol,
                   macs,
                   attrs=None):
    in_features_shape = input[0].size()
    out_features_shape = output.size()

    param_name = distiller.model_find_param_name(model, self.weight.data)
    if param_name is None:
        return
    mod_name = param_name[:param_name.find(".weight")]
    df.loc[len(df.index)] = ([
        mod_name, self.__class__.__name__, attrs if attrs is not None else '',
        distiller.size_to_str(in_features_shape),
        distiller.volume(input[0]),
        distiller.size_to_str(out_features_shape),
        distiller.volume(output), weights_vol,
        int(macs)
    ])
def create_png(sgraph, display_param_nodes=False, rankdir="TB", styles=None):
    """Create a PNG object containing a graphiz-dot graph of the network,
    as represented by SummaryGraph 'sgraph'.

    Args:
        sgraph (SummaryGraph): the SummaryGraph instance to draw.
        display_param_nodes (boolean): if True, draw the parameter nodes
        rankdir: diagram direction.  'TB'/'BT' is Top-to-Bottom/Bottom-to-Top
                 'LR'/'R/L' is Left-to-Rt/Rt-to-Left
        styles: a dictionary of styles.  Key is module name.  Value is
                a legal pydot style dictionary.  For example:
                styles['conv1'] = {'shape': 'oval',
                                   'fillcolor': 'gray',
                                   'style': 'rounded, filled'}
    """
    def annotate_op_node(op):
        if op["type"] == "Conv":
            return [
                "sh={}".format(distiller.size2str(
                    op["attrs"]["kernel_shape"])),
                "g={}".format(str(op["attrs"]["group"])),
            ]
        return ""

    op_nodes = [op["name"] for op in sgraph.ops.values()]
    data_nodes = []
    param_nodes = []
    for id, param in sgraph.params.items():
        n_data = (id, str(distiller.volume(param["shape"])),
                  str(param["shape"]))
        if data_node_has_parent(sgraph, id):
            data_nodes.append(n_data)
        else:
            param_nodes.append(n_data)
    edges = sgraph.edges

    if not display_param_nodes:
        # Use only the edges that don't have a parameter source
        non_param_ids = op_nodes + [dn[0] for dn in data_nodes]
        edges = [edge for edge in sgraph.edges if edge.src in non_param_ids]
        param_nodes = None

    op_nodes_desc = [(op["name"], op["type"], *annotate_op_node(op))
                     for op in sgraph.ops.values()]
    pydot_graph = create_pydot_graph(op_nodes_desc, data_nodes, param_nodes,
                                     edges, rankdir, styles)
    png = pydot_graph.create_png()
    return png
Example #10
0
    def __remove_structures(self, idx, fraction_to_prune, prune_what="channels"):
        """Physically remove channels and corresponding filters from the model"""
        if idx not in range(self.num_layers()):
            raise ValueError("idx=%d is not in correct range (0-%d)" % (idx, self.num_layers()))
        if fraction_to_prune < 0:
            raise ValueError("fraction_to_prune=%f is illegal" % (fraction_to_prune))

        if fraction_to_prune == 0:
            return 0
        if fraction_to_prune == 1.0:
            # For now, prevent the removal of entire layers
            fraction_to_prune = ALMOST_ONE

        layer = self.conv_layers[idx]
        conv_pname = layer.name + ".weight"
        conv_p = distiller.model_find_param(self.model, conv_pname)

        msglogger.info("ADC: removing %.1f%% %s from %s" % (fraction_to_prune*100, prune_what, conv_pname))

        if prune_what == "channels":
            calculate_sparsity = distiller.sparsity_ch
            remove_structures = distiller.remove_channels
            group_type = "Channels"
        elif prune_what == "filters":
            calculate_sparsity = distiller.sparsity_3D
            group_type = "Filters"
            remove_structures = distiller.remove_filters
        else:
            raise ValueError("unsupported structure {}".format(prune_what))
        # Create a channel-ranking pruner
        pruner = distiller.pruning.L1RankedStructureParameterPruner("adc_pruner", group_type,
                                                                    fraction_to_prune, conv_pname)
        pruner.set_param_mask(conv_p, conv_pname, self.zeros_mask_dict, meta=None)

        if (self.zeros_mask_dict[conv_pname].mask is None or
            calculate_sparsity(self.zeros_mask_dict[conv_pname].mask) == 0):
            msglogger.info("__remove_structures: aborting because there are no channels to prune")
            return 0

        # Use the mask to prune
        self.zeros_mask_dict[conv_pname].apply_mask(conv_p)

        if PERFORM_THINNING:
            remove_structures(self.model, self.zeros_mask_dict, self.app_args.dataset, self.app_args.dataset, optimizer=None)
            conv_p = distiller.model_find_param(self.model, conv_pname)
            return distiller.volume(conv_p) / layer.weights_vol
        actual_sparsity = calculate_sparsity(conv_p)
        return actual_sparsity
def conv_visitor(self, input, output, df, model, memo):
    assert isinstance(self, torch.nn.Conv2d)
    if self in memo:
        return

    weights_vol = (self.out_channels * self.in_channels * self.kernel_size[0] *
                   self.kernel_size[1])

    # Multiply-accumulate operations: MACs = volume(OFM) * (#IFM * K^2) / #Groups
    # Bias is ignored
    macs = distiller.volume(output) * (self.in_channels / self.groups *
                                       self.kernel_size[0] *
                                       self.kernel_size[1])
    attrs = "k=" + "(" + (", ").join(["%d" % v
                                      for v in self.kernel_size]) + ")"
    module_visitor(self, input, output, df, model, weights_vol, macs, attrs)
Example #12
0
    def rank_and_prune_blocks(fraction_to_prune, param, param_name=None, zeros_mask_dict=None,
                              model=None, binary_map=None, block_shape=None,
                              magnitude_fn=distiller.norms.l1_norm, group_size=1):
        """Block-wise pruning for 4D tensors.

        The block shape is specified using a tuple: [block_repetitions, block_depth, block_height, block_width].
        The dimension 'block_repetitions' specifies in how many consecutive filters the "basic block"
        (shaped as [block_depth, block_height, block_width]) repeats to produce a (4D) "super block".

        For example:

          block_pruner:
            class: L1RankedStructureParameterPruner_AGP
            initial_sparsity : 0.05
            final_sparsity: 0.70
            group_type: Blocks
            kwargs:
              block_shape: [1,8,1,1]  # [block_repetitions, block_depth, block_height, block_width]

        Currently the only supported block shape is: block_repetitions x block_depth x 1 x 1
        """
        if len(block_shape) != 4:
            raise ValueError("The block shape must be specified as a 4-element tuple")
        block_repetitions, block_depth, block_height, block_width = block_shape
        if not block_width == block_height == 1:
            raise ValueError("Currently the only supported block shape is: block_repetitions x block_depth x 1 x 1")
        super_block_volume = distiller.volume(block_shape)
        num_super_blocks = distiller.volume(param) / super_block_volume
        if distiller.volume(param) % super_block_volume != 0:
            raise ValueError("The super-block size must divide the weight tensor exactly.")

        num_filters, num_channels = param.size(0), param.size(1)
        kernel_size = param.size(2) * param.size(3)

        if block_depth > 1:
            view_dims = (num_filters*num_channels//(block_repetitions*block_depth),
                         block_repetitions*block_depth,
                         kernel_size,)
        else:
            view_dims = (num_filters // block_repetitions,
                         block_repetitions,
                         -1,)

        def rank_blocks(fraction_to_prune, param):
            # Create a view where each block is a column
            view1 = param.view(*view_dims)
            # Next, compute the sums of each column (block)
            block_mags = magnitude_fn(view1, dim=1)
            block_mags = block_mags.view(-1)  # flatten
            k = int(fraction_to_prune * block_mags.size(0))
            if k == 0:
                msglogger.info("Too few blocks (%d)- can't prune %.1f%% blocks",
                               block_mags.size(0), 100*fraction_to_prune)
                return None, None

            bottomk, _ = torch.topk(block_mags, k, largest=False, sorted=True)
            return bottomk, block_mags

        def binary_map_to_mask(binary_map, param):
            a = binary_map.view(view_dims[0], view_dims[2])
            c = a.unsqueeze(1)
            d = c.expand(*view_dims).contiguous()
            return d.view(num_filters, num_channels, param.size(2), param.size(3))

        if binary_map is None:
            bottomk_blocks, block_mags = rank_blocks(fraction_to_prune, param)
            if bottomk_blocks is None:
                # Empty list means that fraction_to_prune is too low to prune anything
                return
            threshold = bottomk_blocks[-1]
            binary_map = block_mags.gt(threshold).type(param.data.type())

        if zeros_mask_dict is not None:
            zeros_mask_dict[param_name].mask = binary_map_to_mask(binary_map, param)
            msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
                           magnitude_fn, param_name,
                           distiller.sparsity_blocks(zeros_mask_dict[param_name].mask, block_shape=block_shape),
                           fraction_to_prune, binary_map.sum().item(), num_super_blocks)
        return binary_map