예제 #1
0
 def update_attributes(self, conv: nn.Conv2d):
     if conv.groups > 1:
         # Depthwise Convolution
         in_channels = int(conv.weight.shape[0])
         conv.groups = in_channels
         out_channels = conv.groups * int(conv.weight.shape[1])
         conv.in_channels = in_channels
         conv.out_channels = out_channels
     else:
         in_channels = int(conv.weight.shape[1])
         conv.in_channels = in_channels
def prune_conv2d(module: Conv2d, in_keep_idxes=None, out_keep_idxes=None):
    if in_keep_idxes is None:
        in_keep_idxes = list(range(module.weight.shape[1]))

    if out_keep_idxes is None:
        out_keep_idxes = list(range(module.weight.shape[0]))

    is_depthwise = is_depthwise_conv2d(module)

    if is_depthwise:
        module.groups = len(in_keep_idxes)
        assert len(in_keep_idxes) == len(out_keep_idxes)
    else:
        assert (
            len(in_keep_idxes) <= module.weight.shape[1]
        ), f"len(in_keep_idxes): {len(in_keep_idxes)}, module.weight.shape[1]: {module.weight.shape[1]}"

    assert (
        len(out_keep_idxes) <= module.weight.shape[0]
    ), f"len(out_keep_idxes): {len(out_keep_idxes)}, module.weight.shape[0]: {module.weight.shape[0]}"

    module.out_channels = len(out_keep_idxes)
    module.in_channels = len(in_keep_idxes)

    module.weight = torch.nn.Parameter(module.weight.data[out_keep_idxes, :, :, :])

    if not is_depthwise:
        module.weight = torch.nn.Parameter(module.weight.data[:, in_keep_idxes, :, :])

    module.weight.grad = None

    if module.bias is not None:
        module.bias = torch.nn.Parameter(module.bias.data[out_keep_idxes])
        module.bias.grad = None

    return in_keep_idxes, out_keep_idxes
예제 #3
0
def prune_conv_layer(
    conv_layer: nn.Conv2d,
    bn_layer: nn.BatchNorm2d,
    sparse_layer_out: typing.Optional[Union[nn.BatchNorm2d, SparseGate]],
    in_channel_mask: typing.Optional[np.ndarray],
    prune_output_mode: str,
    pruner: Pruner,
    prune_mode: str,
    sparse_layer_in: typing.Optional[Union[nn.BatchNorm2d, SparseGate]],
) -> typing.Tuple[np.ndarray, np.ndarray]:
    """
    Note: if the sparse_layer is SparseGate, the gate will be replaced by BatchNorm
    scaling factor. The value of the gate will be set to all ones.

    :param prune_output_mode: how to handle the output channel (case insensitive)
        "keep": keep the output channel intact
        "same": keep the output channel as same as input channel
        "prune": prune the output according to bn_layer
    :param pruner: the method to determinate the pruning mask.
    :param in_channel_mask: a 0-1 vector indicates whether the corresponding channel should be pruned (0) or not (1)
    :param bn_layer: the BatchNorm layer after the convolution layer
    :param conv_layer: the convolution layer to be pruned

    :param sparse_layer_out: the layer to determine the output sparsity. Support BatchNorm2d and SparseGate.
    :param sparse_layer_in: the layer to determine the input sparsity. Support BatchNorm2d and SparseGate.
        When the `sparse_layer_in` is None, there is no input sparse layers,
        the input channel will be determined by the `in_channel_mask`
        Note: `in_channel_mask` is CONFLICT with `sparse_layer_in`!

    :param prune_mode: pruning mode (`str`):
        - `"multiply"`: pruning threshold is determined by the multiplication of `sparse_layer` and `bn_layer`
            only available when `sparse_layer` is `SparseGate`
        - `None` or `"default"`: default behaviour. The pruning threshold is determined by `sparse_layer`
    :return out_channel_mask
    """
    assert isinstance(conv_layer, nn.Conv2d), f"conv_layer got {conv_layer}"

    assert isinstance(sparse_layer_out, nn.BatchNorm2d) or isinstance(
        sparse_layer_out, SparseGate), f"sparse_layer got {sparse_layer_out}"
    if in_channel_mask is not None and sparse_layer_in is not None:
        raise ValueError(
            "Conflict option: in_channel_mask and sparse_layer_in")

    prune_mode = prune_mode.lower()
    prune_output_mode = str.lower(prune_output_mode)

    if prune_mode == 'multiply':
        if not isinstance(sparse_layer_out, SparseGate):
            raise ValueError(
                f"Do not support prune_mode {prune_mode} when the sparse_layer is {sparse_layer_out}"
            )

    with torch.no_grad():
        conv_weight: torch.Tensor = conv_layer.weight.data.clone()

        # prune the input channel of the conv layer
        # if sparse_layer_in and in_channel_mask are both None, the input dim will NOT be pruned
        if sparse_layer_in is not None:
            if in_channel_mask is not None:
                raise ValueError("")
            sparse_weight_in: np.ndarray = sparse_layer_in.weight.view(
                -1).data.cpu().numpy()
            # the in_channel_mask will be overwrote
            in_channel_mask = pruner(sparse_weight_in)

        if in_channel_mask is not None:
            # prune the input channel according to the in_channel_mask
            # convert mask to channel indexes
            idx_in = np.squeeze(np.argwhere(np.asarray(in_channel_mask)))
            if len(idx_in.shape) == 0:
                # expand the single scalar to array
                idx_in = np.expand_dims(idx_in, 0)
            elif len(idx_in.shape) == 1 and idx_in.shape[0] == 0:
                # nothing left, prune the whole block
                out_channel_mask = np.full(conv_layer.out_channels, False)
                return in_channel_mask, out_channel_mask

        # prune the input of the conv layer
        if conv_layer.groups == 1:
            conv_weight = conv_weight[:, idx_in.tolist(), :, :]
        else:
            assert conv_weight.shape[
                1] == 1, "only works for groups == num_channels"

        # prune the output channel of the conv layer
        if prune_output_mode == "prune":
            # the sparse_layer.weight need to be flatten, because the weight of SparseGate is not 1d
            sparse_weight_out: np.ndarray = sparse_layer_out.weight.view(
                -1).data.cpu().numpy()
            if prune_mode == 'multiply':
                bn_weight = bn_layer.weight.data.cpu().numpy()
                sparse_weight_out = sparse_weight_out * bn_weight  # element-wise multiplication
            elif prune_mode != 'default':
                raise ValueError(f"Do not support prune_mode {prune_mode}")

            # prune and get the pruned mask
            out_channel_mask = pruner(sparse_weight_out)
        elif prune_output_mode == "keep":
            # do not prune the output
            out_channel_mask = np.ones(conv_layer.out_channels)
        elif prune_output_mode == "same":
            # prune the output channel with the input mask
            # keep the conv layer in_channel == out_channel
            out_channel_mask = in_channel_mask
        else:
            raise ValueError(f"invalid prune_output_mode: {prune_output_mode}")

        idx_out: np.ndarray = np.squeeze(
            np.argwhere(np.asarray(out_channel_mask)))
        if len(idx_out.shape) == 0:
            # expand the single scalar to array
            idx_out = np.expand_dims(idx_out, 0)
        elif len(idx_out.shape) == 1 and idx_out.shape[0] == 0:
            # no channel left
            # return mask directly
            # the block is supposed to be set as a identity mapping
            return out_channel_mask, in_channel_mask
        conv_weight = conv_weight[idx_out.tolist(), :, :, :]

        # change the property of the conv layer
        conv_layer.in_channels = len(idx_in)
        conv_layer.out_channels = len(idx_out)
        conv_layer.weight.data = conv_weight
        if conv_layer.groups != 1:
            # set the new groups for dw layer
            conv_layer.groups = conv_layer.in_channels
            pass

        # prune the bn layer
        bn_layer.weight.data = bn_layer.weight.data[idx_out.tolist()].clone()
        bn_layer.bias.data = bn_layer.bias.data[idx_out.tolist()].clone()
        bn_layer.running_mean = bn_layer.running_mean[idx_out.tolist()].clone()
        bn_layer.running_var = bn_layer.running_var[idx_out.tolist()].clone()

        # set bn properties
        bn_layer.num_features = len(idx_out)

        # prune the gate
        if isinstance(sparse_layer_out, SparseGate):
            sparse_layer_out.prune(idx_out)
            # multiply the bn weight and SparseGate weight
            sparse_weight_out: torch.Tensor = sparse_layer_out.weight.view(-1)
            bn_layer.weight.data = (bn_layer.weight.data *
                                    sparse_weight_out).clone()
            bn_layer.bias.data = (bn_layer.bias.data *
                                  sparse_weight_out).clone()
            # the function of the SparseGate is now replaced by bn layers
            # the SparseGate should be disabled
            sparse_layer_out.set_ones()

    return out_channel_mask, in_channel_mask